[NTOS:LPC]
[reactos.git] / reactos / ntoskrnl / lpc / reply.c
index 910d0cd..d9fa834 100644 (file)
@@ -192,7 +192,7 @@ NtReplyPort(IN HANDLE PortHandle,
 {
     NTSTATUS Status;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
-    // PORT_MESSAGE CapturedReplyMessage;
+    PORT_MESSAGE CapturedReplyMessage;
     PLPCP_PORT_OBJECT Port;
     PLPCP_MESSAGE Message;
     PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
@@ -203,32 +203,35 @@ NtReplyPort(IN HANDLE PortHandle,
              PortHandle,
              ReplyMessage);
 
-    if (KeGetPreviousMode() == UserMode)
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
     {
         _SEH2_TRY
         {
-            ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
-            /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
-            ReplyMessage = &CapturedReplyMessage;*/
+            ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
+            CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
         }
-        _SEH2_EXCEPT(ExSystemExceptionFilter())
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            DPRINT1("SEH crash [1]\n");
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
         _SEH2_END;
     }
+    else
+    {
+        CapturedReplyMessage = *ReplyMessage;
+    }
 
     /* Validate its length */
-    if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-        (ULONG)ReplyMessage->u1.s1.TotalLength)
+    if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)CapturedReplyMessage.u1.s1.TotalLength)
     {
         /* Fail */
         return STATUS_INVALID_PARAMETER;
     }
 
     /* Make sure it has a valid ID */
-    if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
+    if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
 
     /* Get the Port object */
     Status = ObReferenceObjectByHandle(PortHandle,
@@ -240,9 +243,9 @@ NtReplyPort(IN HANDLE PortHandle,
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Validate its length in respect to the port object */
-    if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
-        ((ULONG)ReplyMessage->u1.s1.TotalLength <=
-        (ULONG)ReplyMessage->u1.s1.DataLength))
+    if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
+         (ULONG)CapturedReplyMessage.u1.s1.DataLength))
     {
         /* Too large, fail */
         ObDereferenceObject(Port);
@@ -250,7 +253,7 @@ NtReplyPort(IN HANDLE PortHandle,
     }
 
     /* Get the ETHREAD corresponding to it */
-    Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
+    Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
                                         NULL,
                                         &WakeupThread);
     if (!NT_SUCCESS(Status))
@@ -274,7 +277,7 @@ NtReplyPort(IN HANDLE PortHandle,
     KeAcquireGuardedMutex(&LpcpLock);
 
     /* Make sure this is the reply the thread is waiting for */
-    if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
+    if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
         ((LpcpGetMessageFromThread(WakeupThread)) &&
         (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)-> Request)
             != LPC_REQUEST)))
@@ -290,7 +293,7 @@ NtReplyPort(IN HANDLE PortHandle,
     _SEH2_TRY
     {
         LpcpMoveMessage(&Message->Request,
-                        ReplyMessage,
+                        &CapturedReplyMessage,
                         ReplyMessage + 1,
                         LPC_REPLY,
                         NULL);
@@ -324,7 +327,7 @@ NtReplyPort(IN HANDLE PortHandle,
 
     /* Check if this is the message the thread had received */
     if ((Thread->LpcReceivedMsgIdValid) &&
-        (Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
+        (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
     {
         /* Clear this data */
         Thread->LpcReceivedMessageId = 0;
@@ -333,9 +336,9 @@ NtReplyPort(IN HANDLE PortHandle,
 
     /* Free any data information */
     LpcpFreeDataInfoMessage(Port,
-                            ReplyMessage->MessageId,
-                            ReplyMessage->CallbackId,
-                            ReplyMessage->ClientId);
+                            CapturedReplyMessage.MessageId,
+                            CapturedReplyMessage.CallbackId,
+                            CapturedReplyMessage.ClientId);
 
     /* Release the lock and release the LPC semaphore to wake up waiters */
     KeReleaseGuardedMutex(&LpcpLock);
@@ -362,7 +365,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
 {
     NTSTATUS Status;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
-    // PORT_MESSAGE CapturedReplyMessage;
+    PORT_MESSAGE CapturedReplyMessage;
     LARGE_INTEGER CapturedTimeout;
     PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
     PLPCP_MESSAGE Message;
@@ -378,30 +381,29 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
              ReceiveMessage,
              PortContext);
 
-    if (KeGetPreviousMode() == UserMode)
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
     {
         _SEH2_TRY
         {
+            if (PortContext != NULL)
+                ProbeForWritePointer(PortContext);
+
             if (ReplyMessage != NULL)
             {
-                ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
-                /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
-                ReplyMessage = &CapturedReplyMessage;*/
+                ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
+                CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
             }
 
             if (Timeout != NULL)
             {
                 ProbeForReadLargeInteger(Timeout);
-                RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER));
+                CapturedTimeout = *(volatile LARGE_INTEGER*)Timeout;
                 Timeout = &CapturedTimeout;
             }
-
-            if (PortContext != NULL)
-                ProbeForWritePointer(PortContext);
         }
-        _SEH2_EXCEPT(ExSystemExceptionFilter())
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            DPRINT1("SEH crash [1]\n");
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
         _SEH2_END;
@@ -410,21 +412,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     {
         /* If this is a system thread, then let it page out its stack */
         if (Thread->SystemThread) WaitMode = UserMode;
+
+        if (ReplyMessage != NULL)
+            CapturedReplyMessage = *ReplyMessage;
     }
 
     /* Check if caller has a reply message */
     if (ReplyMessage)
     {
         /* Validate its length */
-        if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-            (ULONG)ReplyMessage->u1.s1.TotalLength)
+        if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+             (ULONG)CapturedReplyMessage.u1.s1.TotalLength)
         {
             /* Fail */
             return STATUS_INVALID_PARAMETER;
         }
 
         /* Make sure it has a valid ID */
-        if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
+        if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
     }
 
     /* Get the Port object */
@@ -440,9 +445,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     if (ReplyMessage)
     {
         /* Validate its length in respect to the port object */
-        if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
-            ((ULONG)ReplyMessage->u1.s1.TotalLength <=
-             (ULONG)ReplyMessage->u1.s1.DataLength))
+        if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
+            ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
+             (ULONG)CapturedReplyMessage.u1.s1.DataLength))
         {
             /* Too large, fail */
             ObDereferenceObject(Port);
@@ -490,7 +495,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     if (ReplyMessage)
     {
         /* Get the ETHREAD corresponding to it */
-        Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
+        Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
                                             NULL,
                                             &WakeupThread);
         if (!NT_SUCCESS(Status))
@@ -516,7 +521,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         KeAcquireGuardedMutex(&LpcpLock);
 
         /* Make sure this is the reply the thread is waiting for */
-        if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
+        if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
             ((LpcpGetMessageFromThread(WakeupThread)) &&
              (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->Request)
                 != LPC_REQUEST)))
@@ -530,11 +535,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         }
 
         /* Copy the message */
-        LpcpMoveMessage(&Message->Request,
-                        ReplyMessage,
-                        ReplyMessage + 1,
-                        LPC_REPLY,
-                        NULL);
+        _SEH2_TRY
+        {
+            LpcpMoveMessage(&Message->Request,
+                            &CapturedReplyMessage,
+                            ReplyMessage + 1,
+                            LPC_REPLY,
+                            NULL);
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            /* Cleanup and return the exception code */
+            LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+            ObDereferenceObject(WakeupThread);
+            ObDereferenceObject(Port);
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
 
         /* Reference the thread while we use it */
         ObReferenceObject(WakeupThread);
@@ -555,7 +573,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
 
         /* Check if this is the message the thread had received */
         if ((Thread->LpcReceivedMsgIdValid) &&
-            (Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
+            (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
         {
             /* Clear this data */
             Thread->LpcReceivedMessageId = 0;
@@ -564,9 +582,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
 
         /* Free any data information */
         LpcpFreeDataInfoMessage(Port,
-                                ReplyMessage->MessageId,
-                                ReplyMessage->CallbackId,
-                                ReplyMessage->ClientId);
+                                CapturedReplyMessage.MessageId,
+                                CapturedReplyMessage.CallbackId,
+                                CapturedReplyMessage.ClientId);
 
         /* Release the lock and release the LPC semaphore to wake up waiters */
         KeReleaseGuardedMutex(&LpcpLock);
@@ -688,9 +706,8 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
             ASSERT(FALSE);
         }
     }
-    _SEH2_EXCEPT(ExSystemExceptionFilter())
+    _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
-        DPRINT1("SEH crash [2]\n");
         Status = _SEH2_GetExceptionCode();
     }
     _SEH2_END;
@@ -771,26 +788,24 @@ LpcpCopyRequestData(
 
     PAGED_CODE();
 
-    /* Check the previous mode */
-    PreviousMode = ExGetPreviousMode();
-    if (PreviousMode == KernelMode)
-    {
-        CapturedMessage = *Message;
-    }
-    else
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
     {
         _SEH2_TRY
         {
             ProbeForRead(Message, sizeof(*Message), sizeof(PVOID));
-            CapturedMessage = *Message;
+            CapturedMessage = *(volatile PORT_MESSAGE*)Message;
         }
         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
-            DPRINT1("Got exception!\n");
             _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
         _SEH2_END;
     }
+    else
+    {
+        CapturedMessage = *Message;
+    }
 
     /* Make sure there is any data to copy */
     if (CapturedMessage.u2.s2.DataInfoOffset == 0)