- Silence TCPIP.
[reactos.git] / reactos / ntoskrnl / lpc / send.c
index c9f97cf..8725e56 100644 (file)
@@ -1,21 +1,16 @@
-/* $Id: send.c,v 1.7 2001/12/01 20:57:23 phreak Exp $
- * 
+/* $Id$
+ *
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
  * FILE:            ntoskrnl/lpc/send.c
  * PURPOSE:         Communication mechanism
- * PROGRAMMER:      David Welch (welch@cwcom.net)
- * UPDATE HISTORY:
- *                  Created 22/05/98
+ *
+ * PROGRAMMERS:     David Welch (welch@cwcom.net)
  */
 
 /* INCLUDES *****************************************************************/
 
-#include <ddk/ntddk.h>
-#include <internal/ob.h>
-#include <internal/port.h>
-#include <internal/dbg.h>
-#include <internal/safe.h>
+#include <ntoskrnl.h>
 
 #define NDEBUG
 #include <internal/debug.h>
@@ -23,6 +18,7 @@
 
 /**********************************************************************
  * NAME
+ *     LpcSendTerminationPort/2
  *
  * DESCRIPTION
  *
  * RETURN VALUE
  *
  * REVISIONS
- *
  */
-NTSTATUS STDCALL 
+NTSTATUS STDCALL
 LpcSendTerminationPort (IN PEPORT Port,
-                       IN TIME CreationTime)
+                       IN LARGE_INTEGER CreateTime)
 {
   NTSTATUS Status;
-  LPC_TERMINATION_MESSAGE Msg;
-   
-  Msg.CreationTime = CreationTime;
-  Status = LpcRequestPort (Port, &Msg.Header);
+  CLIENT_DIED_MSG Msg;
+
+#ifdef __USE_NT_LPC__
+  Msg.h.u2.s2.Type = LPC_CLIENT_DIED;
+#endif
+  Msg.h.u1.s1.TotalLength = sizeof(Msg);
+  Msg.h.u1.s1.DataLength = sizeof(Msg) - sizeof(PORT_MESSAGE);
+  Msg.CreateTime = CreateTime;
+  Status = LpcRequestPort (Port, &Msg.h);
   return(Status);
 }
 
 
 /**********************************************************************
  * NAME
+ *     LpcSendDebugMessagePort/3
  *
  * DESCRIPTION
  *
@@ -56,19 +57,18 @@ LpcSendTerminationPort (IN PEPORT Port,
  * RETURN VALUE
  *
  * REVISIONS
- *
  */
-NTSTATUS STDCALL 
+NTSTATUS STDCALL
 LpcSendDebugMessagePort (IN PEPORT Port,
-                        IN PLPC_DBG_MESSAGE Message,
-                        OUT PLPC_DBG_MESSAGE Reply)
+                        IN PDBGKM_MSG Message,
+                        OUT PDBGKM_MSG Reply)
 {
    NTSTATUS Status;
    KIRQL oldIrql;
    PQUEUEDMESSAGE ReplyMessage;
-   
-   Status = EiReplyOrRequestPort(Port, 
-                                &Message->Header, 
+
+   Status = EiReplyOrRequestPort(Port,
+                                &Message->h,
                                 LPC_REQUEST,
                                 Port);
    if (!NT_SUCCESS(Status))
@@ -86,14 +86,14 @@ LpcSendDebugMessagePort (IN PEPORT Port,
                         UserMode,
                         FALSE,
                         NULL);
-   
+
    /*
     * Dequeue the reply
     */
    KeAcquireSpinLock(&Port->Lock, &oldIrql);
    ReplyMessage = EiDequeueMessagePort(Port);
    KeReleaseSpinLock(&Port->Lock, oldIrql);
-   memcpy(Reply, &ReplyMessage->Message, ReplyMessage->Message.MessageSize);
+   memcpy(Reply, &ReplyMessage->Message, ReplyMessage->Message.u1.s1.TotalLength);
    ExFreePool(ReplyMessage);
 
    return(STATUS_SUCCESS);
@@ -102,6 +102,7 @@ LpcSendDebugMessagePort (IN PEPORT Port,
 
 /**********************************************************************
  * NAME
+ *     LpcRequestPort/2
  *
  * DESCRIPTION
  *
@@ -110,17 +111,44 @@ LpcSendDebugMessagePort (IN PEPORT Port,
  * RETURN VALUE
  *
  * REVISIONS
+ *     2002-03-01 EA
+ *     I investigated this function a bit more in depth.
+ *     It looks like the legal values for the MessageType field in the
+ *     message to send are in the range LPC_NEW_MESSAGE .. LPC_CLIENT_DIED,
+ *     but LPC_DATAGRAM is explicitly forbidden.
  *
+ * @implemented
  */
 NTSTATUS STDCALL LpcRequestPort (IN    PEPORT          Port,
-                                IN     PLPC_MESSAGE    LpcMessage)
+                                IN     PPORT_MESSAGE   LpcMessage)
 {
    NTSTATUS Status;
-   
-   DPRINT("LpcRequestPort(PortHandle %x LpcMessage %x)\n", Port, LpcMessage);
-   
-   Status = EiReplyOrRequestPort(Port, 
-                                LpcMessage, 
+
+   DPRINT("LpcRequestPort(PortHandle %08x, LpcMessage %08x)\n", Port, LpcMessage);
+
+#ifdef __USE_NT_LPC__
+   /* Check the message's type */
+   if (LPC_NEW_MESSAGE == LpcMessage->u2.s2.Type)
+   {
+      LpcMessage->u2.s2.Type = LPC_DATAGRAM;
+   }
+   else if (LPC_DATAGRAM == LpcMessage->u2.s2.Type)
+   {
+      return STATUS_INVALID_PARAMETER;
+   }
+   else if (LpcMessage->u2.s2.Type > LPC_CLIENT_DIED)
+   {
+      return STATUS_INVALID_PARAMETER;
+   }
+   /* Check the range offset */
+   if (0 != LpcMessage->VirtualRangesOffset)
+   {
+      return STATUS_INVALID_PARAMETER;
+   }
+#endif
+
+   Status = EiReplyOrRequestPort(Port,
+                                LpcMessage,
                                 LPC_DATAGRAM,
                                 Port);
    KeReleaseSemaphore( &Port->Semaphore, IO_NO_INCREMENT, 1, FALSE );
@@ -131,6 +159,7 @@ NTSTATUS STDCALL LpcRequestPort (IN PEPORT          Port,
 
 /**********************************************************************
  * NAME
+ *     NtRequestPort/2
  *
  * DESCRIPTION
  *
@@ -140,19 +169,20 @@ NTSTATUS STDCALL LpcRequestPort (IN       PEPORT          Port,
  *
  * REVISIONS
  *
+ * @implemented
  */
 NTSTATUS STDCALL NtRequestPort (IN     HANDLE          PortHandle,
-                               IN      PLPC_MESSAGE    LpcMessage)
+                               IN      PPORT_MESSAGE   LpcMessage)
 {
    NTSTATUS Status;
    PEPORT Port;
-   
-   DPRINT("NtRequestPort(PortHandle %x LpcMessage %x)\n", PortHandle, 
+
+   DPRINT("NtRequestPort(PortHandle %x LpcMessage %x)\n", PortHandle,
          LpcMessage);
-   
+
    Status = ObReferenceObjectByHandle(PortHandle,
                                      PORT_ALL_ACCESS,
-                                     ExPortType,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
@@ -162,9 +192,9 @@ NTSTATUS STDCALL NtRequestPort (IN  HANDLE          PortHandle,
        return(Status);
      }
 
-   Status = LpcRequestPort(Port->OtherPort, 
+   Status = LpcRequestPort(Port->OtherPort,
                           LpcMessage);
-   
+
    ObDereferenceObject(Port);
    return(Status);
 }
@@ -172,6 +202,7 @@ NTSTATUS STDCALL NtRequestPort (IN  HANDLE          PortHandle,
 
 /**********************************************************************
  * NAME
+ *     NtRequestWaitReplyPort/3
  *
  * DESCRIPTION
  *
@@ -181,25 +212,59 @@ NTSTATUS STDCALL NtRequestPort (IN        HANDLE          PortHandle,
  *
  * REVISIONS
  *
+ * @implemented
  */
-NTSTATUS STDCALL 
+NTSTATUS STDCALL
 NtRequestWaitReplyPort (IN HANDLE PortHandle,
-                       PLPC_MESSAGE UnsafeLpcRequest,    
-                       PLPC_MESSAGE UnsafeLpcReply)
+                       PPORT_MESSAGE UnsafeLpcRequest,
+                       PPORT_MESSAGE UnsafeLpcReply)
 {
-   NTSTATUS Status;
+   PETHREAD CurrentThread;
+   struct _KPROCESS *AttachedProcess;
    PEPORT Port;
    PQUEUEDMESSAGE Message;
    KIRQL oldIrql;
-   PLPC_MESSAGE LpcRequest;
-   USHORT LpcRequestMessageSize;
+   PPORT_MESSAGE LpcRequest;
+   USHORT LpcRequestMessageSize = 0, LpcRequestDataSize = 0;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS Status = STATUS_SUCCESS;
+   
+   PreviousMode = ExGetPreviousMode();
+   
+   if (PreviousMode != KernelMode)
+     {
+       _SEH_TRY
+         {
+           ProbeForRead(UnsafeLpcRequest,
+                        sizeof(PORT_MESSAGE),
+                        1);
+           ProbeForWrite(UnsafeLpcReply,
+                         sizeof(PORT_MESSAGE),
+                         1);
+           LpcRequestMessageSize = UnsafeLpcRequest->u1.s1.TotalLength;
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           return Status;
+         }
+     }
+   else
+     {
+       LpcRequestMessageSize = UnsafeLpcRequest->u1.s1.TotalLength;
+     }
 
    DPRINT("NtRequestWaitReplyPort(PortHandle %x, LpcRequest %x, "
          "LpcReply %x)\n", PortHandle, UnsafeLpcRequest, UnsafeLpcReply);
 
    Status = ObReferenceObjectByHandle(PortHandle,
-                                     PORT_ALL_ACCESS, 
-                                     ExPortType,
+                                     PORT_ALL_ACCESS,
+                                     LpcPortObjectType,
                                      UserMode,
                                      (PVOID*)&Port,
                                      NULL);
@@ -208,91 +273,188 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
        return(Status);
      }
 
-   Status = MmCopyFromCaller(&LpcRequestMessageSize,
-                            &UnsafeLpcRequest->MessageSize,
-                            sizeof(USHORT));
-   if (!NT_SUCCESS(Status))
+   if (EPORT_DISCONNECTED == Port->State)
      {
-       ObDereferenceObject(Port);
-       return(Status);
+       ObDereferenceObject(Port);
+       return STATUS_PORT_DISCONNECTED;
+     }
+
+   /* win32k sometimes needs to KeAttach() the CSRSS process in order to make
+      the PortHandle valid. Now that we've got the EPORT structure from the
+      handle we can undo this, so everything is normal again. Need to
+      re-KeAttach() before returning though */
+   CurrentThread = PsGetCurrentThread();
+   if (&CurrentThread->ThreadsProcess->Pcb == CurrentThread->Tcb.ApcState.Process)
+     {
+       AttachedProcess = NULL;
      }
-   if (LpcRequestMessageSize > (sizeof(LPC_MESSAGE) + MAX_MESSAGE_DATA))
+   else
      {
+       AttachedProcess = CurrentThread->Tcb.ApcState.Process;
+       KeDetachProcess();
+     }
+
+   if (LpcRequestMessageSize > LPC_MAX_MESSAGE_LENGTH)
+     {
+       if (NULL != AttachedProcess)
+         {
+           KeAttachProcess(AttachedProcess);
+         }
        ObDereferenceObject(Port);
        return(STATUS_PORT_MESSAGE_TOO_LONG);
      }
    LpcRequest = ExAllocatePool(NonPagedPool, LpcRequestMessageSize);
    if (LpcRequest == NULL)
      {
+       if (NULL != AttachedProcess)
+         {
+           KeAttachProcess(AttachedProcess);
+         }
        ObDereferenceObject(Port);
        return(STATUS_NO_MEMORY);
      }
-   Status = MmCopyFromCaller(LpcRequest, UnsafeLpcRequest,
-                            LpcRequestMessageSize);
-   if (!NT_SUCCESS(Status))
+   if (PreviousMode != KernelMode)
      {
-       ExFreePool(LpcRequest);
-       ObDereferenceObject(Port);
-       return(Status);
+       _SEH_TRY
+         {
+           RtlCopyMemory(LpcRequest,
+                         UnsafeLpcRequest,
+                         LpcRequestMessageSize);
+           LpcRequestMessageSize = LpcRequest->u1.s1.TotalLength;
+           LpcRequestDataSize = LpcRequest->u1.s1.DataLength;
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           ExFreePool(LpcRequest);
+           if (NULL != AttachedProcess)
+             {
+               KeAttachProcess(AttachedProcess);
+             }
+           ObDereferenceObject(Port);
+           return(Status);
+         }
      }
-   LpcRequestMessageSize = LpcRequest->MessageSize;
-   if (LpcRequestMessageSize > (sizeof(LPC_MESSAGE) + MAX_MESSAGE_DATA))
+   else
+     {
+       RtlCopyMemory(LpcRequest,
+                     UnsafeLpcRequest,
+                     LpcRequestMessageSize);
+       LpcRequestMessageSize = LpcRequest->u1.s1.TotalLength;
+       LpcRequestDataSize = LpcRequest->u1.s1.DataLength;
+     }
+
+   if (LpcRequestMessageSize > LPC_MAX_MESSAGE_LENGTH)
      {
        ExFreePool(LpcRequest);
+       if (NULL != AttachedProcess)
+         {
+           KeAttachProcess(AttachedProcess);
+         }
        ObDereferenceObject(Port);
        return(STATUS_PORT_MESSAGE_TOO_LONG);
      }
-   if (LpcRequest->DataSize != (LpcRequest->MessageSize - sizeof(LPC_MESSAGE)))
+   if (LpcRequestDataSize > LPC_MAX_DATA_LENGTH)
      {
        ExFreePool(LpcRequest);
+       if (NULL != AttachedProcess)
+         {
+           KeAttachProcess(AttachedProcess);
+         }
        ObDereferenceObject(Port);
        return(STATUS_PORT_MESSAGE_TOO_LONG);
      }
 
-   Status = EiReplyOrRequestPort(Port->OtherPort, 
-                                LpcRequest, 
+   Status = EiReplyOrRequestPort(Port->OtherPort,
+                                LpcRequest,
                                 LPC_REQUEST,
                                 Port);
    if (!NT_SUCCESS(Status))
      {
-       DbgPrint("Enqueue failed\n");
+       DPRINT1("Enqueue failed\n");
        ExFreePool(LpcRequest);
+        if (NULL != AttachedProcess)
+          {
+            KeAttachProcess(AttachedProcess);
+          }
        ObDereferenceObject(Port);
        return(Status);
      }
    ExFreePool(LpcRequest);
-   KeReleaseSemaphore (&Port->OtherPort->Semaphore, IO_NO_INCREMENT, 
-                      1, FALSE);   
-   
+   KeReleaseSemaphore (&Port->OtherPort->Semaphore, IO_NO_INCREMENT,
+                      1, FALSE);
+
    /*
     * Wait for a reply
     */
-   KeWaitForSingleObject(&Port->Semaphore,
-                        UserRequest,
-                        UserMode,
-                        FALSE,
-                        NULL);
-   
-   /*
-    * Dequeue the reply
-    */
-   KeAcquireSpinLock(&Port->Lock, &oldIrql);
-   Message = EiDequeueMessagePort(Port);
-   KeReleaseSpinLock(&Port->Lock, oldIrql);
-   DPRINT("Message->Message.MessageSize %d\n",
-          Message->Message.MessageSize);
-   Status = MmCopyToCaller(UnsafeLpcReply, &Message->Message, 
-                          Message->Message.MessageSize);
-   ExFreePool(Message);
-   
+   Status = KeWaitForSingleObject(&Port->Semaphore,
+                                 UserRequest,
+                                 UserMode,
+                                 FALSE,
+                                 NULL);
+   if (Status == STATUS_SUCCESS)
+     {
+
+       /*
+        * Dequeue the reply
+        */
+       KeAcquireSpinLock(&Port->Lock, &oldIrql);
+       Message = EiDequeueMessagePort(Port);
+       KeReleaseSpinLock(&Port->Lock, oldIrql);
+       if (Message)
+         {
+           DPRINT("Message->Message.u1.s1.TotalLength %d\n",
+                 Message->Message.u1.s1.TotalLength);
+           if (PreviousMode != KernelMode)
+             {
+               _SEH_TRY
+                 {
+                   RtlCopyMemory(UnsafeLpcReply,
+                                 &Message->Message,
+                                 Message->Message.u1.s1.TotalLength);
+                 }
+               _SEH_HANDLE
+                 {
+                   Status = _SEH_GetExceptionCode();
+                 }
+               _SEH_END;
+             }
+           else
+             {
+               RtlCopyMemory(UnsafeLpcReply,
+                             &Message->Message,
+                             Message->Message.u1.s1.TotalLength);
+             }
+           ExFreePool(Message);
+         }
+       else
+         Status = STATUS_UNSUCCESSFUL;
+     }
+   else
+     {
+       if (NT_SUCCESS(Status))
+         {
+          Status = STATUS_UNSUCCESSFUL;
+        }
+     }
+   if (NULL != AttachedProcess)
+     {
+       KeAttachProcess(AttachedProcess);
+     }
    ObDereferenceObject(Port);
-   
+
    return(Status);
 }
 
 
 /**********************************************************************
  * NAME
+ *     NtWriteRequestData/6
  *
  * DESCRIPTION
  *
@@ -301,16 +463,16 @@ NtRequestWaitReplyPort (IN HANDLE PortHandle,
  * RETURN VALUE
  *
  * REVISIONS
- *
  */
 NTSTATUS STDCALL NtWriteRequestData (HANDLE            PortHandle,
-                                    PLPC_MESSAGE       Message,
+                                    PPORT_MESSAGE      Message,
                                     ULONG              Index,
                                     PVOID              Buffer,
                                     ULONG              BufferLength,
                                     PULONG             ReturnLength)
 {
    UNIMPLEMENTED;
+   return(STATUS_NOT_IMPLEMENTED);
 }