[NTOSKRNL]
authorTimo Kreuzer <timo.kreuzer@reactos.org>
Sun, 16 Feb 2014 09:54:05 +0000 (09:54 +0000)
committerTimo Kreuzer <timo.kreuzer@reactos.org>
Sun, 16 Feb 2014 09:54:05 +0000 (09:54 +0000)
Implement NtReadRequestData and NtWriteRequestData

svn path=/trunk/; revision=62209

reactos/ntoskrnl/include/internal/lpc.h
reactos/ntoskrnl/include/internal/lpc_x.h
reactos/ntoskrnl/include/internal/mm.h
reactos/ntoskrnl/lpc/reply.c
reactos/ntoskrnl/lpc/send.c

index c6d3a14..e862a4f 100644 (file)
 #define LPCP_LOCK_HELD      1
 #define LPCP_LOCK_RELEASE   2
 
 #define LPCP_LOCK_HELD      1
 #define LPCP_LOCK_RELEASE   2
 
+
+typedef struct _LPCP_DATA_INFO
+{
+    ULONG NumberOfEntries;
+    struct
+    {
+        PVOID BaseAddress;
+        ULONG DataLength;
+    } Entries[1];
+} LPCP_DATA_INFO, *PLPCP_DATA_INFO;
+
+
 //
 // Internal Port Management
 //
 //
 // Internal Port Management
 //
@@ -131,6 +143,13 @@ LpcInitSystem(
     VOID
 );
 
     VOID
 );
 
+BOOLEAN
+NTAPI
+LpcpValidateClientPort(
+    PETHREAD ClientThread,
+    PLPCP_PORT_OBJECT Port);
+
+
 //
 // Global data inside the Process Manager
 //
 //
 // Global data inside the Process Manager
 //
index 816f979..0d04988 100644 (file)
@@ -164,3 +164,10 @@ LpcpSetPortToThread(IN PETHREAD Thread,
     Thread->LpcWaitingOnPort = (PVOID)(((ULONG_PTR)Port) |
                                        LPCP_THREAD_FLAG_IS_PORT);
 }
     Thread->LpcWaitingOnPort = (PVOID)(((ULONG_PTR)Port) |
                                        LPCP_THREAD_FLAG_IS_PORT);
 }
+
+FORCEINLINE
+PLPCP_DATA_INFO
+LpcpGetDataInfoFromMessage(PPORT_MESSAGE Message)
+{
+    return (PLPCP_DATA_INFO)((PUCHAR)Message + Message->u2.s2.DataInfoOffset);
+}
index 4e9ddc0..6a89494 100644 (file)
@@ -1798,3 +1798,17 @@ VOID
 NTAPI
 MmSetSessionLocaleId(
     _In_ LCID LocaleId);
 NTAPI
 MmSetSessionLocaleId(
     _In_ LCID LocaleId);
+
+
+/* virtual.c *****************************************************************/
+
+NTSTATUS
+NTAPI
+MmCopyVirtualMemory(IN PEPROCESS SourceProcess,
+                    IN PVOID SourceAddress,
+                    IN PEPROCESS TargetProcess,
+                    OUT PVOID TargetAddress,
+                    IN SIZE_T BufferSize,
+                    IN KPROCESSOR_MODE PreviousMode,
+                    OUT PSIZE_T ReturnSize);
+
index 744c910..d07152f 100644 (file)
@@ -90,6 +90,49 @@ LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
     if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
 }
 
     if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
 }
 
+PLPCP_MESSAGE
+NTAPI
+LpcpFindDataInfoMessage(
+    IN PLPCP_PORT_OBJECT Port,
+    IN ULONG MessageId,
+    IN LPC_CLIENT_ID ClientId)
+{
+    PLPCP_MESSAGE Message;
+    PLIST_ENTRY ListEntry;
+    PAGED_CODE();
+
+    /* Check if the port we want is the connection port */
+    if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT)
+    {
+        /* Use it */
+        Port = Port->ConnectionPort;
+        if (!Port)
+        {
+            /* Return NULL */
+            return NULL;
+        }
+    }
+
+    /* Loop all entries in the list */
+    for (ListEntry = Port->LpcDataInfoChainHead.Flink;
+         ListEntry != &Port->LpcDataInfoChainHead;
+         ListEntry = ListEntry->Flink)
+    {
+        Message = CONTAINING_RECORD(ListEntry, LPCP_MESSAGE, Entry);
+
+        /* Check if this is the desired message */
+        if ((Message->Request.MessageId == MessageId) &&
+            (Message->Request.ClientId.UniqueProcess == ClientId.UniqueProcess) &&
+            (Message->Request.ClientId.UniqueThread == ClientId.UniqueThread))
+        {
+            /* It is, return it */
+            return Message;
+        }
+    }
+
+    return NULL;
+}
+
 VOID
 NTAPI
 LpcpMoveMessage(IN PPORT_MESSAGE Destination,
 VOID
 NTAPI
 LpcpMoveMessage(IN PPORT_MESSAGE Destination,
@@ -132,7 +175,7 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination,
     /* Copy the Message Data */
     RtlCopyMemory(Destination + 1,
                   Data,
     /* Copy the Message Data */
     RtlCopyMemory(Destination + 1,
                   Data,
-                  ((Destination->u1.Length & 0xFFFF) + 3) &~3);
+                  ALIGN_UP_BY(Destination->u1.s1.DataLength, sizeof(ULONG)));
 }
 
 /* PUBLIC FUNCTIONS **********************************************************/
 }
 
 /* PUBLIC FUNCTIONS **********************************************************/
@@ -710,6 +753,199 @@ NtReplyWaitReplyPort(IN HANDLE PortHandle,
     return STATUS_NOT_IMPLEMENTED;
 }
 
     return STATUS_NOT_IMPLEMENTED;
 }
 
+NTSTATUS
+NTAPI
+LpcpCopyRequestData(
+    IN BOOLEAN Write,
+    IN HANDLE PortHandle,
+    IN PPORT_MESSAGE Message,
+    IN ULONG Index,
+    IN PVOID Buffer,
+    IN ULONG BufferLength,
+    OUT PULONG Returnlength)
+{
+    KPROCESSOR_MODE PreviousMode;
+    PORT_MESSAGE CapturedMessage;
+    PLPCP_PORT_OBJECT Port = NULL;
+    PETHREAD ClientThread = NULL;
+    ULONG LocalReturnlength;
+    PLPCP_MESSAGE InfoMessage;
+    PLPCP_DATA_INFO DataInfo;
+    PVOID DataInfoBaseAddress;
+    NTSTATUS Status;
+    PAGED_CODE();
+
+    /* Check the previous mode */
+    PreviousMode = ExGetPreviousMode();
+    if (PreviousMode == KernelMode)
+    {
+        CapturedMessage = *Message;
+    }
+    else
+    {
+        _SEH2_TRY
+        {
+            ProbeForRead(Message, sizeof(*Message), sizeof(PVOID));
+            CapturedMessage = *Message;
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            DPRINT1("Got exception!\n");
+            return _SEH2_GetExceptionCode();
+        }
+        _SEH2_END;
+    }
+
+    /* Make sure there is any data to copy */
+    if (CapturedMessage.u2.s2.DataInfoOffset == 0)
+    {
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Reference the port handle */
+    Status = ObReferenceObjectByHandle(PortHandle,
+                                       PORT_ALL_ACCESS,
+                                       LpcPortObjectType,
+                                       PreviousMode,
+                                       (PVOID*)&Port,
+                                       NULL);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to reference port handle: 0x%ls\n", Status);
+        return Status;
+    }
+
+    /* Look up the client thread */
+    Status = PsLookupProcessThreadByCid(&CapturedMessage.ClientId,
+                                        NULL,
+                                        &ClientThread);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to lookup client thread for [0x%lx:0x%lx]: 0x%ls\n",
+                CapturedMessage.ClientId.UniqueProcess,
+                CapturedMessage.ClientId.UniqueThread, Status);
+        goto Cleanup;
+    }
+
+    /* Acquire the global LPC lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Check for message id mismatch */
+    if ((ClientThread->LpcReplyMessageId != CapturedMessage.MessageId) ||
+        (CapturedMessage.MessageId == 0))
+    {
+        DPRINT1("LpcReplyMessageId mismatch: 0x%lx/0x%lx.\n",
+                ClientThread->LpcReplyMessageId, CapturedMessage.MessageId);
+        Status = STATUS_REPLY_MESSAGE_MISMATCH;
+        goto CleanupWithLock;
+    }
+
+    /* Validate the port */
+    if (!LpcpValidateClientPort(ClientThread, Port))
+    {
+        DPRINT1("LpcpValidateClientPort failed\n");
+        Status = STATUS_REPLY_MESSAGE_MISMATCH;
+        goto CleanupWithLock;
+    }
+
+    /* Find the message with the data */
+    InfoMessage = LpcpFindDataInfoMessage(Port,
+                                          CapturedMessage.MessageId,
+                                          CapturedMessage.ClientId);
+    if (InfoMessage == NULL)
+    {
+        DPRINT1("LpcpFindDataInfoMessage failed\n");
+        Status = STATUS_INVALID_PARAMETER;
+        goto CleanupWithLock;
+    }
+
+    /* Get the data info */
+    DataInfo = LpcpGetDataInfoFromMessage(&InfoMessage->Request);
+
+    /* Check if the index is within bounds */
+    if (Index >= DataInfo->NumberOfEntries)
+    {
+        DPRINT1("Message data index %lu out of bounds (%lu in msg)\n",
+                Index, DataInfo->NumberOfEntries);
+        Status = STATUS_INVALID_PARAMETER;
+        goto CleanupWithLock;
+    }
+
+    /* Check if the caller wants to read/write more data than expected */
+    if (BufferLength > DataInfo->Entries[Index].DataLength)
+    {
+        DPRINT1("Trying to read more data (%lu) than available (%lu)\n",
+                BufferLength, DataInfo->Entries[Index].DataLength);
+        Status = STATUS_INVALID_PARAMETER;
+        goto CleanupWithLock;
+    }
+
+    /* Get the data pointer */
+    DataInfoBaseAddress = DataInfo->Entries[Index].BaseAddress;
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+
+    if (Write)
+    {
+        /* Copy data from the caller to the message sender */
+        Status = MmCopyVirtualMemory(PsGetCurrentProcess(),
+                                     Buffer,
+                                     ClientThread->ThreadsProcess,
+                                     DataInfoBaseAddress,
+                                     BufferLength,
+                                     PreviousMode,
+                                     &LocalReturnlength);
+    }
+    else
+    {
+        /* Copy data from the message sender to the caller */
+        Status = MmCopyVirtualMemory(ClientThread->ThreadsProcess,
+                                     DataInfoBaseAddress,
+                                     PsGetCurrentProcess(),
+                                     Buffer,
+                                     BufferLength,
+                                     PreviousMode,
+                                     &LocalReturnlength);
+    }
+
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("MmCopyVirtualMemory failed: 0x%ls\n", Status);
+        goto Cleanup;
+    }
+
+    /* Check if the caller asked to return the copied length */
+    if (Returnlength != NULL)
+    {
+        _SEH2_TRY
+        {
+            *Returnlength = LocalReturnlength;
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            /* Ignore */
+            DPRINT1("Exception writing Returnlength, ignoring\n");
+        }
+        _SEH2_END;
+    }
+
+Cleanup:
+
+    if (ClientThread != NULL)
+        ObDereferenceObject(ClientThread);
+
+    ObDereferenceObject(Port);
+
+    return Status;
+
+CleanupWithLock:
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+    goto Cleanup;
+}
+
 /*
  * @unimplemented
  */
 /*
  * @unimplemented
  */
@@ -720,10 +956,16 @@ NtReadRequestData(IN HANDLE PortHandle,
                   IN ULONG Index,
                   IN PVOID Buffer,
                   IN ULONG BufferLength,
                   IN ULONG Index,
                   IN PVOID Buffer,
                   IN ULONG BufferLength,
-                  OUT PULONG Returnlength)
+                  OUT PULONG ReturnLength)
 {
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    /* Call the internal function */
+    return LpcpCopyRequestData(FALSE,
+                               PortHandle,
+                               Message,
+                               Index,
+                               Buffer,
+                               BufferLength,
+                               ReturnLength);
 }
 
 /*
 }
 
 /*
@@ -738,8 +980,14 @@ NtWriteRequestData(IN HANDLE PortHandle,
                    IN ULONG BufferLength,
                    OUT PULONG ReturnLength)
 {
                    IN ULONG BufferLength,
                    OUT PULONG ReturnLength)
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    /* Call the internal function */
+    return LpcpCopyRequestData(TRUE,
+                               PortHandle,
+                               Message,
+                               Index,
+                               Buffer,
+                               BufferLength,
+                               ReturnLength);
 }
 
 /* EOF */
 }
 
 /* EOF */
index 7a3cd45..a9aad6a 100644 (file)
@@ -207,34 +207,34 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
     {
         /* No type */
         case 0:
     {
         /* No type */
         case 0:
-            
+
             /* Assume LPC request */
             MessageType = LPC_REQUEST;
             break;
             /* Assume LPC request */
             MessageType = LPC_REQUEST;
             break;
-        
+
         /* LPC request callback */
         case LPC_REQUEST:
         /* LPC request callback */
         case LPC_REQUEST:
-            
+
             /* This is a callback */
             Callback = TRUE;
             break;
             /* This is a callback */
             Callback = TRUE;
             break;
-        
+
         /* Anything else */
         case LPC_CLIENT_DIED:
         case LPC_PORT_CLOSED:
         case LPC_EXCEPTION:
         case LPC_DEBUG_EVENT:
         case LPC_ERROR_EVENT:
         /* Anything else */
         case LPC_CLIENT_DIED:
         case LPC_PORT_CLOSED:
         case LPC_EXCEPTION:
         case LPC_DEBUG_EVENT:
         case LPC_ERROR_EVENT:
-            
+
             /* Nothing to do */
             break;
             /* Nothing to do */
             break;
-            
+
         default:
         default:
-            
+
             /* Invalid message type */
             return STATUS_INVALID_PARAMETER;
     }
             /* Invalid message type */
             return STATUS_INVALID_PARAMETER;
     }
-    
+
     /* Set the request type */
     LpcRequest->u2.s2.Type = MessageType;
 
     /* Set the request type */
     LpcRequest->u2.s2.Type = MessageType;
 
@@ -401,10 +401,10 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
                             (&Message->Request) + 1,
                             0,
                             NULL);
                             (&Message->Request) + 1,
                             0,
                             NULL);
-            
+
             /* Acquire the lock */
             KeAcquireGuardedMutex(&LpcpLock);
             /* Acquire the lock */
             KeAcquireGuardedMutex(&LpcpLock);
-            
+
             /* Check if we replied to a thread */
             if (Message->RepliedToThread)
             {
             /* Check if we replied to a thread */
             if (Message->RepliedToThread)
             {
@@ -633,6 +633,47 @@ NtRequestPort(IN HANDLE PortHandle,
     return Status;
 }
 
     return Status;
 }
 
+NTSTATUS
+NTAPI
+LpcpVerifyMessageDataInfo(
+    _In_ PPORT_MESSAGE Message,
+    _Out_ PULONG NumberOfDataEntries)
+{
+    PLPCP_DATA_INFO DataInfo;
+    PUCHAR EndOfEntries;
+
+    /* Check if we have no data info at all */
+    if (Message->u2.s2.DataInfoOffset == 0)
+    {
+        *NumberOfDataEntries = 0;
+        return STATUS_SUCCESS;
+    }
+
+    /* Make sure the data info structure is within the message */
+    if (((ULONG)Message->u1.s1.TotalLength <
+            sizeof(PORT_MESSAGE) + sizeof(LPCP_DATA_INFO)) ||
+        ((ULONG)Message->u2.s2.DataInfoOffset < sizeof(PORT_MESSAGE)) ||
+        ((ULONG)Message->u2.s2.DataInfoOffset >
+            ((ULONG)Message->u1.s1.TotalLength - sizeof(LPCP_DATA_INFO))))
+    {
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Get a pointer to the data info */
+    DataInfo = LpcpGetDataInfoFromMessage(Message);
+
+    /* Make sure the full data info with all entries is within the message */
+    EndOfEntries = (PUCHAR)&DataInfo->Entries[DataInfo->NumberOfEntries];
+    if ((EndOfEntries > ((PUCHAR)Message + (ULONG)Message->u1.s1.TotalLength)) ||
+        (EndOfEntries < (PUCHAR)Message))
+    {
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    *NumberOfDataEntries = DataInfo->NumberOfEntries;
+    return STATUS_SUCCESS;
+}
+
 /*
  * @implemented
  */
 /*
  * @implemented
  */
@@ -642,6 +683,8 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
                        IN PPORT_MESSAGE LpcRequest,
                        IN OUT PPORT_MESSAGE LpcReply)
 {
                        IN PPORT_MESSAGE LpcRequest,
                        IN OUT PPORT_MESSAGE LpcReply)
 {
+    PORT_MESSAGE LocalLpcRequest;
+    ULONG NumberOfDataEntries;
     PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
     PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
@@ -650,6 +693,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     BOOLEAN Callback;
     PKSEMAPHORE Semaphore;
     ULONG MessageType;
     BOOLEAN Callback;
     PKSEMAPHORE Semaphore;
     ULONG MessageType;
+    PLPCP_DATA_INFO DataInfo;
     PAGED_CODE();
     LPCTRACE(LPC_SEND_DEBUG,
              "Handle: %p. Messages: %p/%p. Type: %lx\n",
     PAGED_CODE();
     LPCTRACE(LPC_SEND_DEBUG,
              "Handle: %p. Messages: %p/%p. Type: %lx\n",
@@ -661,32 +705,78 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     /* Check if the thread is dying */
     if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;
 
     /* Check if the thread is dying */
     if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;
 
+    /* Check for user mode access */
+    if (PreviousMode != KernelMode)
+    {
+        _SEH2_TRY
+        {
+            /* Probe the full request message and copy the base structure */
+            ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
+            ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG));
+            LocalLpcRequest = *LpcRequest;
+
+            /* Probe the reply message for write */
+            ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG));
+
+            /* Make sure the data entries in the request message are valid */
+            Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
+            if (!NT_SUCCESS(Status))
+            {
+                DPRINT1("LpcpVerifyMessageDataInfo failed\n");
+                return Status;
+            }
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            DPRINT1("Got exception\n");
+            return _SEH2_GetExceptionCode();
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        LocalLpcRequest = *LpcRequest;
+        Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
+        if (!NT_SUCCESS(Status))
+        {
+            DPRINT1("LpcpVerifyMessageDataInfo failed\n");
+            return Status;
+        }
+    }
+
     /* Check if this is an LPC Request */
     /* Check if this is an LPC Request */
-    if (LpcpGetMessageType(LpcRequest) == LPC_REQUEST)
+    if (LpcpGetMessageType(&LocalLpcRequest) == LPC_REQUEST)
     {
         /* Then it's a callback */
         Callback = TRUE;
     }
     {
         /* Then it's a callback */
         Callback = TRUE;
     }
-    else if (LpcpGetMessageType(LpcRequest))
+    else if (LpcpGetMessageType(&LocalLpcRequest))
     {
         /* This is a not kernel-mode message */
     {
         /* This is a not kernel-mode message */
+        DPRINT1("Not a kernel-mode message!\n");
         return STATUS_INVALID_PARAMETER;
     }
     else
     {
         /* This is a kernel-mode message without a callback */
         return STATUS_INVALID_PARAMETER;
     }
     else
     {
         /* This is a kernel-mode message without a callback */
-        LpcRequest->u2.s2.Type |= LPC_REQUEST;
+        LocalLpcRequest.u2.s2.Type |= LPC_REQUEST;
         Callback = FALSE;
     }
 
     /* Get the message type */
         Callback = FALSE;
     }
 
     /* Get the message type */
-    MessageType = LpcRequest->u2.s2.Type;
+    MessageType = LocalLpcRequest.u2.s2.Type;
+
+    /* Due to the above probe, we know that TotalLength is positive */
+    NT_ASSERT(LocalLpcRequest.u1.s1.TotalLength >= 0);
 
     /* Validate the length */
 
     /* Validate the length */
-    if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-         (ULONG)LpcRequest->u1.s1.TotalLength)
+    if ((((ULONG)(USHORT)LocalLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)LocalLpcRequest.u1.s1.TotalLength))
     {
         /* Fail */
     {
         /* Fail */
+        DPRINT1("Invalid message length: %u, %u\n",
+                LocalLpcRequest.u1.s1.DataLength,
+                LocalLpcRequest.u1.s1.TotalLength);
         return STATUS_INVALID_PARAMETER;
     }
 
         return STATUS_INVALID_PARAMETER;
     }
 
@@ -700,10 +790,13 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate the message length */
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate the message length */
-    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
-        ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
+    if (((ULONG)LocalLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)LocalLpcRequest.u1.s1.TotalLength <= (ULONG)LocalLpcRequest.u1.s1.DataLength))
     {
         /* Fail */
     {
         /* Fail */
+        DPRINT1("Invalid message length: %u, %u\n",
+                LocalLpcRequest.u1.s1.DataLength,
+                LocalLpcRequest.u1.s1.TotalLength);
         ObDereferenceObject(Port);
         return STATUS_PORT_MESSAGE_TOO_LONG;
     }
         ObDereferenceObject(Port);
         return STATUS_PORT_MESSAGE_TOO_LONG;
     }
@@ -713,6 +806,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     if (!Message)
     {
         /* Fail if we couldn't allocate a message */
     if (!Message)
     {
         /* Fail if we couldn't allocate a message */
+        DPRINT1("Failed to allocate a message!\n");
         ObDereferenceObject(Port);
         return STATUS_NO_MEMORY;
     }
         ObDereferenceObject(Port);
         return STATUS_NO_MEMORY;
     }
@@ -729,6 +823,22 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
         /* No callback, just copy the message */
         _SEH2_TRY
         {
         /* No callback, just copy the message */
         _SEH2_TRY
         {
+            /* Check if we have data info entries */
+            if (LpcRequest->u2.s2.DataInfoOffset != 0)
+            {
+                /* Get the data info and check if the number of entries matches
+                   what we expect */
+                DataInfo = LpcpGetDataInfoFromMessage(LpcRequest);
+                if (DataInfo->NumberOfEntries != NumberOfDataEntries)
+                {
+                    LpcpFreeToPortZone(Message, 0);
+                    ObDereferenceObject(Port);
+                    DPRINT1("NumberOfEntries has changed: %u, %u\n",
+                            DataInfo->NumberOfEntries, NumberOfDataEntries);
+                    return STATUS_INVALID_PARAMETER;
+                }
+            }
+
             /* Copy it */
             LpcpMoveMessage(&Message->Request,
                             LpcRequest,
             /* Copy it */
             LpcpMoveMessage(&Message->Request,
                             LpcRequest,
@@ -739,6 +849,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
             /* Fail */
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
             /* Fail */
+            DPRINT1("Got exception!\n");
             LpcpFreeToPortZone(Message, 0);
             ObDereferenceObject(Port);
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
             LpcpFreeToPortZone(Message, 0);
             ObDereferenceObject(Port);
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
@@ -759,6 +870,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             if (!QueuePort)
             {
                 /* We have no connected port, fail */
             if (!QueuePort)
             {
                 /* We have no connected port, fail */
+                DPRINT1("No connected port\n");
                 LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                 ObDereferenceObject(Port);
                 return STATUS_PORT_DISCONNECTED;
                 LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                 ObDereferenceObject(Port);
                 return STATUS_PORT_DISCONNECTED;
@@ -767,28 +879,21 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             /* This will be the rundown port */
             ReplyPort = QueuePort;
 
             /* This will be the rundown port */
             ReplyPort = QueuePort;
 
-            /* Check if this is a communication port */
+            /* Check if this is a client port */
             if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
             {
             if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
             {
-                /* Copy the port context and use the connection port */
+                /* Copy the port context */
                 Message->PortContext = QueuePort->PortContext;
                 Message->PortContext = QueuePort->PortContext;
-                ConnectionPort = QueuePort = Port->ConnectionPort;
-                if (!ConnectionPort)
-                {
-                    /* Fail */
-                    LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
-                    ObDereferenceObject(Port);
-                    return STATUS_PORT_DISCONNECTED;
-                }
             }
             }
-            else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
-                      LPCP_COMMUNICATION_PORT)
+
+            if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
             {
                 /* Use the connection port for anything but communication ports */
                 ConnectionPort = QueuePort = Port->ConnectionPort;
                 if (!ConnectionPort)
                 {
                     /* Fail */
             {
                 /* Use the connection port for anything but communication ports */
                 ConnectionPort = QueuePort = Port->ConnectionPort;
                 if (!ConnectionPort)
                 {
                     /* Fail */
+                    DPRINT1("No connection port\n");
                     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                     ObDereferenceObject(Port);
                     return STATUS_PORT_DISCONNECTED;
                     LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
                     ObDereferenceObject(Port);
                     return STATUS_PORT_DISCONNECTED;
@@ -883,6 +988,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             }
             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
             {
             }
             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
             {
+                DPRINT1("Got exception!\n");
                 Status = _SEH2_GetExceptionCode();
             }
             _SEH2_END;
                 Status = _SEH2_GetExceptionCode();
             }
             _SEH2_END;