[BOOTDATA] Register the CriticalDeviceCoInstaller as a class co-installer for critica...
[reactos.git] / ntoskrnl / lpc / send.c
index a9aad6a..d6b2206 100644 (file)
@@ -179,17 +179,17 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
                         IN PPORT_MESSAGE LpcRequest,
                         OUT PPORT_MESSAGE LpcReply)
 {
-    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status = STATUS_SUCCESS;
-    PLPCP_MESSAGE Message;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PETHREAD Thread = PsGetCurrentThread();
+    PLPCP_PORT_OBJECT Port = (PLPCP_PORT_OBJECT)PortObject;
+    PLPCP_PORT_OBJECT QueuePort, ReplyPort, ConnectionPort = NULL;
+    USHORT MessageType;
+    PLPCP_MESSAGE Message;
     BOOLEAN Callback = FALSE;
     PKSEMAPHORE Semaphore;
-    USHORT MessageType;
-    PAGED_CODE();
 
-    Port = (PLPCP_PORT_OBJECT)PortObject;
+    PAGED_CODE();
 
     LPCTRACE(LPC_SEND_DEBUG,
              "Port: %p. Messages: %p/%p. Type: %lx\n",
@@ -205,33 +205,26 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
     MessageType = LpcpGetMessageType(LpcRequest);
     switch (MessageType)
     {
-        /* No type */
+        /* No type, assume LPC request */
         case 0:
-
-            /* Assume LPC request */
             MessageType = LPC_REQUEST;
             break;
 
         /* LPC request callback */
         case LPC_REQUEST:
-
-            /* This is a callback */
             Callback = TRUE;
             break;
 
-        /* Anything else */
+        /* Anything else, nothing to do */
         case LPC_CLIENT_DIED:
         case LPC_PORT_CLOSED:
         case LPC_EXCEPTION:
         case LPC_DEBUG_EVENT:
         case LPC_ERROR_EVENT:
-
-            /* Nothing to do */
             break;
 
+        /* Invalid message type */
         default:
-
-            /* Invalid message type */
             return STATUS_INVALID_PARAMETER;
     }
 
@@ -448,29 +441,50 @@ NtRequestPort(IN HANDLE PortHandle,
               IN PPORT_MESSAGE LpcRequest)
 {
     NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    PETHREAD Thread = PsGetCurrentThread();
+    PORT_MESSAGE CapturedLpcRequest;
     PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
     ULONG MessageType;
     PLPCP_MESSAGE Message;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
-    PETHREAD Thread = PsGetCurrentThread();
 
     PAGED_CODE();
-
     LPCTRACE(LPC_SEND_DEBUG,
              "Handle: %p. Message: %p. Type: %lx\n",
              PortHandle,
              LpcRequest,
              LpcpGetMessageType(LpcRequest));
 
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
+    {
+        _SEH2_TRY
+        {
+            /* Probe and capture the LpcRequest */
+            ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
+            CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        /* Access the LpcRequest directly */
+        CapturedLpcRequest = *LpcRequest;
+    }
+
     /* Get the message type */
-    MessageType = LpcRequest->u2.s2.Type | LPC_DATAGRAM;
+    MessageType = CapturedLpcRequest.u2.s2.Type | LPC_DATAGRAM;
 
     /* Can't have data information on this type of call */
-    if (LpcRequest->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
+    if (CapturedLpcRequest.u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
 
     /* Validate the length */
-    if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-         (ULONG)LpcRequest->u1.s1.TotalLength)
+    if (((ULONG)CapturedLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)CapturedLpcRequest.u1.s1.TotalLength)
     {
         /* Fail */
         return STATUS_INVALID_PARAMETER;
@@ -486,8 +500,8 @@ NtRequestPort(IN HANDLE PortHandle,
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate the message length */
-    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
-        ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
+    if (((ULONG)CapturedLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)CapturedLpcRequest.u1.s1.TotalLength <= (ULONG)CapturedLpcRequest.u1.s1.DataLength))
     {
         /* Fail */
         ObDereferenceObject(Port);
@@ -508,14 +522,14 @@ NtRequestPort(IN HANDLE PortHandle,
     {
         /* Copy it */
         LpcpMoveMessage(&Message->Request,
-                        LpcRequest,
+                        &CapturedLpcRequest,
                         LpcRequest + 1,
                         MessageType,
                         &Thread->Cid);
     }
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
-        /* Fail */
+        /* Cleanup and return the exception code */
         LpcpFreeToPortZone(Message, 0);
         ObDereferenceObject(Port);
         _SEH2_YIELD(return _SEH2_GetExceptionCode());
@@ -555,8 +569,7 @@ NtRequestPort(IN HANDLE PortHandle,
                 return STATUS_PORT_DISCONNECTED;
             }
         }
-        else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
-            LPCP_COMMUNICATION_PORT)
+        else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
         {
             /* Use the connection port for anything but communication ports */
             ConnectionPort = QueuePort = Port->ConnectionPort;
@@ -683,17 +696,18 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
                        IN PPORT_MESSAGE LpcRequest,
                        IN OUT PPORT_MESSAGE LpcReply)
 {
-    PORT_MESSAGE LocalLpcRequest;
+    NTSTATUS Status;
+    PORT_MESSAGE CapturedLpcRequest;
     ULONG NumberOfDataEntries;
     PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
-    NTSTATUS Status;
     PLPCP_MESSAGE Message;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PETHREAD Thread = PsGetCurrentThread();
     BOOLEAN Callback;
     PKSEMAPHORE Semaphore;
     ULONG MessageType;
     PLPCP_DATA_INFO DataInfo;
+
     PAGED_CODE();
     LPCTRACE(LPC_SEND_DEBUG,
              "Handle: %p. Messages: %p/%p. Type: %lx\n",
@@ -710,10 +724,9 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     {
         _SEH2_TRY
         {
-            /* Probe the full request message and copy the base structure */
+            /* Probe and capture the LpcRequest */
             ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
-            ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG));
-            LocalLpcRequest = *LpcRequest;
+            CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
 
             /* Probe the reply message for write */
             ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG));
@@ -723,19 +736,19 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             if (!NT_SUCCESS(Status))
             {
                 DPRINT1("LpcpVerifyMessageDataInfo failed\n");
-                return Status;
+                _SEH2_YIELD(return Status);
             }
         }
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
             DPRINT1("Got exception\n");
-            return _SEH2_GetExceptionCode();
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
         _SEH2_END;
     }
     else
     {
-        LocalLpcRequest = *LpcRequest;
+        CapturedLpcRequest = *LpcRequest;
         Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
         if (!NT_SUCCESS(Status))
         {
@@ -744,13 +757,16 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
         }
     }
 
+    /* This flag is undocumented. Remove it before continuing */
+    CapturedLpcRequest.u2.s2.Type &= ~0x4000;
+
     /* Check if this is an LPC Request */
-    if (LpcpGetMessageType(&LocalLpcRequest) == LPC_REQUEST)
+    if (LpcpGetMessageType(&CapturedLpcRequest) == LPC_REQUEST)
     {
         /* Then it's a callback */
         Callback = TRUE;
     }
-    else if (LpcpGetMessageType(&LocalLpcRequest))
+    else if (LpcpGetMessageType(&CapturedLpcRequest))
     {
         /* This is a not kernel-mode message */
         DPRINT1("Not a kernel-mode message!\n");
@@ -759,24 +775,24 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     else
     {
         /* This is a kernel-mode message without a callback */
-        LocalLpcRequest.u2.s2.Type |= LPC_REQUEST;
+        CapturedLpcRequest.u2.s2.Type |= LPC_REQUEST;
         Callback = FALSE;
     }
 
     /* Get the message type */
-    MessageType = LocalLpcRequest.u2.s2.Type;
+    MessageType = CapturedLpcRequest.u2.s2.Type;
 
     /* Due to the above probe, we know that TotalLength is positive */
-    NT_ASSERT(LocalLpcRequest.u1.s1.TotalLength >= 0);
+    ASSERT(CapturedLpcRequest.u1.s1.TotalLength >= 0);
 
     /* Validate the length */
-    if ((((ULONG)(USHORT)LocalLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-         (ULONG)LocalLpcRequest.u1.s1.TotalLength))
+    if ((((ULONG)(USHORT)CapturedLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)CapturedLpcRequest.u1.s1.TotalLength))
     {
         /* Fail */
         DPRINT1("Invalid message length: %u, %u\n",
-                LocalLpcRequest.u1.s1.DataLength,
-                LocalLpcRequest.u1.s1.TotalLength);
+                CapturedLpcRequest.u1.s1.DataLength,
+                CapturedLpcRequest.u1.s1.TotalLength);
         return STATUS_INVALID_PARAMETER;
     }
 
@@ -790,13 +806,13 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate the message length */
-    if (((ULONG)LocalLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
-        ((ULONG)LocalLpcRequest.u1.s1.TotalLength <= (ULONG)LocalLpcRequest.u1.s1.DataLength))
+    if (((ULONG)CapturedLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)CapturedLpcRequest.u1.s1.TotalLength <= (ULONG)CapturedLpcRequest.u1.s1.DataLength))
     {
         /* Fail */
         DPRINT1("Invalid message length: %u, %u\n",
-                LocalLpcRequest.u1.s1.DataLength,
-                LocalLpcRequest.u1.s1.TotalLength);
+                CapturedLpcRequest.u1.s1.DataLength,
+                CapturedLpcRequest.u1.s1.TotalLength);
         ObDereferenceObject(Port);
         return STATUS_PORT_MESSAGE_TOO_LONG;
     }
@@ -835,20 +851,20 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
                     ObDereferenceObject(Port);
                     DPRINT1("NumberOfEntries has changed: %u, %u\n",
                             DataInfo->NumberOfEntries, NumberOfDataEntries);
-                    return STATUS_INVALID_PARAMETER;
+                    _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
                 }
             }
 
             /* Copy it */
             LpcpMoveMessage(&Message->Request,
-                            LpcRequest,
+                            &CapturedLpcRequest,
                             LpcRequest + 1,
                             MessageType,
                             &Thread->Cid);
         }
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            /* Fail */
+            /* Cleanup and return the exception code */
             DPRINT1("Got exception!\n");
             LpcpFreeToPortZone(Message, 0);
             ObDereferenceObject(Port);