[NTOS:PNP] Avoid a fixed-length stack buffer in IopActionConfigureChildServices....
[reactos.git] / ntoskrnl / lpc / reply.c
index ef0b2cd..8454754 100644 (file)
@@ -99,6 +99,7 @@ LpcpFindDataInfoMessage(
 {
     PLPCP_MESSAGE Message;
     PLIST_ENTRY ListEntry;
+
     PAGED_CODE();
 
     /* Check if the port we want is the connection port */
@@ -141,13 +142,14 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination,
                 IN ULONG MessageType,
                 IN PCLIENT_ID ClientId)
 {
-    /* Set the Message size */
     LPCTRACE((LPC_REPLY_DEBUG | LPC_SEND_DEBUG),
              "Destination/Origin: %p/%p. Data: %p. Length: %lx\n",
              Destination,
              Origin,
              Data,
              Origin->u1.Length);
+
+    /* Set the Message size */
     Destination->u1.Length = Origin->u1.Length;
 
     /* Set the Message Type */
@@ -188,12 +190,12 @@ NTAPI
 NtReplyPort(IN HANDLE PortHandle,
             IN PPORT_MESSAGE ReplyMessage)
 {
-    PLPCP_PORT_OBJECT Port;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    PORT_MESSAGE CapturedReplyMessage;
+    PLPCP_PORT_OBJECT Port;
     PLPCP_MESSAGE Message;
     PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
-    //PORT_MESSAGE CapturedReplyMessage;
 
     PAGED_CODE();
     LPCTRACE(LPC_REPLY_DEBUG,
@@ -201,33 +203,35 @@ NtReplyPort(IN HANDLE PortHandle,
              PortHandle,
              ReplyMessage);
 
-    if (KeGetPreviousMode() == UserMode)
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
     {
         _SEH2_TRY
         {
-            ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
-            /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
-            ReplyMessage = &CapturedReplyMessage;*/
+            ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
+            CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
         }
-        _SEH2_EXCEPT(ExSystemExceptionFilter())
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            DPRINT1("SEH crash [1]\n");
-            DbgBreakPoint();
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
         _SEH2_END;
     }
+    else
+    {
+        CapturedReplyMessage = *ReplyMessage;
+    }
 
     /* Validate its length */
-    if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-        (ULONG)ReplyMessage->u1.s1.TotalLength)
+    if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)CapturedReplyMessage.u1.s1.TotalLength)
     {
         /* Fail */
         return STATUS_INVALID_PARAMETER;
     }
 
     /* Make sure it has a valid ID */
-    if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
+    if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
 
     /* Get the Port object */
     Status = ObReferenceObjectByHandle(PortHandle,
@@ -239,9 +243,9 @@ NtReplyPort(IN HANDLE PortHandle,
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate its length in respect to the port object */
-    if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
-        ((ULONG)ReplyMessage->u1.s1.TotalLength <=
-        (ULONG)ReplyMessage->u1.s1.DataLength))
+    if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
+         (ULONG)CapturedReplyMessage.u1.s1.DataLength))
     {
         /* Too large, fail */
         ObDereferenceObject(Port);
@@ -249,7 +253,7 @@ NtReplyPort(IN HANDLE PortHandle,
     }
 
     /* Get the ETHREAD corresponding to it */
-    Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
+    Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
                                         NULL,
                                         &WakeupThread);
     if (!NT_SUCCESS(Status))
@@ -273,10 +277,10 @@ NtReplyPort(IN HANDLE PortHandle,
     KeAcquireGuardedMutex(&LpcpLock);
 
     /* Make sure this is the reply the thread is waiting for */
-    if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
+    if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
         ((LpcpGetMessageFromThread(WakeupThread)) &&
-        (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->
-        Request) != LPC_REQUEST)))
+        (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)-> Request)
+            != LPC_REQUEST)))
     {
         /* It isn't, fail */
         LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
@@ -289,14 +293,14 @@ NtReplyPort(IN HANDLE PortHandle,
     _SEH2_TRY
     {
         LpcpMoveMessage(&Message->Request,
-                        ReplyMessage,
+                        &CapturedReplyMessage,
                         ReplyMessage + 1,
                         LPC_REPLY,
                         NULL);
     }
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
-        /* Fail */
+        /* Cleanup and return the exception code */
         LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
         ObDereferenceObject(WakeupThread);
         ObDereferenceObject(Port);
@@ -323,7 +327,7 @@ NtReplyPort(IN HANDLE PortHandle,
 
     /* Check if this is the message the thread had received */
     if ((Thread->LpcReceivedMsgIdValid) &&
-        (Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
+        (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
     {
         /* Clear this data */
         Thread->LpcReceivedMessageId = 0;
@@ -332,9 +336,9 @@ NtReplyPort(IN HANDLE PortHandle,
 
     /* Free any data information */
     LpcpFreeDataInfoMessage(Port,
-                            ReplyMessage->MessageId,
-                            ReplyMessage->CallbackId,
-                            ReplyMessage->ClientId);
+                            CapturedReplyMessage.MessageId,
+                            CapturedReplyMessage.CallbackId,
+                            CapturedReplyMessage.ClientId);
 
     /* Release the lock and release the LPC semaphore to wake up waiters */
     KeReleaseGuardedMutex(&LpcpLock);
@@ -359,15 +363,15 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
                          OUT PPORT_MESSAGE ReceiveMessage,
                          IN PLARGE_INTEGER Timeout OPTIONAL)
 {
-    PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
     NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
+    PORT_MESSAGE CapturedReplyMessage;
+    LARGE_INTEGER CapturedTimeout;
+    PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
     PLPCP_MESSAGE Message;
     PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
     ULONG ConnectionInfoLength;
-    //PORT_MESSAGE CapturedReplyMessage;
-    LARGE_INTEGER CapturedTimeout;
 
     PAGED_CODE();
     LPCTRACE(LPC_REPLY_DEBUG,
@@ -377,31 +381,29 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
              ReceiveMessage,
              PortContext);
 
-    if (KeGetPreviousMode() == UserMode)
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
     {
         _SEH2_TRY
         {
+            if (PortContext != NULL)
+                ProbeForWritePointer(PortContext);
+
             if (ReplyMessage != NULL)
             {
-                ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
-                /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
-                ReplyMessage = &CapturedReplyMessage;*/
+                ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
+                CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
             }
 
             if (Timeout != NULL)
             {
                 ProbeForReadLargeInteger(Timeout);
-                RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER));
+                CapturedTimeout = *(volatile LARGE_INTEGER*)Timeout;
                 Timeout = &CapturedTimeout;
             }
-
-            if (PortContext != NULL)
-                ProbeForWritePointer(PortContext);
         }
-        _SEH2_EXCEPT(ExSystemExceptionFilter())
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            DPRINT1("SEH crash [1]\n");
-            DbgBreakPoint();
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
         _SEH2_END;
@@ -410,21 +412,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     {
         /* If this is a system thread, then let it page out its stack */
         if (Thread->SystemThread) WaitMode = UserMode;
+
+        if (ReplyMessage != NULL)
+            CapturedReplyMessage = *ReplyMessage;
     }
 
     /* Check if caller has a reply message */
     if (ReplyMessage)
     {
         /* Validate its length */
-        if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-            (ULONG)ReplyMessage->u1.s1.TotalLength)
+        if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+             (ULONG)CapturedReplyMessage.u1.s1.TotalLength)
         {
             /* Fail */
             return STATUS_INVALID_PARAMETER;
         }
 
         /* Make sure it has a valid ID */
-        if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
+        if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
     }
 
     /* Get the Port object */
@@ -440,9 +445,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     if (ReplyMessage)
     {
         /* Validate its length in respect to the port object */
-        if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
-            ((ULONG)ReplyMessage->u1.s1.TotalLength <=
-             (ULONG)ReplyMessage->u1.s1.DataLength))
+        if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
+            ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
+             (ULONG)CapturedReplyMessage.u1.s1.DataLength))
         {
             /* Too large, fail */
             ObDereferenceObject(Port);
@@ -490,7 +495,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     if (ReplyMessage)
     {
         /* Get the ETHREAD corresponding to it */
-        Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
+        Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
                                             NULL,
                                             &WakeupThread);
         if (!NT_SUCCESS(Status))
@@ -516,10 +521,10 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         KeAcquireGuardedMutex(&LpcpLock);
 
         /* Make sure this is the reply the thread is waiting for */
-        if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
+        if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
             ((LpcpGetMessageFromThread(WakeupThread)) &&
-             (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->
-                                 Request) != LPC_REQUEST)))
+             (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->Request)
+                != LPC_REQUEST)))
         {
             /* It isn't, fail */
             LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
@@ -530,11 +535,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         }
 
         /* Copy the message */
-        LpcpMoveMessage(&Message->Request,
-                        ReplyMessage,
-                        ReplyMessage + 1,
-                        LPC_REPLY,
-                        NULL);
+        _SEH2_TRY
+        {
+            LpcpMoveMessage(&Message->Request,
+                            &CapturedReplyMessage,
+                            ReplyMessage + 1,
+                            LPC_REPLY,
+                            NULL);
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            /* Cleanup and return the exception code */
+            LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+            ObDereferenceObject(WakeupThread);
+            ObDereferenceObject(Port);
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
 
         /* Reference the thread while we use it */
         ObReferenceObject(WakeupThread);
@@ -555,7 +573,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
 
         /* Check if this is the message the thread had received */
         if ((Thread->LpcReceivedMsgIdValid) &&
-            (Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
+            (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
         {
             /* Clear this data */
             Thread->LpcReceivedMessageId = 0;
@@ -564,9 +582,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
 
         /* Free any data information */
         LpcpFreeDataInfoMessage(Port,
-                                ReplyMessage->MessageId,
-                                ReplyMessage->CallbackId,
-                                ReplyMessage->ClientId);
+                                CapturedReplyMessage.MessageId,
+                                CapturedReplyMessage.CallbackId,
+                                CapturedReplyMessage.ClientId);
 
         /* Release the lock and release the LPC semaphore to wake up waiters */
         KeReleaseGuardedMutex(&LpcpLock);
@@ -590,7 +608,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         if (ReceivePort->Flags & LPCP_WAITABLE_PORT)
         {
             /* Reset its event */
-            KeResetEvent(&ReceivePort->WaitEvent);
+            KeClearEvent(&ReceivePort->WaitEvent);
         }
 
         /* Release the lock and fail */
@@ -601,8 +619,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     }
 
     /* Get the message on the queue */
-    Message = CONTAINING_RECORD(RemoveHeadList(&ReceivePort->
-                                               MsgQueue.ReceiveHead),
+    Message = CONTAINING_RECORD(RemoveHeadList(&ReceivePort->MsgQueue.ReceiveHead),
                                 LPCP_MESSAGE,
                                 Entry);
 
@@ -613,7 +630,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         if (ReceivePort->Flags & LPCP_WAITABLE_PORT)
         {
             /* Reset its event */
-            KeResetEvent(&ReceivePort->WaitEvent);
+            KeClearEvent(&ReceivePort->WaitEvent);
         }
     }
 
@@ -689,10 +706,8 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
             ASSERT(FALSE);
         }
     }
-    _SEH2_EXCEPT(ExSystemExceptionFilter())
+    _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
-        DPRINT1("SEH crash [2]\n");
-        DbgBreakPoint();
         Status = _SEH2_GetExceptionCode();
     }
     _SEH2_END;
@@ -759,39 +774,38 @@ LpcpCopyRequestData(
     IN ULONG Index,
     IN PVOID Buffer,
     IN ULONG BufferLength,
-    OUT PULONG Returnlength)
+    OUT PULONG ReturnLength)
 {
-    KPROCESSOR_MODE PreviousMode;
+    NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PORT_MESSAGE CapturedMessage;
     PLPCP_PORT_OBJECT Port = NULL;
     PETHREAD ClientThread = NULL;
-    ULONG LocalReturnlength;
+    SIZE_T LocalReturnLength;
     PLPCP_MESSAGE InfoMessage;
     PLPCP_DATA_INFO DataInfo;
     PVOID DataInfoBaseAddress;
-    NTSTATUS Status;
+
     PAGED_CODE();
 
-    /* Check the previous mode */
-    PreviousMode = ExGetPreviousMode();
-    if (PreviousMode == KernelMode)
-    {
-        CapturedMessage = *Message;
-    }
-    else
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
     {
         _SEH2_TRY
         {
             ProbeForRead(Message, sizeof(*Message), sizeof(PVOID));
-            CapturedMessage = *Message;
+            CapturedMessage = *(volatile PORT_MESSAGE*)Message;
         }
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            DPRINT1("Got exception!\n");
-            return _SEH2_GetExceptionCode();
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
         _SEH2_END;
     }
+    else
+    {
+        CapturedMessage = *Message;
+    }
 
     /* Make sure there is any data to copy */
     if (CapturedMessage.u2.s2.DataInfoOffset == 0)
@@ -892,7 +906,7 @@ LpcpCopyRequestData(
                                      DataInfoBaseAddress,
                                      BufferLength,
                                      PreviousMode,
-                                     &LocalReturnlength);
+                                     &LocalReturnLength);
     }
     else
     {
@@ -903,7 +917,7 @@ LpcpCopyRequestData(
                                      Buffer,
                                      BufferLength,
                                      PreviousMode,
-                                     &LocalReturnlength);
+                                     &LocalReturnLength);
     }
 
     if (!NT_SUCCESS(Status))
@@ -913,16 +927,16 @@ LpcpCopyRequestData(
     }
 
     /* Check if the caller asked to return the copied length */
-    if (Returnlength != NULL)
+    if (ReturnLength != NULL)
     {
         _SEH2_TRY
         {
-            *Returnlength = LocalReturnlength;
+            *ReturnLength = LocalReturnLength;
         }
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
             /* Ignore */
-            DPRINT1("Exception writing Returnlength, ignoring\n");
+            DPRINT1("Exception writing ReturnLength, ignoring\n");
         }
         _SEH2_END;
     }
@@ -944,7 +958,7 @@ CleanupWithLock:
 }
 
 /*
- * @unimplemented
+ * @implemented
  */
 NTSTATUS
 NTAPI
@@ -966,7 +980,7 @@ NtReadRequestData(IN HANDLE PortHandle,
 }
 
 /*
- * @unimplemented
+ * @implemented
  */
 NTSTATUS
 NTAPI