[ntoskrnl/lpc]
authorAleksey Bragin <aleksey@reactos.org>
Mon, 19 Oct 2009 15:49:29 +0000 (15:49 +0000)
committerAleksey Bragin <aleksey@reactos.org>
Mon, 19 Oct 2009 15:49:29 +0000 (15:49 +0000)
- Implement NtReplyPort based on NtReplyWaitReceivePortEx and LpcReplyPort.
- Implement NtRequestPort based on NtRequestWaitReplyPort and LpcRequestPort.

svn path=/trunk/; revision=43603

reactos/ntoskrnl/lpc/reply.c
reactos/ntoskrnl/lpc/send.c

index 0233ea1..377313e 100644 (file)
@@ -141,10 +141,172 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination,
 NTSTATUS
 NTAPI
 NtReplyPort(IN HANDLE PortHandle,
-            IN PPORT_MESSAGE LpcReply)
+            IN PPORT_MESSAGE ReplyMessage)
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    PLPCP_PORT_OBJECT Port, ConnectionPort = NULL;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    NTSTATUS Status;
+    PLPCP_MESSAGE Message;
+    PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
+    //PORT_MESSAGE CapturedReplyMessage;
+
+    PAGED_CODE();
+    LPCTRACE(LPC_REPLY_DEBUG,
+             "Handle: %lx. Message: %p.\n",
+             PortHandle,
+             ReplyMessage);
+
+    if (KeGetPreviousMode() == UserMode)
+    {
+        _SEH2_TRY
+        {
+            if (ReplyMessage != NULL)
+            {
+                ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
+                /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
+                ReplyMessage = &CapturedReplyMessage;*/
+            }
+        }
+        _SEH2_EXCEPT(ExSystemExceptionFilter())
+        {
+            DPRINT1("SEH crash [1]\n");
+            DbgBreakPoint();
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
+
+    /* Validate its length */
+    if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+        (ULONG)ReplyMessage->u1.s1.TotalLength)
+    {
+        /* Fail */
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Make sure it has a valid ID */
+    if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
+
+    /* Get the Port object */
+    Status = ObReferenceObjectByHandle(PortHandle,
+                                       0,
+                                       LpcPortObjectType,
+                                       PreviousMode,
+                                       (PVOID*)&Port,
+                                       NULL);
+    if (!NT_SUCCESS(Status)) return Status;
+
+    /* Validate its length in respect to the port object */
+    if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)ReplyMessage->u1.s1.TotalLength <=
+        (ULONG)ReplyMessage->u1.s1.DataLength))
+    {
+        /* Too large, fail */
+        ObDereferenceObject(Port);
+        return STATUS_PORT_MESSAGE_TOO_LONG;
+    }
+
+    /* Get the ETHREAD corresponding to it */
+    Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
+                                        NULL,
+                                        &WakeupThread);
+    if (!NT_SUCCESS(Status))
+    {
+        /* No thread found, fail */
+        ObDereferenceObject(Port);
+        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+        return Status;
+    }
+
+    /* Allocate a message from the port zone */
+    Message = LpcpAllocateFromPortZone();
+    if (!Message)
+    {
+        /* Fail if we couldn't allocate a message */
+        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+        ObDereferenceObject(WakeupThread);
+        ObDereferenceObject(Port);
+        return STATUS_NO_MEMORY;
+    }
+
+    /* Keep the lock acquired */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Make sure this is the reply the thread is waiting for */
+    if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
+        ((LpcpGetMessageFromThread(WakeupThread)) &&
+        (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->
+        Request) != LPC_REQUEST)))
+    {
+        /* It isn't, fail */
+        LpcpFreeToPortZone(Message, 3);
+        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+        ObDereferenceObject(WakeupThread);
+        ObDereferenceObject(Port);
+        return STATUS_REPLY_MESSAGE_MISMATCH;
+    }
+
+    /* Copy the message */
+    _SEH2_TRY
+    {
+        LpcpMoveMessage(&Message->Request,
+                        ReplyMessage,
+                        ReplyMessage + 1,
+                        LPC_REPLY,
+                        NULL);
+    }
+    _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+    {
+        /* Fail */
+        LpcpFreeToPortZone(Message, 3);
+        ObDereferenceObject(WakeupThread);
+        ObDereferenceObject(Port);
+        _SEH2_YIELD(return _SEH2_GetExceptionCode());
+    }
+    _SEH2_END;
+
+    /* Reference the thread while we use it */
+    ObReferenceObject(WakeupThread);
+    Message->RepliedToThread = WakeupThread;
+
+    /* Set this as the reply message */
+    WakeupThread->LpcReplyMessageId = 0;
+    WakeupThread->LpcReplyMessage = (PVOID)Message;
+
+    /* Check if we have messages on the reply chain */
+    if (!(WakeupThread->LpcExitThreadCalled) &&
+        !(IsListEmpty(&WakeupThread->LpcReplyChain)))
+    {
+        /* Remove us from it and reinitialize it */
+        RemoveEntryList(&WakeupThread->LpcReplyChain);
+        InitializeListHead(&WakeupThread->LpcReplyChain);
+    }
+
+    /* Check if this is the message the thread had received */
+    if ((Thread->LpcReceivedMsgIdValid) &&
+        (Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
+    {
+        /* Clear this data */
+        Thread->LpcReceivedMessageId = 0;
+        Thread->LpcReceivedMsgIdValid = FALSE;
+    }
+
+    /* Free any data information */
+    LpcpFreeDataInfoMessage(Port,
+                            ReplyMessage->MessageId,
+                            ReplyMessage->CallbackId,
+                            ReplyMessage->ClientId);
+
+    /* Release the lock and release the LPC semaphore to wake up waiters */
+    KeReleaseGuardedMutex(&LpcpLock);
+    LpcpCompleteWait(&WakeupThread->LpcReplySemaphore);
+
+    /* Now we can let go of the thread */
+    ObDereferenceObject(WakeupThread);
+
+    /* Dereference port object */
+    ObDereferenceObject(Port);
+    return Status;
 }
 
 /*
index dad5a35..6980ee4 100644 (file)
@@ -442,10 +442,195 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
 NTSTATUS
 NTAPI
 NtRequestPort(IN HANDLE PortHandle,
-              IN PPORT_MESSAGE LpcMessage)
+              IN PPORT_MESSAGE LpcRequest)
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    NTSTATUS Status;
+    PLPCP_MESSAGE Message;
+    PETHREAD Thread = PsGetCurrentThread();
+
+    PKSEMAPHORE Semaphore;
+    ULONG MessageType;
+    PAGED_CODE();
+    LPCTRACE(LPC_SEND_DEBUG,
+             "Handle: %lx. Message: %p. Type: %lx\n",
+             PortHandle,
+             LpcRequest,
+             LpcpGetMessageType(LpcRequest));
+
+    /* Get the message type */
+    MessageType = LpcRequest->u2.s2.Type | LPC_DATAGRAM;
+
+    /* Can't have data information on this type of call */
+    if (LpcRequest->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
+
+    /* Validate the length */
+    if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+         (ULONG)LpcRequest->u1.s1.TotalLength)
+    {
+        /* Fail */
+        return STATUS_INVALID_PARAMETER;
+    }
+
+    /* Reference the object */
+    Status = ObReferenceObjectByHandle(PortHandle,
+                                       0,
+                                       LpcPortObjectType,
+                                       PreviousMode,
+                                       (PVOID*)&Port,
+                                       NULL);
+    if (!NT_SUCCESS(Status)) return Status;
+
+    /* Validate the message length */
+    if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
+        ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
+    {
+        /* Fail */
+        ObDereferenceObject(Port);
+        return STATUS_PORT_MESSAGE_TOO_LONG;
+    }
+
+    /* Allocate a message from the port zone */
+    Message = LpcpAllocateFromPortZone();
+    if (!Message)
+    {
+        /* Fail if we couldn't allocate a message */
+        ObDereferenceObject(Port);
+        return STATUS_NO_MEMORY;
+    }
+
+    /* No callback, just copy the message */
+    _SEH2_TRY
+    {
+        /* Copy it */
+        LpcpMoveMessage(&Message->Request,
+            LpcRequest,
+            LpcRequest + 1,
+            MessageType,
+            &Thread->Cid);
+    }
+    _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+    {
+        /* Fail */
+        LpcpFreeToPortZone(Message, 0);
+        ObDereferenceObject(Port);
+        _SEH2_YIELD(return _SEH2_GetExceptionCode());
+    }
+    _SEH2_END;
+
+    /* Acquire the LPC lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Right now clear the port context */
+    Message->PortContext = NULL;
+
+    /* Check if this is a not connection port */
+    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
+    {
+        /* We want the connected port */
+        QueuePort = Port->ConnectedPort;
+        if (!QueuePort)
+        {
+            /* We have no connected port, fail */
+            LpcpFreeToPortZone(Message, 3);
+            ObDereferenceObject(Port);
+            return STATUS_PORT_DISCONNECTED;
+        }
+
+        /* Check if this is a communication port */
+        if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
+        {
+            /* Copy the port context and use the connection port */
+            Message->PortContext = QueuePort->PortContext;
+            ConnectionPort = QueuePort = Port->ConnectionPort;
+            if (!ConnectionPort)
+            {
+                /* Fail */
+                LpcpFreeToPortZone(Message, 3);
+                ObDereferenceObject(Port);
+                return STATUS_PORT_DISCONNECTED;
+            }
+        }
+        else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
+            LPCP_COMMUNICATION_PORT)
+        {
+            /* Use the connection port for anything but communication ports */
+            ConnectionPort = QueuePort = Port->ConnectionPort;
+            if (!ConnectionPort)
+            {
+                /* Fail */
+                LpcpFreeToPortZone(Message, 3);
+                ObDereferenceObject(Port);
+                return STATUS_PORT_DISCONNECTED;
+            }
+        }
+
+        /* Reference the connection port if it exists */
+        if (ConnectionPort) ObReferenceObject(ConnectionPort);
+    }
+    else
+    {
+        /* Otherwise, for a connection port, use the same port object */
+        QueuePort = Port;
+    }
+
+    /* Reference QueuePort if we have it */
+    if (QueuePort && ObReferenceObjectSafe(QueuePort))
+    {
+        /* Set sender's port */
+        Message->SenderPort = Port;
+
+        /* Generate the Message ID and set it */
+        Message->Request.MessageId =  LpcpNextMessageId++;
+        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
+        Message->Request.CallbackId = 0;
+
+        /* No Message ID for the thread */
+        PsGetCurrentThread()->LpcReplyMessageId = 0;
+
+        /* Insert the message in our chain */
+        InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
+
+        /* Release the lock and get the semaphore we'll use later */
+        KeEnterCriticalRegion();
+        KeReleaseGuardedMutex(&LpcpLock);
+
+        /* Now release the semaphore */
+        Semaphore = QueuePort->MsgQueue.Semaphore;
+        LpcpCompleteWait(Semaphore);
+
+        /* If this is a waitable port, wake it up */
+        if (QueuePort->Flags & LPCP_WAITABLE_PORT)
+        {
+            /* Wake it */
+            KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
+        }
+
+        KeLeaveCriticalRegion();
+
+        /* Dereference objects */
+        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+        ObDereferenceObject(QueuePort);
+        ObDereferenceObject(Port);
+        LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
+        return STATUS_SUCCESS;
+    }
+
+    Status = STATUS_PORT_DISCONNECTED;
+
+    /* All done with a failure*/
+    LPCTRACE(LPC_SEND_DEBUG,
+             "Port: %p. Status: %p\n",
+             Port,
+             Status);
+
+    /* The wait failed, free the message */
+    if (Message) LpcpFreeToPortZone(Message, 3);
+
+    ObDereferenceObject(Port);
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+    return Status;
 }
 
 /*