- Use _SEH2_YIELD when returning from an exception instead of returning outside the...
[reactos.git] / reactos / ntoskrnl / lpc / send.c
index 1dd0a1d..dad5a35 100644 (file)
@@ -167,16 +167,273 @@ LpcRequestPort(IN PVOID PortObject,
 }
 
 /*
-* @unimplemented
+* @implemented
 */
 NTSTATUS
 NTAPI
-LpcRequestWaitReplyPort(IN PVOID Port,
-                        IN PPORT_MESSAGE LpcMessageRequest,
-                        OUT PPORT_MESSAGE LpcMessageReply)
+LpcRequestWaitReplyPort(IN PVOID PortObject,
+                        IN PPORT_MESSAGE LpcRequest,
+                        OUT PPORT_MESSAGE LpcReply)
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    NTSTATUS Status = STATUS_SUCCESS;
+    PLPCP_MESSAGE Message;
+    PETHREAD Thread = PsGetCurrentThread();
+    BOOLEAN Callback = FALSE;
+    PKSEMAPHORE Semaphore;
+    ULONG MessageType;
+    PAGED_CODE();
+
+    Port = (PLPCP_PORT_OBJECT)PortObject;
+
+    LPCTRACE(LPC_SEND_DEBUG,
+             "Port: %p. Messages: %p/%p. Type: %lx\n",
+             Port,
+             LpcRequest,
+             LpcReply,
+             LpcpGetMessageType(LpcRequest));
+
+    /* Check if the thread is dying */
+    if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;
+
+    /* Check if this is an LPC Request */
+    MessageType = LpcpGetMessageType(LpcRequest);
+    switch (MessageType)
+    {
+        /* No type */
+        case 0:
+            
+            /* Assume LPC request */
+            MessageType = LPC_REQUEST;
+            break;
+        
+        /* LPC request callback */
+        case LPC_REQUEST:
+            
+            /* This is a callback */
+            Callback = TRUE;
+            break;
+        
+        /* Anything else */
+        case LPC_CLIENT_DIED:
+        case LPC_PORT_CLOSED:
+        case LPC_EXCEPTION:
+        case LPC_DEBUG_EVENT:
+        case LPC_ERROR_EVENT:
+            
+            /* Nothing to do */
+            break;
+            
+        default:
+            
+            /* Invalid message type */
+            return STATUS_INVALID_PARAMETER;
+    }
+    
+    /* Set the request type */
+    LpcRequest->u2.s2.Type = MessageType;
+
+    /* Validate the message length */
+    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
+    {
+        /* Fail */
+        return STATUS_PORT_MESSAGE_TOO_LONG;
+    }
+
+    /* Allocate a message from the port zone */
+    Message = LpcpAllocateFromPortZone();
+    if (!Message)
+    {
+        /* Fail if we couldn't allocate a message */
+        return STATUS_NO_MEMORY;
+    }
+
+    /* Check if this is a callback */
+    if (Callback)
+    {
+        /* FIXME: TODO */
+        Semaphore = NULL; // we'd use the Thread Semaphore here
+        ASSERT(FALSE);
+        return STATUS_NOT_IMPLEMENTED;
+    }
+    else
+    {
+        /* No callback, just copy the message */
+        LpcpMoveMessage(&Message->Request,
+                        LpcRequest,
+                        LpcRequest + 1,
+                        0,
+                        &Thread->Cid);
+
+        /* Acquire the LPC lock */
+        KeAcquireGuardedMutex(&LpcpLock);
+
+        /* Right now clear the port context */
+        Message->PortContext = NULL;
+
+        /* Check if this is a not connection port */
+        if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
+        {
+            /* We want the connected port */
+            QueuePort = Port->ConnectedPort;
+            if (!QueuePort)
+            {
+                /* We have no connected port, fail */
+                LpcpFreeToPortZone(Message, 3);
+                return STATUS_PORT_DISCONNECTED;
+            }
+
+            /* This will be the rundown port */
+            ReplyPort = QueuePort;
+
+            /* Check if this is a communication port */
+            if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
+            {
+                /* Copy the port context and use the connection port */
+                Message->PortContext = QueuePort->PortContext;
+                ConnectionPort = QueuePort = Port->ConnectionPort;
+                if (!ConnectionPort)
+                {
+                    /* Fail */
+                    LpcpFreeToPortZone(Message, 3);
+                    return STATUS_PORT_DISCONNECTED;
+                }
+            }
+            else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
+                      LPCP_COMMUNICATION_PORT)
+            {
+                /* Use the connection port for anything but communication ports */
+                ConnectionPort = QueuePort = Port->ConnectionPort;
+                if (!ConnectionPort)
+                {
+                    /* Fail */
+                    LpcpFreeToPortZone(Message, 3);
+                    return STATUS_PORT_DISCONNECTED;
+                }
+            }
+
+            /* Reference the connection port if it exists */
+            if (ConnectionPort) ObReferenceObject(ConnectionPort);
+        }
+        else
+        {
+            /* Otherwise, for a connection port, use the same port object */
+            QueuePort = ReplyPort = Port;
+        }
+
+        /* No reply thread */
+        Message->RepliedToThread = NULL;
+        Message->SenderPort = Port;
+
+        /* Generate the Message ID and set it */
+        Message->Request.MessageId =  LpcpNextMessageId++;
+        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
+        Message->Request.CallbackId = 0;
+
+        /* Set the message ID for our thread now */
+        Thread->LpcReplyMessageId = Message->Request.MessageId;
+        Thread->LpcReplyMessage = NULL;
+
+        /* Insert the message in our chain */
+        InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
+        InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
+        LpcpSetPortToThread(Thread, Port);
+
+        /* Release the lock and get the semaphore we'll use later */
+        KeEnterCriticalRegion();
+        KeReleaseGuardedMutex(&LpcpLock);
+        Semaphore = QueuePort->MsgQueue.Semaphore;
+
+        /* If this is a waitable port, wake it up */
+        if (QueuePort->Flags & LPCP_WAITABLE_PORT)
+        {
+            /* Wake it */
+            KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
+        }
+    }
+
+    /* Now release the semaphore */
+    LpcpCompleteWait(Semaphore);
+    KeLeaveCriticalRegion();
+
+    /* And let's wait for the reply */
+    LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);
+
+    /* Acquire the LPC lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Get the LPC Message and clear our thread's reply data */
+    Message = LpcpGetMessageFromThread(Thread);
+    Thread->LpcReplyMessage = NULL;
+    Thread->LpcReplyMessageId = 0;
+
+    /* Check if we have anything on the reply chain*/
+    if (!IsListEmpty(&Thread->LpcReplyChain))
+    {
+        /* Remove this thread and reinitialize the list */
+        RemoveEntryList(&Thread->LpcReplyChain);
+        InitializeListHead(&Thread->LpcReplyChain);
+    }
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+
+    /* Check if we got a reply */
+    if (Status == STATUS_SUCCESS)
+    {
+        /* Check if we have a valid message */
+        if (Message)
+        {
+            LPCTRACE(LPC_SEND_DEBUG,
+                     "Reply Messages: %p/%p\n",
+                     &Message->Request,
+                     (&Message->Request) + 1);
+
+            /* Move the message */
+            LpcpMoveMessage(LpcReply,
+                            &Message->Request,
+                            (&Message->Request) + 1,
+                            0,
+                            NULL);
+            
+            /* Acquire the lock */
+            KeAcquireGuardedMutex(&LpcpLock);
+            
+            /* Check if we replied to a thread */
+            if (Message->RepliedToThread)
+            {
+                /* Dereference */
+                ObDereferenceObject(Message->RepliedToThread);
+                Message->RepliedToThread = NULL;
+            }
+
+
+            /* Free the message */
+            LpcpFreeToPortZone(Message, 3);
+        }
+        else
+        {
+            /* We don't have a reply */
+            Status = STATUS_LPC_REPLY_LOST;
+        }
+    }
+    else
+    {
+        /* The wait failed, free the message */
+        if (Message) LpcpFreeToPortZone(Message, 0);
+    }
+
+    /* All done */
+    LPCTRACE(LPC_SEND_DEBUG,
+             "Port: %p. Status: %p\n",
+             Port,
+             Status);
+
+    /* Dereference the connection port */
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+    return Status;
 }
 
 /*
@@ -285,26 +542,23 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     else
     {
         /* No callback, just copy the message */
-        _SEH_TRY
+        _SEH2_TRY
         {
+            /* Copy it */
             LpcpMoveMessage(&Message->Request,
                             LpcRequest,
                             LpcRequest + 1,
                             MessageType,
                             &Thread->Cid);
         }
-        _SEH_HANDLE
-        {
-            Status = _SEH_GetExceptionCode();
-        }
-        _SEH_END;
-
-        if (!NT_SUCCESS(Status))
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
         {
+            /* Fail */
             LpcpFreeToPortZone(Message, 0);
             ObDereferenceObject(Port);
-            return Status;
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
         }
+        _SEH2_END;
 
         /* Acquire the LPC lock */
         KeAcquireGuardedMutex(&LpcpLock);
@@ -434,7 +688,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
                      (&Message->Request) + 1);
 
             /* Move the message */
-            _SEH_TRY
+            _SEH2_TRY
             {
                 LpcpMoveMessage(LpcReply,
                                 &Message->Request,
@@ -442,11 +696,11 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
                                 0,
                                 NULL);
             }
-            _SEH_HANDLE
+            _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
             {
-                Status = _SEH_GetExceptionCode();
+                Status = _SEH2_GetExceptionCode();
             }
-            _SEH_END;
+            _SEH2_END;
 
             /* Check if this is an LPC request with data information */
             if ((LpcpGetMessageType(&Message->Request) == LPC_REQUEST) &&