- Use _SEH2_YIELD when returning from an exception instead of returning outside the...
[reactos.git] / reactos / ntoskrnl / lpc / reply.c
index 55fa1c0..0233ea1 100644 (file)
@@ -165,6 +165,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
     ULONG ConnectionInfoLength;
+    //PORT_MESSAGE CapturedReplyMessage;
+    LARGE_INTEGER CapturedTimeout;
+
     PAGED_CODE();
     LPCTRACE(LPC_REPLY_DEBUG,
              "Handle: %lx. Messages: %p/%p. Context: %p\n",
@@ -173,8 +176,40 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
              ReceiveMessage,
              PortContext);
 
-    /* If this is a system thread, then let it page out its stack */
-    if (Thread->SystemThread) WaitMode = UserMode;
+    if (KeGetPreviousMode() == UserMode)
+    {
+        _SEH2_TRY
+        {
+            if (ReplyMessage != NULL)
+            {
+                ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
+                /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
+                ReplyMessage = &CapturedReplyMessage;*/
+            }
+
+            if (Timeout != NULL)
+            {
+                ProbeForReadLargeInteger(Timeout);
+                RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER));
+                Timeout = &CapturedTimeout;
+            }
+
+            if (PortContext != NULL)
+                ProbeForWritePointer(PortContext);
+        }
+        _SEH2_EXCEPT(ExSystemExceptionFilter())
+        {
+            DPRINT1("SEH crash [1]\n");
+            DbgBreakPoint();
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        /* If this is a system thread, then let it page out its stack */
+        if (Thread->SystemThread) WaitMode = UserMode;
+    }
 
     /* Check if caller has a reply message */
     if (ReplyMessage)
@@ -388,68 +423,78 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     Thread->LpcReceivedMessageId = Message->Request.MessageId;
     Thread->LpcReceivedMsgIdValid = TRUE;
 
-    /* Check if this was a connection request */
-    if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
+    _SEH2_TRY
     {
-        /* Get the connection message */
-        ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
-        LPCTRACE(LPC_REPLY_DEBUG,
-                 "Request Messages: %p/%p\n",
-                 Message,
-                 ConnectMessage);
-
-        /* Get its length */
-        ConnectionInfoLength = Message->Request.u1.s1.DataLength -
-                               sizeof(LPCP_CONNECTION_MESSAGE);
-
-        /* Return it as the receive message */
-        *ReceiveMessage = Message->Request;
-
-        /* Clear our stack variable so the message doesn't get freed */
-        Message = NULL;
-
-        /* Setup the receive message */
-        ReceiveMessage->u1.s1.TotalLength = (CSHORT)(sizeof(LPCP_MESSAGE) +
-                                                     ConnectionInfoLength);
-        ReceiveMessage->u1.s1.DataLength = (CSHORT)ConnectionInfoLength;
-        RtlCopyMemory(ReceiveMessage + 1,
-                      ConnectMessage + 1,
-                      ConnectionInfoLength);
-
-        /* Clear the port context if the caller requested one */
-        if (PortContext) *PortContext = NULL;
-    }
-    else if (Message->Request.u2.s2.Type != LPC_REPLY)
-    {
-        /* Otherwise, this is a new message or event */
-        LPCTRACE(LPC_REPLY_DEBUG,
-                 "Non-Reply Messages: %p/%p\n",
-                 &Message->Request,
-                 (&Message->Request) + 1);
-
-        /* Copy it */
-        LpcpMoveMessage(ReceiveMessage,
-                        &Message->Request,
-                        (&Message->Request) + 1,
-                        0,
-                        NULL);
+        /* Check if this was a connection request */
+        if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
+        {
+            /* Get the connection message */
+            ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+            LPCTRACE(LPC_REPLY_DEBUG,
+                     "Request Messages: %p/%p\n",
+                     Message,
+                     ConnectMessage);
 
-        /* Return its context */
-        if (PortContext) *PortContext = Message->PortContext;
+            /* Get its length */
+            ConnectionInfoLength = Message->Request.u1.s1.DataLength -
+                                   sizeof(LPCP_CONNECTION_MESSAGE);
 
-        /* And check if it has data information */
-        if (Message->Request.u2.s2.DataInfoOffset)
-        {
-            /* It does, save it, and don't free the message below */
-            LpcpSaveDataInfoMessage(Port, Message, 1);
+            /* Return it as the receive message */
+            *ReceiveMessage = Message->Request;
+
+            /* Clear our stack variable so the message doesn't get freed */
             Message = NULL;
+
+            /* Setup the receive message */
+            ReceiveMessage->u1.s1.TotalLength = (CSHORT)(sizeof(LPCP_MESSAGE) +
+                                                         ConnectionInfoLength);
+            ReceiveMessage->u1.s1.DataLength = (CSHORT)ConnectionInfoLength;
+            RtlCopyMemory(ReceiveMessage + 1,
+                          ConnectMessage + 1,
+                          ConnectionInfoLength);
+
+            /* Clear the port context if the caller requested one */
+            if (PortContext) *PortContext = NULL;
+        }
+        else if (LpcpGetMessageType(&Message->Request) != LPC_REPLY)
+        {
+            /* Otherwise, this is a new message or event */
+            LPCTRACE(LPC_REPLY_DEBUG,
+                     "Non-Reply Messages: %p/%p\n",
+                     &Message->Request,
+                     (&Message->Request) + 1);
+
+            /* Copy it */
+            LpcpMoveMessage(ReceiveMessage,
+                            &Message->Request,
+                            (&Message->Request) + 1,
+                            0,
+                            NULL);
+
+            /* Return its context */
+            if (PortContext) *PortContext = Message->PortContext;
+
+            /* And check if it has data information */
+            if (Message->Request.u2.s2.DataInfoOffset)
+            {
+                /* It does, save it, and don't free the message below */
+                LpcpSaveDataInfoMessage(Port, Message, 1);
+                Message = NULL;
+            }
+        }
+        else
+        {
+            /* This is a reply message, should never happen! */
+            ASSERT(FALSE);
         }
     }
-    else
+    _SEH2_EXCEPT(ExSystemExceptionFilter())
     {
-        /* This is a reply message, should never happen! */
-        ASSERT(FALSE);
+        DPRINT1("SEH crash [2]\n");
+        DbgBreakPoint();
+        Status = _SEH2_GetExceptionCode();
     }
+    _SEH2_END;
 
     /* Check if we have a message pointer here */
     if (Message)