- Use _SEH2_YIELD when returning from an exception instead of returning outside the...
[reactos.git] / reactos / ntoskrnl / lpc / reply.c
index 751564b..0233ea1 100644 (file)
@@ -18,7 +18,8 @@ VOID
 NTAPI
 LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
                         IN ULONG MessageId,
-                        IN ULONG CallbackId)
+                        IN ULONG CallbackId,
+                        IN CLIENT_ID ClientId)
 {
     PLPCP_MESSAGE Message;
     PLIST_ENTRY ListHead, NextEntry;
@@ -28,6 +29,7 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
     {
         /* Use it */
         Port = Port->ConnectionPort;
+        if (!Port) return;
     }
 
     /* Loop the list */
@@ -40,12 +42,13 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
 
         /* Make sure it matches */
         if ((Message->Request.MessageId == MessageId) &&
-            (Message->Request.CallbackId == CallbackId))
+            (Message->Request.ClientId.UniqueThread == ClientId.UniqueThread) &&
+            (Message->Request.ClientId.UniqueProcess == ClientId.UniqueProcess))
         {
             /* Unlink and free it */
             RemoveEntryList(&Message->Entry);
             InitializeListHead(&Message->Entry);
-            LpcpFreeToPortZone(Message, TRUE);
+            LpcpFreeToPortZone(Message, 1);
             break;
         }
 
@@ -58,25 +61,31 @@ VOID
 NTAPI
 LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
                         IN PLPCP_MESSAGE Message,
-                        IN ULONG LockFlags)
+                        IN ULONG LockHeld)
 {
     PAGED_CODE();
 
     /* Acquire the lock */
-    KeAcquireGuardedMutex(&LpcpLock);
+    if (!LockHeld) KeAcquireGuardedMutex(&LpcpLock);
 
     /* Check if the port we want is the connection port */
     if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT)
     {
         /* Use it */
         Port = Port->ConnectionPort;
+        if (!Port)
+        {
+            /* Release the lock and return */
+            if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
+            return;
+        }
     }
 
     /* Link the message */
     InsertTailList(&Port->LpcDataInfoChainHead, &Message->Entry);
 
     /* Release the lock */
-    KeReleaseGuardedMutex(&LpcpLock);
+    if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
 }
 
 VOID
@@ -119,7 +128,7 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination,
     Destination->ClientViewSize = Origin->ClientViewSize;
 
     /* Copy the Message Data */
-    RtlMoveMemory(Destination + 1,
+    RtlCopyMemory(Destination + 1,
                   Data,
                   ((Destination->u1.Length & 0xFFFF) + 3) &~3);
 }
@@ -149,13 +158,16 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
                          OUT PPORT_MESSAGE ReceiveMessage,
                          IN PLARGE_INTEGER Timeout OPTIONAL)
 {
-    PLPCP_PORT_OBJECT Port, ReceivePort;
+    PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
     NTSTATUS Status;
     PLPCP_MESSAGE Message;
     PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
     ULONG ConnectionInfoLength;
+    //PORT_MESSAGE CapturedReplyMessage;
+    LARGE_INTEGER CapturedTimeout;
+
     PAGED_CODE();
     LPCTRACE(LPC_REPLY_DEBUG,
              "Handle: %lx. Messages: %p/%p. Context: %p\n",
@@ -164,15 +176,47 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
              ReceiveMessage,
              PortContext);
 
-    /* If this is a system thread, then let it page out its stack */
-    if (Thread->SystemThread) WaitMode = UserMode;
+    if (KeGetPreviousMode() == UserMode)
+    {
+        _SEH2_TRY
+        {
+            if (ReplyMessage != NULL)
+            {
+                ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
+                /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
+                ReplyMessage = &CapturedReplyMessage;*/
+            }
+
+            if (Timeout != NULL)
+            {
+                ProbeForReadLargeInteger(Timeout);
+                RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER));
+                Timeout = &CapturedTimeout;
+            }
+
+            if (PortContext != NULL)
+                ProbeForWritePointer(PortContext);
+        }
+        _SEH2_EXCEPT(ExSystemExceptionFilter())
+        {
+            DPRINT1("SEH crash [1]\n");
+            DbgBreakPoint();
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        /* If this is a system thread, then let it page out its stack */
+        if (Thread->SystemThread) WaitMode = UserMode;
+    }
 
     /* Check if caller has a reply message */
     if (ReplyMessage)
     {
         /* Validate its length */
-        if ((ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
-            ReplyMessage->u1.s1.TotalLength)
+        if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+            (ULONG)ReplyMessage->u1.s1.TotalLength)
         {
             /* Fail */
             return STATUS_INVALID_PARAMETER;
@@ -195,8 +239,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     if (ReplyMessage)
     {
         /* Validate its length in respect to the port object */
-        if ((ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
-            (ReplyMessage->u1.s1.TotalLength <= ReplyMessage->u1.s1.DataLength))
+        if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
+            ((ULONG)ReplyMessage->u1.s1.TotalLength <=
+             (ULONG)ReplyMessage->u1.s1.DataLength))
         {
             /* Too large, fail */
             ObDereferenceObject(Port);
@@ -207,8 +252,32 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     /* Check if this is anything but a client port */
     if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CLIENT_PORT)
     {
-        /* Use the connection port */
-        ReceivePort = Port->ConnectionPort;
+        /* Check if this is the connection port */
+        if (Port->ConnectionPort == Port)
+        {
+            /* Use this port */
+            ConnectionPort = ReceivePort = Port;
+            ObReferenceObject(ConnectionPort);
+        }
+        else
+        {
+            /* Acquire the lock */
+            KeAcquireGuardedMutex(&LpcpLock);
+
+            /* Get the port */
+            ConnectionPort = ReceivePort = Port->ConnectionPort;
+            if (!ConnectionPort)
+            {
+                /* Fail */
+                KeReleaseGuardedMutex(&LpcpLock);
+                ObDereferenceObject(Port);
+                return STATUS_PORT_DISCONNECTED;
+            }
+
+            /* Release lock and reference */
+            ObReferenceObject(ConnectionPort);
+            KeReleaseGuardedMutex(&LpcpLock);
+        }
     }
     else
     {
@@ -227,6 +296,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         {
             /* No thread found, fail */
             ObDereferenceObject(Port);
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
             return Status;
         }
 
@@ -235,6 +305,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         if (!Message)
         {
             /* Fail if we couldn't allocate a message */
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
             ObDereferenceObject(WakeupThread);
             ObDereferenceObject(Port);
             return STATUS_NO_MEMORY;
@@ -244,11 +315,14 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         KeAcquireGuardedMutex(&LpcpLock);
 
         /* Make sure this is the reply the thread is waiting for */
-        if (WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId)
+        if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
+            ((LpcpGetMessageFromThread(WakeupThread)) &&
+             (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->
+                                 Request) != LPC_REQUEST)))
         {
             /* It isn't, fail */
-            LpcpFreeToPortZone(Message, TRUE);
-            KeReleaseGuardedMutex(&LpcpLock);
+            LpcpFreeToPortZone(Message, 3);
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
             ObDereferenceObject(WakeupThread);
             ObDereferenceObject(Port);
             return STATUS_REPLY_MESSAGE_MISMATCH;
@@ -261,11 +335,6 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
                         LPC_REPLY,
                         NULL);
 
-        /* Free any data information */
-        LpcpFreeDataInfoMessage(Port,
-                                ReplyMessage->MessageId,
-                                ReplyMessage->CallbackId);
-
         /* Reference the thread while we use it */
         ObReferenceObject(WakeupThread);
         Message->RepliedToThread = WakeupThread;
@@ -292,6 +361,12 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
             Thread->LpcReceivedMsgIdValid = FALSE;
         }
 
+        /* Free any data information */
+        LpcpFreeDataInfoMessage(Port,
+                                ReplyMessage->MessageId,
+                                ReplyMessage->CallbackId,
+                                ReplyMessage->ClientId);
+
         /* Release the lock and release the LPC semaphore to wake up waiters */
         KeReleaseGuardedMutex(&LpcpLock);
         LpcpCompleteWait(&WakeupThread->LpcReplySemaphore);
@@ -319,6 +394,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
 
         /* Release the lock and fail */
         KeReleaseGuardedMutex(&LpcpLock);
+        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
         ObDereferenceObject(Port);
         return STATUS_UNSUCCESSFUL;
     }
@@ -347,74 +423,90 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     Thread->LpcReceivedMessageId = Message->Request.MessageId;
     Thread->LpcReceivedMsgIdValid = TRUE;
 
-    /* Done touching global data, release the lock */
-    KeReleaseGuardedMutex(&LpcpLock);
-
-    /* Check if this was a connection request */
-    if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
-    {
-        /* Get the connection message */
-        ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
-        LPCTRACE(LPC_REPLY_DEBUG,
-                 "Request Messages: %p/%p\n",
-                 Message,
-                 ConnectMessage);
-
-        /* Get its length */
-        ConnectionInfoLength = Message->Request.u1.s1.DataLength -
-                               sizeof(LPCP_CONNECTION_MESSAGE);
-
-        /* Return it as the receive message */
-        *ReceiveMessage = Message->Request;
-
-        /* Clear our stack variable so the message doesn't get freed */
-        Message = NULL;
-
-        /* Setup the receive message */
-        ReceiveMessage->u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
-                                            ConnectionInfoLength;
-        ReceiveMessage->u1.s1.DataLength = ConnectionInfoLength;
-        RtlMoveMemory(ReceiveMessage + 1,
-                      ConnectMessage + 1,
-                      ConnectionInfoLength);
-
-        /* Clear the port context if the caller requested one */
-        if (PortContext) *PortContext = NULL;
-    }
-    else if (Message->Request.u2.s2.Type != LPC_REPLY)
+    _SEH2_TRY
     {
-        /* Otherwise, this is a new message or event */
-        LPCTRACE(LPC_REPLY_DEBUG,
-                 "Non-Reply Messages: %p/%p\n",
-                 &Message->Request,
-                 (&Message->Request) + 1);
-
-        /* Copy it */
-        LpcpMoveMessage(ReceiveMessage,
-                        &Message->Request,
-                        (&Message->Request) + 1,
-                        0,
-                        NULL);
+        /* Check if this was a connection request */
+        if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
+        {
+            /* Get the connection message */
+            ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+            LPCTRACE(LPC_REPLY_DEBUG,
+                     "Request Messages: %p/%p\n",
+                     Message,
+                     ConnectMessage);
 
-        /* Return its context */
-        if (PortContext) *PortContext = Message->PortContext;
+            /* Get its length */
+            ConnectionInfoLength = Message->Request.u1.s1.DataLength -
+                                   sizeof(LPCP_CONNECTION_MESSAGE);
 
-        /* And check if it has data information */
-        if (Message->Request.u2.s2.DataInfoOffset)
-        {
-            /* It does, save it, and don't free the message below */
-            LpcpSaveDataInfoMessage(Port, Message, 1);
+            /* Return it as the receive message */
+            *ReceiveMessage = Message->Request;
+
+            /* Clear our stack variable so the message doesn't get freed */
             Message = NULL;
+
+            /* Setup the receive message */
+            ReceiveMessage->u1.s1.TotalLength = (CSHORT)(sizeof(LPCP_MESSAGE) +
+                                                         ConnectionInfoLength);
+            ReceiveMessage->u1.s1.DataLength = (CSHORT)ConnectionInfoLength;
+            RtlCopyMemory(ReceiveMessage + 1,
+                          ConnectMessage + 1,
+                          ConnectionInfoLength);
+
+            /* Clear the port context if the caller requested one */
+            if (PortContext) *PortContext = NULL;
+        }
+        else if (LpcpGetMessageType(&Message->Request) != LPC_REPLY)
+        {
+            /* Otherwise, this is a new message or event */
+            LPCTRACE(LPC_REPLY_DEBUG,
+                     "Non-Reply Messages: %p/%p\n",
+                     &Message->Request,
+                     (&Message->Request) + 1);
+
+            /* Copy it */
+            LpcpMoveMessage(ReceiveMessage,
+                            &Message->Request,
+                            (&Message->Request) + 1,
+                            0,
+                            NULL);
+
+            /* Return its context */
+            if (PortContext) *PortContext = Message->PortContext;
+
+            /* And check if it has data information */
+            if (Message->Request.u2.s2.DataInfoOffset)
+            {
+                /* It does, save it, and don't free the message below */
+                LpcpSaveDataInfoMessage(Port, Message, 1);
+                Message = NULL;
+            }
+        }
+        else
+        {
+            /* This is a reply message, should never happen! */
+            ASSERT(FALSE);
         }
     }
-    else
+    _SEH2_EXCEPT(ExSystemExceptionFilter())
     {
-        /* This is a reply message, should never happen! */
-        ASSERT(FALSE);
+        DPRINT1("SEH crash [2]\n");
+        DbgBreakPoint();
+        Status = _SEH2_GetExceptionCode();
     }
+    _SEH2_END;
 
-    /* If we have a message pointer here, free it */
-    if (Message) LpcpFreeToPortZone(Message, FALSE);
+    /* Check if we have a message pointer here */
+    if (Message)
+    {
+        /* Free it and release the lock */
+        LpcpFreeToPortZone(Message, 3);
+    }
+    else
+    {
+        /* Just release the lock */
+        KeReleaseGuardedMutex(&LpcpLock);
+    }
 
 Cleanup:
     /* All done, dereference the port and return the status */
@@ -422,6 +514,7 @@ Cleanup:
              "Port: %p. Status: %p\n",
              Port,
              Status);
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
     ObDereferenceObject(Port);
     return Status;
 }