[BOOTDATA] Register the CriticalDeviceCoInstaller as a class co-installer for critica...
[reactos.git] / ntoskrnl / lpc / send.c
index 8bf0efa..d6b2206 100644 (file)
@@ -26,7 +26,10 @@ LpcRequestPort(IN PVOID PortObject,
     ULONG MessageType;
     PLPCP_MESSAGE Message;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    PETHREAD Thread = PsGetCurrentThread();
+
     PAGED_CODE();
+
     LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", Port, LpcMessage);
 
     /* Check if this is a non-datagram message */
@@ -59,7 +62,7 @@ LpcRequestPort(IN PVOID PortObject,
     /* Can't have data information on this type of call */
     if (LpcMessage->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
 
-    /* Validate message sizes */
+    /* Validate the message length */
     if (((ULONG)LpcMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
         ((ULONG)LpcMessage->u1.s1.TotalLength <= (ULONG)LpcMessage->u1.s1.DataLength))
     {
@@ -80,7 +83,7 @@ LpcRequestPort(IN PVOID PortObject,
                     LpcMessage,
                     LpcMessage + 1,
                     MessageType,
-                    &PsGetCurrentThread()->Cid);
+                    &Thread->Cid);
 
     /* Acquire the LPC lock */
     KeAcquireGuardedMutex(&LpcpLock);
@@ -101,7 +104,7 @@ LpcRequestPort(IN PVOID PortObject,
                 if (!ConnectionPort)
                 {
                     /* Fail */
-                    LpcpFreeToPortZone(Message, 3);
+                    LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                     return STATUS_PORT_DISCONNECTED;
                 }
             }
@@ -112,7 +115,7 @@ LpcRequestPort(IN PVOID PortObject,
                 if (!ConnectionPort)
                 {
                     /* Fail */
-                    LpcpFreeToPortZone(Message, 3);
+                    LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                     return STATUS_PORT_DISCONNECTED;
                 }
             }
@@ -131,17 +134,17 @@ LpcRequestPort(IN PVOID PortObject,
     if (QueuePort)
     {
         /* Generate the Message ID and set it */
-        Message->Request.MessageId =  LpcpNextMessageId++;
+        Message->Request.MessageId = LpcpNextMessageId++;
         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
         Message->Request.CallbackId = 0;
 
         /* No Message ID for the thread */
-        PsGetCurrentThread()->LpcReplyMessageId = 0;
+        Thread->LpcReplyMessageId = 0;
 
         /* Insert the message in our chain */
         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
 
-        /* Release the lock and release the semaphore */
+        /* Release the lock and the semaphore */
         KeEnterCriticalRegion();
         KeReleaseGuardedMutex(&LpcpLock);
         LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
@@ -153,15 +156,16 @@ LpcRequestPort(IN PVOID PortObject,
             KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
         }
 
-        /* We're done */
         KeLeaveCriticalRegion();
+
+        /* We're done */
         if (ConnectionPort) ObDereferenceObject(ConnectionPort);
         LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
         return STATUS_SUCCESS;
     }
 
     /* If we got here, then free the message and fail */
-    LpcpFreeToPortZone(Message, 3);
+    LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
     if (ConnectionPort) ObDereferenceObject(ConnectionPort);
     return STATUS_PORT_DISCONNECTED;
 }
@@ -175,17 +179,17 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
                         IN PPORT_MESSAGE LpcRequest,
                         OUT PPORT_MESSAGE LpcReply)
 {
-    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status = STATUS_SUCCESS;
-    PLPCP_MESSAGE Message;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PETHREAD Thread = PsGetCurrentThread();
+    PLPCP_PORT_OBJECT Port = (PLPCP_PORT_OBJECT)PortObject;
+    PLPCP_PORT_OBJECT QueuePort, ReplyPort, ConnectionPort = NULL;
+    USHORT MessageType;
+    PLPCP_MESSAGE Message;
     BOOLEAN Callback = FALSE;
     PKSEMAPHORE Semaphore;
-    USHORT MessageType;
-    PAGED_CODE();
 
-    Port = (PLPCP_PORT_OBJECT)PortObject;
+    PAGED_CODE();
 
     LPCTRACE(LPC_SEND_DEBUG,
              "Port: %p. Messages: %p/%p. Type: %lx\n",
@@ -201,36 +205,29 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
     MessageType = LpcpGetMessageType(LpcRequest);
     switch (MessageType)
     {
-        /* No type */
+        /* No type, assume LPC request */
         case 0:
-            
-            /* Assume LPC request */
             MessageType = LPC_REQUEST;
             break;
-        
+
         /* LPC request callback */
         case LPC_REQUEST:
-            
-            /* This is a callback */
             Callback = TRUE;
             break;
-        
-        /* Anything else */
+
+        /* Anything else, nothing to do */
         case LPC_CLIENT_DIED:
         case LPC_PORT_CLOSED:
         case LPC_EXCEPTION:
         case LPC_DEBUG_EVENT:
         case LPC_ERROR_EVENT:
-            
-            /* Nothing to do */
             break;
-            
+
+        /* Invalid message type */
         default:
-            
-            /* Invalid message type */
             return STATUS_INVALID_PARAMETER;
     }
-    
+
     /* Set the request type */
     LpcRequest->u2.s2.Type = MessageType;
 
@@ -281,7 +278,7 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
             if (!QueuePort)
             {
                 /* We have no connected port, fail */
-                LpcpFreeToPortZone(Message, 3);
+                LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                 return STATUS_PORT_DISCONNECTED;
             }
 
@@ -297,7 +294,7 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
                 if (!ConnectionPort)
                 {
                     /* Fail */
-                    LpcpFreeToPortZone(Message, 3);
+                    LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                     return STATUS_PORT_DISCONNECTED;
                 }
             }
@@ -309,7 +306,7 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
                 if (!ConnectionPort)
                 {
                     /* Fail */
-                    LpcpFreeToPortZone(Message, 3);
+                    LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                     return STATUS_PORT_DISCONNECTED;
                 }
             }
@@ -328,7 +325,7 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
         Message->SenderPort = Port;
 
         /* Generate the Message ID and set it */
-        Message->Request.MessageId =  LpcpNextMessageId++;
+        Message->Request.MessageId = LpcpNextMessageId++;
         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
         Message->Request.CallbackId = 0;
 
@@ -397,10 +394,10 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
                             (&Message->Request) + 1,
                             0,
                             NULL);
-            
+
             /* Acquire the lock */
             KeAcquireGuardedMutex(&LpcpLock);
-            
+
             /* Check if we replied to a thread */
             if (Message->RepliedToThread)
             {
@@ -409,9 +406,8 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
                 Message->RepliedToThread = NULL;
             }
 
-
             /* Free the message */
-            LpcpFreeToPortZone(Message, 3);
+            LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
         }
         else
         {
@@ -427,7 +423,7 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
 
     /* All done */
     LPCTRACE(LPC_SEND_DEBUG,
-             "Port: %p. Status: %p\n",
+             "Port: %p. Status: %d\n",
              Port,
              Status);
 
@@ -444,30 +440,51 @@ NTAPI
 NtRequestPort(IN HANDLE PortHandle,
               IN PPORT_MESSAGE LpcRequest)
 {
-    PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
-    PLPCP_MESSAGE Message;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PETHREAD Thread = PsGetCurrentThread();
-
-    PKSEMAPHORE Semaphore;
+    PORT_MESSAGE CapturedLpcRequest;
+    PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
     ULONG MessageType;
+    PLPCP_MESSAGE Message;
+
     PAGED_CODE();
     LPCTRACE(LPC_SEND_DEBUG,
-             "Handle: %lx. Message: %p. Type: %lx\n",
+             "Handle: %p. Message: %p. Type: %lx\n",
              PortHandle,
              LpcRequest,
              LpcpGetMessageType(LpcRequest));
 
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
+    {
+        _SEH2_TRY
+        {
+            /* Probe and capture the LpcRequest */
+            ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
+            CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        /* Access the LpcRequest directly */
+        CapturedLpcRequest = *LpcRequest;
+    }
+
     /* Get the message type */
-    MessageType = LpcRequest->u2.s2.Type | LPC_DATAGRAM;
+    MessageType = CapturedLpcRequest.u2.s2.Type | LPC_DATAGRAM;
 
     /* Can't have data information on this type of call */
-    if (LpcRequest->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
+    if (CapturedLpcRequest.u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
 
     /* Validate the length */
-    if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-         (ULONG)LpcRequest->u1.s1.TotalLength)
+    if (((ULONG)CapturedLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)CapturedLpcRequest.u1.s1.TotalLength)
     {
         /* Fail */
         return STATUS_INVALID_PARAMETER;
@@ -483,8 +500,8 @@ NtRequestPort(IN HANDLE PortHandle,
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate the message length */
-    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
-        ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
+    if (((ULONG)CapturedLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)CapturedLpcRequest.u1.s1.TotalLength <= (ULONG)CapturedLpcRequest.u1.s1.DataLength))
     {
         /* Fail */
         ObDereferenceObject(Port);
@@ -505,14 +522,14 @@ NtRequestPort(IN HANDLE PortHandle,
     {
         /* Copy it */
         LpcpMoveMessage(&Message->Request,
-            LpcRequest,
-            LpcRequest + 1,
-            MessageType,
-            &Thread->Cid);
+                        &CapturedLpcRequest,
+                        LpcRequest + 1,
+                        MessageType,
+                        &Thread->Cid);
     }
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
-        /* Fail */
+        /* Cleanup and return the exception code */
         LpcpFreeToPortZone(Message, 0);
         ObDereferenceObject(Port);
         _SEH2_YIELD(return _SEH2_GetExceptionCode());
@@ -533,7 +550,7 @@ NtRequestPort(IN HANDLE PortHandle,
         if (!QueuePort)
         {
             /* We have no connected port, fail */
-            LpcpFreeToPortZone(Message, 3);
+            LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
             ObDereferenceObject(Port);
             return STATUS_PORT_DISCONNECTED;
         }
@@ -547,20 +564,19 @@ NtRequestPort(IN HANDLE PortHandle,
             if (!ConnectionPort)
             {
                 /* Fail */
-                LpcpFreeToPortZone(Message, 3);
+                LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                 ObDereferenceObject(Port);
                 return STATUS_PORT_DISCONNECTED;
             }
         }
-        else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
-            LPCP_COMMUNICATION_PORT)
+        else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
         {
             /* Use the connection port for anything but communication ports */
             ConnectionPort = QueuePort = Port->ConnectionPort;
             if (!ConnectionPort)
             {
                 /* Fail */
-                LpcpFreeToPortZone(Message, 3);
+                LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                 ObDereferenceObject(Port);
                 return STATUS_PORT_DISCONNECTED;
             }
@@ -582,23 +598,20 @@ NtRequestPort(IN HANDLE PortHandle,
         Message->SenderPort = Port;
 
         /* Generate the Message ID and set it */
-        Message->Request.MessageId =  LpcpNextMessageId++;
+        Message->Request.MessageId = LpcpNextMessageId++;
         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
         Message->Request.CallbackId = 0;
 
         /* No Message ID for the thread */
-        PsGetCurrentThread()->LpcReplyMessageId = 0;
+        Thread->LpcReplyMessageId = 0;
 
         /* Insert the message in our chain */
         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
 
-        /* Release the lock and get the semaphore we'll use later */
+        /* Release the lock and the semaphore */
         KeEnterCriticalRegion();
         KeReleaseGuardedMutex(&LpcpLock);
-
-        /* Now release the semaphore */
-        Semaphore = QueuePort->MsgQueue.Semaphore;
-        LpcpCompleteWait(Semaphore);
+        LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
 
         /* If this is a waitable port, wake it up */
         if (QueuePort->Flags & LPCP_WAITABLE_PORT)
@@ -621,18 +634,59 @@ NtRequestPort(IN HANDLE PortHandle,
 
     /* All done with a failure*/
     LPCTRACE(LPC_SEND_DEBUG,
-             "Port: %p. Status: %p\n",
+             "Port: %p. Status: %d\n",
              Port,
              Status);
 
     /* The wait failed, free the message */
-    if (Message) LpcpFreeToPortZone(Message, 3);
+    if (Message) LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
 
     ObDereferenceObject(Port);
     if (ConnectionPort) ObDereferenceObject(ConnectionPort);
     return Status;
 }
 
+NTSTATUS
+NTAPI
+LpcpVerifyMessageDataInfo(
+    _In_ PPORT_MESSAGE Message,
+    _Out_ PULONG NumberOfDataEntries)
+{
+    PLPCP_DATA_INFO DataInfo;
+    PUCHAR EndOfEntries;
+
+    /* Check if we have no data info at all */
+    if (Message->u2.s2.DataInfoOffset == 0)
+    {
+        *NumberOfDataEntries = 0;
+        return STATUS_SUCCESS;
+    }
+
+    /* Make sure the data info structure is within the message */
+    if (((ULONG)Message->u1.s1.TotalLength <
+            sizeof(PORT_MESSAGE) + sizeof(LPCP_DATA_INFO)) ||
+        ((ULONG)Message->u2.s2.DataInfoOffset < sizeof(PORT_MESSAGE)) ||
+        ((ULONG)Message->u2.s2.DataInfoOffset >
+            ((ULONG)Message->u1.s1.TotalLength - sizeof(LPCP_DATA_INFO))))
+    {
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Get a pointer to the data info */
+    DataInfo = LpcpGetDataInfoFromMessage(Message);
+
+    /* Make sure the full data info with all entries is within the message */
+    EndOfEntries = (PUCHAR)&DataInfo->Entries[DataInfo->NumberOfEntries];
+    if ((EndOfEntries > ((PUCHAR)Message + (ULONG)Message->u1.s1.TotalLength)) ||
+        (EndOfEntries < (PUCHAR)Message))
+    {
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    *NumberOfDataEntries = DataInfo->NumberOfEntries;
+    return STATUS_SUCCESS;
+}
+
 /*
  * @implemented
  */
@@ -642,17 +696,21 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
                        IN PPORT_MESSAGE LpcRequest,
                        IN OUT PPORT_MESSAGE LpcReply)
 {
-    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
+    PORT_MESSAGE CapturedLpcRequest;
+    ULONG NumberOfDataEntries;
+    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
     PLPCP_MESSAGE Message;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PETHREAD Thread = PsGetCurrentThread();
     BOOLEAN Callback;
     PKSEMAPHORE Semaphore;
     ULONG MessageType;
+    PLPCP_DATA_INFO DataInfo;
+
     PAGED_CODE();
     LPCTRACE(LPC_SEND_DEBUG,
-             "Handle: %lx. Messages: %p/%p. Type: %lx\n",
+             "Handle: %p. Messages: %p/%p. Type: %lx\n",
              PortHandle,
              LpcRequest,
              LpcReply,
@@ -661,32 +719,80 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     /* Check if the thread is dying */
     if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;
 
+    /* Check for user mode access */
+    if (PreviousMode != KernelMode)
+    {
+        _SEH2_TRY
+        {
+            /* Probe and capture the LpcRequest */
+            ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
+            CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
+
+            /* Probe the reply message for write */
+            ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG));
+
+            /* Make sure the data entries in the request message are valid */
+            Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
+            if (!NT_SUCCESS(Status))
+            {
+                DPRINT1("LpcpVerifyMessageDataInfo failed\n");
+                _SEH2_YIELD(return Status);
+            }
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            DPRINT1("Got exception\n");
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        CapturedLpcRequest = *LpcRequest;
+        Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
+        if (!NT_SUCCESS(Status))
+        {
+            DPRINT1("LpcpVerifyMessageDataInfo failed\n");
+            return Status;
+        }
+    }
+
+    /* This flag is undocumented. Remove it before continuing */
+    CapturedLpcRequest.u2.s2.Type &= ~0x4000;
+
     /* Check if this is an LPC Request */
-    if (LpcpGetMessageType(LpcRequest) == LPC_REQUEST)
+    if (LpcpGetMessageType(&CapturedLpcRequest) == LPC_REQUEST)
     {
         /* Then it's a callback */
         Callback = TRUE;
     }
-    else if (LpcpGetMessageType(LpcRequest))
+    else if (LpcpGetMessageType(&CapturedLpcRequest))
     {
         /* This is a not kernel-mode message */
+        DPRINT1("Not a kernel-mode message!\n");
         return STATUS_INVALID_PARAMETER;
     }
     else
     {
         /* This is a kernel-mode message without a callback */
-        LpcRequest->u2.s2.Type |= LPC_REQUEST;
+        CapturedLpcRequest.u2.s2.Type |= LPC_REQUEST;
         Callback = FALSE;
     }
 
     /* Get the message type */
-    MessageType = LpcRequest->u2.s2.Type;
+    MessageType = CapturedLpcRequest.u2.s2.Type;
+
+    /* Due to the above probe, we know that TotalLength is positive */
+    ASSERT(CapturedLpcRequest.u1.s1.TotalLength >= 0);
 
     /* Validate the length */
-    if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-         (ULONG)LpcRequest->u1.s1.TotalLength)
+    if ((((ULONG)(USHORT)CapturedLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)CapturedLpcRequest.u1.s1.TotalLength))
     {
         /* Fail */
+        DPRINT1("Invalid message length: %u, %u\n",
+                CapturedLpcRequest.u1.s1.DataLength,
+                CapturedLpcRequest.u1.s1.TotalLength);
         return STATUS_INVALID_PARAMETER;
     }
 
@@ -700,10 +806,13 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate the message length */
-    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
-        ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
+    if (((ULONG)CapturedLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)CapturedLpcRequest.u1.s1.TotalLength <= (ULONG)CapturedLpcRequest.u1.s1.DataLength))
     {
         /* Fail */
+        DPRINT1("Invalid message length: %u, %u\n",
+                CapturedLpcRequest.u1.s1.DataLength,
+                CapturedLpcRequest.u1.s1.TotalLength);
         ObDereferenceObject(Port);
         return STATUS_PORT_MESSAGE_TOO_LONG;
     }
@@ -713,6 +822,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     if (!Message)
     {
         /* Fail if we couldn't allocate a message */
+        DPRINT1("Failed to allocate a message!\n");
         ObDereferenceObject(Port);
         return STATUS_NO_MEMORY;
     }
@@ -729,16 +839,33 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
         /* No callback, just copy the message */
         _SEH2_TRY
         {
+            /* Check if we have data info entries */
+            if (LpcRequest->u2.s2.DataInfoOffset != 0)
+            {
+                /* Get the data info and check if the number of entries matches
+                   what we expect */
+                DataInfo = LpcpGetDataInfoFromMessage(LpcRequest);
+                if (DataInfo->NumberOfEntries != NumberOfDataEntries)
+                {
+                    LpcpFreeToPortZone(Message, 0);
+                    ObDereferenceObject(Port);
+                    DPRINT1("NumberOfEntries has changed: %u, %u\n",
+                            DataInfo->NumberOfEntries, NumberOfDataEntries);
+                    _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
+                }
+            }
+
             /* Copy it */
             LpcpMoveMessage(&Message->Request,
-                            LpcRequest,
+                            &CapturedLpcRequest,
                             LpcRequest + 1,
                             MessageType,
                             &Thread->Cid);
         }
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            /* Fail */
+            /* Cleanup and return the exception code */
+            DPRINT1("Got exception!\n");
             LpcpFreeToPortZone(Message, 0);
             ObDereferenceObject(Port);
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
@@ -759,7 +886,8 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             if (!QueuePort)
             {
                 /* We have no connected port, fail */
-                LpcpFreeToPortZone(Message, 3);
+                DPRINT1("No connected port\n");
+                LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                 ObDereferenceObject(Port);
                 return STATUS_PORT_DISCONNECTED;
             }
@@ -767,29 +895,22 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             /* This will be the rundown port */
             ReplyPort = QueuePort;
 
-            /* Check if this is a communication port */
+            /* Check if this is a client port */
             if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
             {
-                /* Copy the port context and use the connection port */
+                /* Copy the port context */
                 Message->PortContext = QueuePort->PortContext;
-                ConnectionPort = QueuePort = Port->ConnectionPort;
-                if (!ConnectionPort)
-                {
-                    /* Fail */
-                    LpcpFreeToPortZone(Message, 3);
-                    ObDereferenceObject(Port);
-                    return STATUS_PORT_DISCONNECTED;
-                }
             }
-            else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
-                      LPCP_COMMUNICATION_PORT)
+
+            if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
             {
                 /* Use the connection port for anything but communication ports */
                 ConnectionPort = QueuePort = Port->ConnectionPort;
                 if (!ConnectionPort)
                 {
                     /* Fail */
-                    LpcpFreeToPortZone(Message, 3);
+                    DPRINT1("No connection port\n");
+                    LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                     ObDereferenceObject(Port);
                     return STATUS_PORT_DISCONNECTED;
                 }
@@ -809,7 +930,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
         Message->SenderPort = Port;
 
         /* Generate the Message ID and set it */
-        Message->Request.MessageId =  LpcpNextMessageId++;
+        Message->Request.MessageId = LpcpNextMessageId++;
         if (!LpcpNextMessageId) LpcpNextMessageId = 1;
         Message->Request.CallbackId = 0;
 
@@ -883,6 +1004,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             }
             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
             {
+                DPRINT1("Got exception!\n");
                 Status = _SEH2_GetExceptionCode();
             }
             _SEH2_END;
@@ -914,7 +1036,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
 
     /* All done */
     LPCTRACE(LPC_SEND_DEBUG,
-             "Port: %p. Status: %p\n",
+             "Port: %p. Status: %d\n",
              Port,
              Status);
     ObDereferenceObject(Port);