- Fix multiple LPC race conditions.
authorAlex Ionescu <aionescu@gmail.com>
Sun, 21 Jan 2007 17:21:42 +0000 (17:21 +0000)
committerAlex Ionescu <aionescu@gmail.com>
Sun, 21 Jan 2007 17:21:42 +0000 (17:21 +0000)
- Improve LpcpFreeToPortZone calls for optimizing lock release.
- Use RtlCopyMemory instead of RtlMoveMemory to optimize data transfer speed.
- Always hold a reference to the connection port associated to the LPC port and properly handle this reference in all the LPC code.
- Hold a reference to the process that mapped a server/client view, and use this field when freeing memory in case we're called out-of-process.
- Fix a lot of list parsing loops and code to handle the case when the list is now empty.
- Validate more fields and data in the code.
- There are still some LPC bugs at system shutdown.

svn path=/trunk/; revision=25557

reactos/ntoskrnl/include/internal/lpc_x.h
reactos/ntoskrnl/ke/gate.c
reactos/ntoskrnl/lpc/close.c
reactos/ntoskrnl/lpc/complete.c
reactos/ntoskrnl/lpc/connect.c
reactos/ntoskrnl/lpc/create.c
reactos/ntoskrnl/lpc/listen.c
reactos/ntoskrnl/lpc/port.c
reactos/ntoskrnl/lpc/reply.c
reactos/ntoskrnl/lpc/send.c

index 0ba3b0a..a1cc008 100644 (file)
@@ -45,7 +45,7 @@
         {                                                   \\r
             /* It's still signaled, so wait on it */        \\r
             KeWaitForSingleObject(s,                        \\r
-                                  Executive,                \\r
+                                  WrExecutive,              \\r
                                   KernelMode,               \\r
                                   FALSE,                    \\r
                                   NULL);                    \\r
@@ -73,7 +73,7 @@
         {                                                   \\r
             /* It's still signaled, so wait on it */        \\r
             KeWaitForSingleObject(s,                        \\r
-                                  Executive,                \\r
+                                  WrExecutive,              \\r
                                   KernelMode,               \\r
                                   FALSE,                    \\r
                                   NULL);                    \\r
index f6ace42..3f82d63 100644 (file)
@@ -137,7 +137,6 @@ KeSignalGateBoostPriority(IN PKGATE Gate)
     KIRQL OldIrql;
     ASSERT_GATE(Gate);
     ASSERT_IRQL_LESS_OR_EQUAL(DISPATCH_LEVEL);
-    ASSERT(FALSE);
 
     /* Start entry loop */
     for (;;)
index bf42824..62a0ff7 100644 (file)
@@ -19,6 +19,7 @@ NTAPI
 LpcExitThread(IN PETHREAD Thread)
 {
     PLPCP_MESSAGE Message;
+    ASSERT(Thread == PsGetCurrentThread());
 
     /* Acquire the lock */
     KeAcquireGuardedMutex(&LpcpLock);
@@ -54,7 +55,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message,
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
     PLPCP_PORT_OBJECT ClientPort = NULL;
     PETHREAD Thread = NULL;
-    BOOLEAN LockHeld = Flags & 1;
+    BOOLEAN LockHeld = Flags & 1, ReleaseLock = Flags & 2;
     PAGED_CODE();
     LPCTRACE(LPC_CLOSE_DEBUG, "Message: %p. Flags: %lx\n", Message, Flags);
 
@@ -99,7 +100,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message,
     ExFreeToPagedLookasideList(&LpcpMessagesLookaside, Message);
 
     /* Reacquire the lock if needed */
-    if ((LockHeld) && !(Flags & 2)) KeAcquireGuardedMutex(&LpcpLock);
+    if ((LockHeld) && !(ReleaseLock)) KeAcquireGuardedMutex(&LpcpLock);
 }
 
 VOID
@@ -110,14 +111,27 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
     PLIST_ENTRY ListHead, NextEntry;
     PETHREAD Thread;
     PLPCP_MESSAGE Message;
+    PLPCP_PORT_OBJECT ConnectionPort = NULL;
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
+    PAGED_CODE();
     LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags);
 
     /* Hold the lock */
     KeAcquireGuardedMutex(&LpcpLock);
 
-    /* Disconnect the port to which this port is connected */
-    if (Port->ConnectedPort) Port->ConnectedPort->ConnectedPort = NULL;
+    /* Check if we have a connected port */
+    if (((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_UNCONNECTED_PORT) &&
+        (Port->ConnectedPort))
+    {
+        /* Disconnect it */
+        Port->ConnectedPort->ConnectedPort = NULL;
+        if (Port->ConnectedPort->ConnectionPort)
+        {
+            /* Save and clear connection port */
+            ConnectionPort = Port->ConnectedPort->ConnectionPort;
+            Port->ConnectedPort->ConnectionPort = NULL;
+        }
+    }
 
     /* Check if this is a connection port */
     if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT)
@@ -129,7 +143,7 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
     /* Walk all the threads waiting and signal them */
     ListHead = &Port->LpcReplyChainHead;
     NextEntry = ListHead->Flink;
-    while (NextEntry != ListHead)
+    while ((NextEntry) && (NextEntry != ListHead))
     {
         /* Get the Thread */
         Thread = CONTAINING_RECORD(NextEntry, ETHREAD, LpcReplyChain);
@@ -147,58 +161,64 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
         /* Check if someone is waiting */
         if (!KeReadStateSemaphore(&Thread->LpcReplySemaphore))
         {
-            /* Get the message and check if it's a connection request */
+            /* Get the message */
             Message = Thread->LpcReplyMessage;
-            if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
+            if (Message)
             {
-                /* Get the connection message */
-                ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
-
-                /* Check if it had a section */
-                if (ConnectMessage->SectionToMap)
+                /* Check if it's a connection request */
+                if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
                 {
-                    /* Dereference it */
-                    ObDereferenceObject(ConnectMessage->SectionToMap);
+                    /* Get the connection message */
+                    ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+
+                    /* Check if it had a section */
+                    if (ConnectMessage->SectionToMap)
+                    {
+                        /* Dereference it */
+                        ObDereferenceObject(ConnectMessage->SectionToMap);
+                    }
                 }
-            }
 
-            /* Clear the reply message */
-            Thread->LpcReplyMessage = NULL;
+                /* Clear the reply message */
+                Thread->LpcReplyMessage = NULL;
 
-            /* And remove the message from the port zone */
-            LpcpFreeToPortZone(Message, TRUE);
-        }
+                /* And remove the message from the port zone */
+                LpcpFreeToPortZone(Message, 1);
+                NextEntry = Port->LpcReplyChainHead.Flink;
+            }
 
-        /* Release the semaphore and reset message id count */
-        Thread->LpcReplyMessageId = 0;
-        LpcpCompleteWait(&Thread->LpcReplySemaphore);
+            /* Release the semaphore and reset message id count */
+            Thread->LpcReplyMessageId = 0;
+            KeReleaseSemaphore(&Thread->LpcReplySemaphore, 0, 1, FALSE);
+        }
     }
 
     /* Reinitialize the list head */
     InitializeListHead(&Port->LpcReplyChainHead);
 
     /* Loop queued messages */
-    ListHead = &Port->MsgQueue.ReceiveHead;
-    NextEntry =  ListHead->Flink;
-    while ((NextEntry) && (ListHead != NextEntry))
+    while ((Port->MsgQueue.ReceiveHead.Flink) &&
+           !(IsListEmpty (&Port->MsgQueue.ReceiveHead)))
     {
         /* Get the message */
-        Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
-        NextEntry = NextEntry->Flink;
+        Message = CONTAINING_RECORD(Port->MsgQueue.ReceiveHead.Flink,
+                                    LPCP_MESSAGE,
+                                    Entry);
 
         /* Free and reinitialize it's list head */
+        RemoveEntryList(&Message->Entry);
         InitializeListHead(&Message->Entry);
 
         /* Remove it from the port zone */
-        LpcpFreeToPortZone(Message, TRUE);
+        LpcpFreeToPortZone(Message, 1);
     }
 
-    /* Reinitialize the message queue list head */
-    InitializeListHead(&Port->MsgQueue.ReceiveHead);
-
     /* Release the lock */
     KeReleaseGuardedMutex(&LpcpLock);
 
+    /* Dereference the connection port */
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+
     /* Check if we have to free the port entirely */
     if (Destroy)
     {
@@ -334,13 +354,32 @@ LpcpDeletePort(IN PVOID ObjectBody)
     /* Destroy the port queue */
     LpcpDestroyPortQueue(Port, TRUE);
 
-    /* Check if we had a client view */
-    if (Port->ClientSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(),
-                                                      Port->ClientSectionBase);
+    /* Check if we had views */
+    if ((Port->ClientSectionBase) || (Port->ServerSectionBase))
+    {
+        /* Check if we had a client view */
+        if (Port->ClientSectionBase)
+        {
+            /* Unmap it */
+            MmUnmapViewOfSection(Port->MappingProcess,
+                                 Port->ClientSectionBase);
+        }
 
-    /* Check for a server view */
-    if (Port->ServerSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(),
-                                                      Port->ServerSectionBase);
+        /* Check for a server view */
+        if (Port->ServerSectionBase)
+        {
+            /* Unmap it */
+            MmUnmapViewOfSection(Port->MappingProcess,
+                                 Port->ServerSectionBase);
+        }
+
+        /* Dereference the mapping process */
+        ObDereferenceObject(Port->MappingProcess);
+        Port->MappingProcess = NULL;
+    }
+
+    /* Acquire the lock */
+    KeAcquireGuardedMutex(&LpcpLock);
 
     /* Get the connection port */
     ConnectionPort = Port->ConnectionPort;
@@ -349,9 +388,6 @@ LpcpDeletePort(IN PVOID ObjectBody)
         /* Get the PID */
         Pid = PsGetCurrentProcessId();
 
-        /* Acquire the lock */
-        KeAcquireGuardedMutex(&LpcpLock);
-
         /* Loop the data lists */
         ListHead = &ConnectionPort->LpcDataInfoChainHead;
         NextEntry = ListHead->Flink;
@@ -361,12 +397,29 @@ LpcpDeletePort(IN PVOID ObjectBody)
             Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
             NextEntry = NextEntry->Flink;
 
-            /* Check if the PID matches */
-            if (Message->Request.ClientId.UniqueProcess == Pid)
+            /* Check if this is the connection port */
+            if (Port == ConnectionPort)
+            {
+                /* Free queued messages */
+                RemoveEntryList(&Message->Entry);
+                InitializeListHead(&Message->Entry);
+                LpcpFreeToPortZone(Message, 1);
+
+                /* Restart at the head */
+                NextEntry = ListHead->Flink;
+            }
+            else if ((Message->Request.ClientId.UniqueProcess == Pid) &&
+                     ((Message->SenderPort == Port) ||
+                      (Message->SenderPort == Port->ConnectedPort) ||
+                      (Message->SenderPort == ConnectionPort)))
             {
                 /* Remove it */
                 RemoveEntryList(&Message->Entry);
-                LpcpFreeToPortZone(Message, TRUE);
+                InitializeListHead(&Message->Entry);
+                LpcpFreeToPortZone(Message, 1);
+
+                /* Restart at the head */
+                NextEntry = ListHead->Flink;
             }
         }
 
@@ -376,6 +429,11 @@ LpcpDeletePort(IN PVOID ObjectBody)
         /* Dereference the object unless it's the same port */
         if (ConnectionPort != Port) ObDereferenceObject(ConnectionPort);
     }
+    else
+    {
+        /* Release the lock */
+        KeReleaseGuardedMutex(&LpcpLock);
+    }
 
     /* Check if this is a connection port with a server process*/
     if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) &&
index eeda2e4..33c6886 100644 (file)
@@ -145,7 +145,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     Message->Request.ClientId = ReplyMessage->ClientId;
     Message->Request.MessageId = ReplyMessage->MessageId;
     Message->Request.ClientViewSize = 0;
-    RtlMoveMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength);
+    RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength);
 
     /* At this point, if the caller refused the connection, go to cleanup */
     if (!AcceptConnection) goto Cleanup;
@@ -213,6 +213,10 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
             /* Set the view base */
             ConnectMessage->ClientView.ViewRemoteBase = ServerPort->
                                                         ClientSectionBase;
+
+            /* Save and reference the mapping process */
+            ServerPort->MappingProcess = PsGetCurrentProcess();
+            ObReferenceObject(ServerPort->MappingProcess);
         }
         else
         {
@@ -351,10 +355,10 @@ NtCompleteConnectPort(IN HANDLE PortHandle)
     /* Make sure it has a reply message */
     if (!Thread->LpcReplyMessage)
     {
-        /* It doesn't, fail */
+        /* It doesn't, quit */
         KeReleaseGuardedMutex(&LpcpLock);
         ObDereferenceObject(Port);
-        return STATUS_PORT_DISCONNECTED;
+        return STATUS_SUCCESS;
     }
 
     /* Clear the client thread and wake it up */
index 73a57ff..da0dcde 100644 (file)
@@ -21,6 +21,7 @@ LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
                IN PETHREAD CurrentThread)
 {
     PVOID SectionToMap;
+    PLPCP_MESSAGE ReplyMessage;
 
     /* Acquire the LPC lock */
     KeAcquireGuardedMutex(&LpcpLock);
@@ -34,10 +35,19 @@ LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
     }
 
     /* Check if there's a reply message */
-    if (CurrentThread->LpcReplyMessage)
+    ReplyMessage = CurrentThread->LpcReplyMessage;
+    if (ReplyMessage)
     {
         /* Get the message */
-        *Message = CurrentThread->LpcReplyMessage;
+        *Message = ReplyMessage;
+
+        /* Check if it's got messages */
+        if (!IsListEmpty(&ReplyMessage->Entry))
+        {
+            /* Clear the list */
+            RemoveEntryList(&ReplyMessage->Entry);
+            InitializeListHead(&ReplyMessage->Entry);
+        }
 
         /* Clear message data */
         CurrentThread->LpcReceivedMessageId = 0;
@@ -124,7 +134,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
     Status = ObReferenceObjectByName(PortName,
                                      0,
                                      NULL,
-                                     PORT_ALL_ACCESS,
+                                     PORT_CONNECT,
                                      LpcPortObjectType,
                                      PreviousMode,
                                      NULL,
@@ -286,6 +296,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
 
         /* Update the base */
         ClientView->ViewBase = Port->ClientSectionBase;
+
+        /* Reference and remember the process */
+        ClientPort->MappingProcess = PsGetCurrentProcess();
+        ObReferenceObject(ClientPort->MappingProcess);
     }
     else
     {
@@ -321,7 +335,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
         Message->Request.ClientViewSize = ClientView->ViewSize;
 
         /* Copy the client view and clear the server view */
-        RtlMoveMemory(&ConnectMessage->ClientView,
+        RtlCopyMemory(&ConnectMessage->ClientView,
                       ClientView,
                       sizeof(PORT_VIEW));
         RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
@@ -348,7 +362,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
     if (ConnectionInformation)
     {
         /* Copy it in */
-        RtlMoveMemory(ConnectMessage + 1,
+        RtlCopyMemory(ConnectMessage + 1,
                       ConnectionInformation,
                       ConnectionInfoLength);
     }
@@ -360,51 +374,63 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
     if (Port->Flags & LPCP_NAME_DELETED)
     {
         /* Fail the request */
-        KeReleaseGuardedMutex(&LpcpLock);
         Status = STATUS_OBJECT_NAME_NOT_FOUND;
-        goto Cleanup;
     }
+    else
+    {
+        /* Associate no thread yet */
+        Message->RepliedToThread = NULL;
 
-    /* Associate no thread yet */
-    Message->RepliedToThread = NULL;
+        /* Generate the Message ID and set it */
+        Message->Request.MessageId =  LpcpNextMessageId++;
+        if (!LpcpNextMessageId) LpcpNextMessageId = 1;
+        Thread->LpcReplyMessageId = Message->Request.MessageId;
 
-    /* Generate the Message ID and set it */
-    Message->Request.MessageId =  LpcpNextMessageId++;
-    if (!LpcpNextMessageId) LpcpNextMessageId = 1;
-    Thread->LpcReplyMessageId = Message->Request.MessageId;
+        /* Insert the message into the queue and thread chain */
+        InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
+        InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
+        Thread->LpcReplyMessage = Message;
 
-    /* Insert the message into the queue and thread chain */
-    InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
-    InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
-    Thread->LpcReplyMessage = Message;
+        /* Now we can finally reference the client port and link it*/
+        ObReferenceObject(ClientPort);
+        ConnectMessage->ClientPort = ClientPort;
 
-    /* Now we can finally reference the client port and link it*/
-    ObReferenceObject(ClientPort);
-    ConnectMessage->ClientPort = ClientPort;
+        /* Enter a critical region */
+        KeEnterCriticalRegion();
+    }
+
+    /* Add another reference to the port */
+    ObReferenceObject(Port);
 
     /* Release the lock */
     KeReleaseGuardedMutex(&LpcpLock);
-    LPCTRACE(LPC_CONNECT_DEBUG,
-             "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
-             Message,
-             ConnectMessage,
-             Port,
-             ClientPort,
-             Status);
-
-    /* If this is a waitable port, set the event */
-    if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
-                                                     1,
-                                                     FALSE);
 
-    /* Release the queue semaphore */
-    LpcpCompleteWait(Port->MsgQueue.Semaphore);
-
-    /* Now wait for a reply */
-    LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
+    /* Check for success */
+    if (NT_SUCCESS(Status))
+    {
+        LPCTRACE(LPC_CONNECT_DEBUG,
+                 "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
+                 Message,
+                 ConnectMessage,
+                 Port,
+                 ClientPort,
+                 Status);
+
+        /* If this is a waitable port, set the event */
+        if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
+                                                         1,
+                                                         FALSE);
+
+        /* Release the queue semaphore and leave the critical region */
+        LpcpCompleteWait(Port->MsgQueue.Semaphore);
+        KeLeaveCriticalRegion();
+
+        /* Now wait for a reply */
+        LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
+    }
 
-    /* Check if our wait ended in success */
-    if (Status != STATUS_SUCCESS) goto Cleanup;
+    /* Check for failure */
+    if (!NT_SUCCESS(Status)) goto Cleanup;
 
     /* Free the connection message */
     SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
@@ -432,7 +458,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
             }
 
             /* Return the connection information */
-            RtlMoveMemory(ConnectionInformation,
+            RtlCopyMemory(ConnectionInformation,
                           ConnectMessage + 1,
                           ConnectionInfoLength );
         }
@@ -466,7 +492,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
                 if (ClientView)
                 {
                     /* Copy it back */
-                    RtlMoveMemory(ClientView,
+                    RtlCopyMemory(ClientView,
                                   &ConnectMessage->ClientView,
                                   sizeof(PORT_VIEW));
                 }
@@ -475,7 +501,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
                 if (ServerView)
                 {
                     /* Copy it back */
-                    RtlMoveMemory(ServerView,
+                    RtlCopyMemory(ServerView,
                                   &ConnectMessage->ServerView,
                                   sizeof(REMOTE_PORT_VIEW));
                 }
@@ -486,8 +512,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
             /* No connection port, we failed */
             if (SectionToMap) ObDereferenceObject(SectionToMap);
 
+            /* Acquire the lock */
+            KeAcquireGuardedMutex(&LpcpLock);
+
             /* Check if it's because the name got deleted */
-            if (Port->Flags & LPCP_NAME_DELETED)
+            if (!(ClientPort->ConnectionPort) ||
+                (Port->Flags & LPCP_NAME_DELETED))
             {
                 /* Set the correct status */
                 Status = STATUS_OBJECT_NAME_NOT_FOUND;
@@ -498,19 +528,27 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
                 Status = STATUS_PORT_CONNECTION_REFUSED;
             }
 
+            /* Release the lock */
+            KeReleaseGuardedMutex(&LpcpLock);
+
             /* Kill the port */
             ObDereferenceObject(ClientPort);
         }
 
         /* Free the message */
-        LpcpFreeToPortZone(Message, FALSE);
-        return Status;
+        LpcpFreeToPortZone(Message, 0);
+    }
+    else
+    {
+        /* No reply message, fail */
+        if (SectionToMap) ObDereferenceObject(SectionToMap);
+        ObDereferenceObject(ClientPort);
+        Status = STATUS_PORT_CONNECTION_REFUSED;
     }
 
-    /* No reply message, fail */
-    if (SectionToMap) ObDereferenceObject(SectionToMap);
-    ObDereferenceObject(ClientPort);
-    return STATUS_PORT_CONNECTION_REFUSED;
+    /* Return status */
+    ObDereferenceObject(Port);
+    return Status;
 
 Cleanup:
     /* We failed, free the message */
@@ -521,20 +559,21 @@ Cleanup:
     {
         /* Wait on it */
         KeWaitForSingleObject(&Thread->LpcReplySemaphore,
+                              WrExecutive,
                               KernelMode,
-                              Executive,
                               FALSE,
                               NULL);
     }
 
     /* Check if we had a message and free it */
-    if (Message) LpcpFreeToPortZone(Message, FALSE);
+    if (Message) LpcpFreeToPortZone(Message, 0);
 
     /* Dereference other objects */
     if (SectionToMap) ObDereferenceObject(SectionToMap);
     ObDereferenceObject(ClientPort);
 
     /* Return status */
+    ObDereferenceObject(Port);
     return Status;
 }
 
index a0478d5..f98df57 100644 (file)
@@ -19,6 +19,7 @@ NTAPI
 LpcpInitializePortQueue(IN PLPCP_PORT_OBJECT Port)
 {
     PLPCP_NONPAGED_PORT_QUEUE MessageQueue;
+    PAGED_CODE();
 
     /* Allocate the queue */
     MessageQueue = ExAllocatePoolWithTag(NonPagedPool,
@@ -48,6 +49,7 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
     PLPCP_PORT_OBJECT Port;
+    PAGED_CODE();
     LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName);
 
     /* Create the Object */
index 8bc84c3..1d4e94c 100644 (file)
@@ -22,30 +22,30 @@ NTAPI
 NtListenPort(IN HANDLE PortHandle,
              OUT PPORT_MESSAGE ConnectMessage)
 {
-        NTSTATUS Status;
-        PAGED_CODE();
-        LPCTRACE(LPC_LISTEN_DEBUG, "Handle: %lx\n", PortHandle);
-
-        /* Wait forever for a connection request. */
-        for (;;)
+    NTSTATUS Status;
+    PAGED_CODE();
+    LPCTRACE(LPC_LISTEN_DEBUG, "Handle: %lx\n", PortHandle);
+
+    /* Wait forever for a connection request. */
+    for (;;)
+    {
+        /* Do the wait */
+        Status = NtReplyWaitReceivePort(PortHandle,
+                                        NULL,
+                                        NULL,
+                                        ConnectMessage);
+
+        /* Accept only LPC_CONNECTION_REQUEST requests. */
+        if ((Status != STATUS_SUCCESS) ||
+            (LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
         {
-            /* Do the wait */
-            Status = NtReplyWaitReceivePort(PortHandle,
-                                            NULL,
-                                            NULL,
-                                            ConnectMessage);
-
-            /* Accept only LPC_CONNECTION_REQUEST requests. */
-            if ((Status != STATUS_SUCCESS) ||
-                (LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
-            {
-                /* Break out */
-                break;
-            }
+            /* Break out */
+            break;
         }
+    }
 
-        /* Return status */
-        return Status;
+    /* Return status */
+    return Status;
 }
 
 
index b085e91..92bcb4d 100644 (file)
@@ -18,7 +18,7 @@ POBJECT_TYPE LpcPortObjectType;
 ULONG LpcpMaxMessageSize;
 PAGED_LOOKASIDE_LIST LpcpMessagesLookaside;
 KGUARDED_MUTEX LpcpLock;
-ULONG LpcpTraceLevel = 0;
+ULONG LpcpTraceLevel = LPC_CLOSE_DEBUG;
 ULONG LpcpNextMessageId = 1, LpcpNextCallbackId = 1;
 
 static GENERIC_MAPPING LpcpPortMapping = 
@@ -54,7 +54,6 @@ LpcpInitSystem(VOID)
     ObjectTypeInitializer.CloseProcedure = LpcpClosePort;
     ObjectTypeInitializer.DeleteProcedure = LpcpDeletePort;
     ObjectTypeInitializer.ValidAccessMask = PORT_ALL_ACCESS;
-    ObjectTypeInitializer.MaintainTypeList = TRUE;
     ObCreateObjectType(&Name,
                        &ObjectTypeInitializer,
                        NULL,
index 751564b..ac421fb 100644 (file)
@@ -18,7 +18,8 @@ VOID
 NTAPI
 LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
                         IN ULONG MessageId,
-                        IN ULONG CallbackId)
+                        IN ULONG CallbackId,
+                        IN CLIENT_ID ClientId)
 {
     PLPCP_MESSAGE Message;
     PLIST_ENTRY ListHead, NextEntry;
@@ -28,6 +29,7 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
     {
         /* Use it */
         Port = Port->ConnectionPort;
+        if (!Port) return;
     }
 
     /* Loop the list */
@@ -40,12 +42,13 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
 
         /* Make sure it matches */
         if ((Message->Request.MessageId == MessageId) &&
-            (Message->Request.CallbackId == CallbackId))
+            (Message->Request.ClientId.UniqueThread == ClientId.UniqueThread) &&
+            (Message->Request.ClientId.UniqueProcess == ClientId.UniqueProcess))
         {
             /* Unlink and free it */
             RemoveEntryList(&Message->Entry);
             InitializeListHead(&Message->Entry);
-            LpcpFreeToPortZone(Message, TRUE);
+            LpcpFreeToPortZone(Message, 1);
             break;
         }
 
@@ -58,25 +61,31 @@ VOID
 NTAPI
 LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
                         IN PLPCP_MESSAGE Message,
-                        IN ULONG LockFlags)
+                        IN ULONG LockHeld)
 {
     PAGED_CODE();
 
     /* Acquire the lock */
-    KeAcquireGuardedMutex(&LpcpLock);
+    if (!LockHeld) KeAcquireGuardedMutex(&LpcpLock);
 
     /* Check if the port we want is the connection port */
     if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT)
     {
         /* Use it */
         Port = Port->ConnectionPort;
+        if (!Port)
+        {
+            /* Release the lock and return */
+            if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
+            return;
+        }
     }
 
     /* Link the message */
     InsertTailList(&Port->LpcDataInfoChainHead, &Message->Entry);
 
     /* Release the lock */
-    KeReleaseGuardedMutex(&LpcpLock);
+    if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
 }
 
 VOID
@@ -119,7 +128,7 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination,
     Destination->ClientViewSize = Origin->ClientViewSize;
 
     /* Copy the Message Data */
-    RtlMoveMemory(Destination + 1,
+    RtlCopyMemory(Destination + 1,
                   Data,
                   ((Destination->u1.Length & 0xFFFF) + 3) &~3);
 }
@@ -149,7 +158,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
                          OUT PPORT_MESSAGE ReceiveMessage,
                          IN PLARGE_INTEGER Timeout OPTIONAL)
 {
-    PLPCP_PORT_OBJECT Port, ReceivePort;
+    PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
     NTSTATUS Status;
     PLPCP_MESSAGE Message;
@@ -207,8 +216,32 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     /* Check if this is anything but a client port */
     if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CLIENT_PORT)
     {
-        /* Use the connection port */
-        ReceivePort = Port->ConnectionPort;
+        /* Check if this is the connection port */
+        if (Port->ConnectionPort == Port)
+        {
+            /* Use this port */
+            ConnectionPort = ReceivePort = Port;
+            ObReferenceObject(ConnectionPort);
+        }
+        else
+        {
+            /* Acquire the lock */
+            KeAcquireGuardedMutex(&LpcpLock);
+
+            /* Get the port */
+            ConnectionPort = ReceivePort = Port->ConnectionPort;
+            if (!ConnectionPort)
+            {
+                /* Fail */
+                KeReleaseGuardedMutex(&LpcpLock);
+                ObDereferenceObject(Port);
+                return STATUS_PORT_DISCONNECTED;
+            }
+
+            /* Release lock and reference */
+            ObReferenceObject(Port);
+            KeReleaseGuardedMutex(&LpcpLock);
+        }
     }
     else
     {
@@ -227,6 +260,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         {
             /* No thread found, fail */
             ObDereferenceObject(Port);
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
             return Status;
         }
 
@@ -235,6 +269,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         if (!Message)
         {
             /* Fail if we couldn't allocate a message */
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
             ObDereferenceObject(WakeupThread);
             ObDereferenceObject(Port);
             return STATUS_NO_MEMORY;
@@ -244,11 +279,16 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         KeAcquireGuardedMutex(&LpcpLock);
 
         /* Make sure this is the reply the thread is waiting for */
-        if (WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId)
+        if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId))// ||
+#if 0
+            ((WakeupThread->LpcReplyMessage) &&
+            (LpcpGetMessageType(&((PLPCP_MESSAGE)WakeupThread->
+                                LpcReplyMessage)->Request) != LPC_REQUEST)))
+#endif
         {
             /* It isn't, fail */
-            LpcpFreeToPortZone(Message, TRUE);
-            KeReleaseGuardedMutex(&LpcpLock);
+            LpcpFreeToPortZone(Message, 3);
+            if (ConnectionPort) ObDereferenceObject(ConnectionPort);
             ObDereferenceObject(WakeupThread);
             ObDereferenceObject(Port);
             return STATUS_REPLY_MESSAGE_MISMATCH;
@@ -261,11 +301,6 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
                         LPC_REPLY,
                         NULL);
 
-        /* Free any data information */
-        LpcpFreeDataInfoMessage(Port,
-                                ReplyMessage->MessageId,
-                                ReplyMessage->CallbackId);
-
         /* Reference the thread while we use it */
         ObReferenceObject(WakeupThread);
         Message->RepliedToThread = WakeupThread;
@@ -292,6 +327,12 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
             Thread->LpcReceivedMsgIdValid = FALSE;
         }
 
+        /* Free any data information */
+        LpcpFreeDataInfoMessage(Port,
+                                ReplyMessage->MessageId,
+                                ReplyMessage->CallbackId,
+                                ReplyMessage->ClientId);
+
         /* Release the lock and release the LPC semaphore to wake up waiters */
         KeReleaseGuardedMutex(&LpcpLock);
         LpcpCompleteWait(&WakeupThread->LpcReplySemaphore);
@@ -319,6 +360,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
 
         /* Release the lock and fail */
         KeReleaseGuardedMutex(&LpcpLock);
+        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
         ObDereferenceObject(Port);
         return STATUS_UNSUCCESSFUL;
     }
@@ -347,9 +389,6 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
     Thread->LpcReceivedMessageId = Message->Request.MessageId;
     Thread->LpcReceivedMsgIdValid = TRUE;
 
-    /* Done touching global data, release the lock */
-    KeReleaseGuardedMutex(&LpcpLock);
-
     /* Check if this was a connection request */
     if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
     {
@@ -374,7 +413,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         ReceiveMessage->u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
                                             ConnectionInfoLength;
         ReceiveMessage->u1.s1.DataLength = ConnectionInfoLength;
-        RtlMoveMemory(ReceiveMessage + 1,
+        RtlCopyMemory(ReceiveMessage + 1,
                       ConnectMessage + 1,
                       ConnectionInfoLength);
 
@@ -413,8 +452,17 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
         ASSERT(FALSE);
     }
 
-    /* If we have a message pointer here, free it */
-    if (Message) LpcpFreeToPortZone(Message, FALSE);
+    /* Check if we have a message pointer here */
+    if (Message)
+    {
+        /* Free it and release the lock */
+        LpcpFreeToPortZone(Message, 3);
+    }
+    else
+    {
+        /* Just release the lock */
+        KeReleaseGuardedMutex(&LpcpLock);
+    }
 
 Cleanup:
     /* All done, dereference the port and return the status */
@@ -422,6 +470,7 @@ Cleanup:
              "Port: %p. Status: %p\n",
              Port,
              Status);
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
     ObDereferenceObject(Port);
     return Status;
 }
index 3658ced..525267c 100644 (file)
@@ -22,7 +22,7 @@ NTAPI
 LpcRequestPort(IN PVOID PortObject,
                IN PPORT_MESSAGE LpcMessage)
 {
-    PLPCP_PORT_OBJECT Port = (PLPCP_PORT_OBJECT)PortObject, QueuePort;
+    PLPCP_PORT_OBJECT Port = PortObject, QueuePort, ConnectionPort = NULL;
     ULONG MessageType;
     PLPCP_MESSAGE Message;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
@@ -42,7 +42,7 @@ LpcRequestPort(IN PVOID PortObject,
             return STATUS_INVALID_PARAMETER;
         }
 
-        /* Mark this as a kernel-mode message only if we really came from there */
+        /* Mark this as a kernel-mode message only if we really came from it */
         if ((PreviousMode == KernelMode) &&
             (LpcMessage->u2.s2.Type & LPC_KERNELMODE_MESSAGE))
         {
@@ -72,6 +72,7 @@ LpcRequestPort(IN PVOID PortObject,
     if (!Message) return STATUS_NO_MEMORY;
 
     /* Clear the context */
+    Message->RepliedToThread = NULL;
     Message->PortContext = NULL;
 
     /* Copy the message */
@@ -96,13 +97,28 @@ LpcRequestPort(IN PVOID PortObject,
             {
                 /* Then copy the context */
                 Message->PortContext = QueuePort->PortContext;
-                QueuePort = Port->ConnectionPort;
+                ConnectionPort = QueuePort = Port->ConnectionPort;
+                if (!ConnectionPort)
+                {
+                    /* Fail */
+                    LpcpFreeToPortZone(Message, 3);
+                    return STATUS_PORT_DISCONNECTED;
+                }
             }
             else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
             {
                 /* Any other kind of port, use the connection port */
-                QueuePort = Port->ConnectionPort;
+                ConnectionPort = QueuePort = Port->ConnectionPort;
+                if (!ConnectionPort)
+                {
+                    /* Fail */
+                    LpcpFreeToPortZone(Message, 3);
+                    return STATUS_PORT_DISCONNECTED;
+                }
             }
+
+            /* If we have a connection port, reference it */
+            if (ConnectionPort) ObReferenceObject(ConnectionPort);
         }
     }
     else
@@ -126,6 +142,7 @@ LpcRequestPort(IN PVOID PortObject,
         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
 
         /* Release the lock and release the semaphore */
+        KeEnterCriticalRegion();
         KeReleaseGuardedMutex(&LpcpLock);
         LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
 
@@ -137,13 +154,15 @@ LpcRequestPort(IN PVOID PortObject,
         }
 
         /* We're done */
+        KeLeaveCriticalRegion();
         LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
+        if (ConnectionPort) ObDereferenceObject(ConnectionPort);
         return STATUS_SUCCESS;
     }
 
     /* If we got here, then free the message and fail */
-    LpcpFreeToPortZone(Message, TRUE);
-    KeReleaseGuardedMutex(&LpcpLock);
+    LpcpFreeToPortZone(Message, 3);
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
     return STATUS_PORT_DISCONNECTED;
 }
 
@@ -181,7 +200,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
                        IN PPORT_MESSAGE LpcRequest,
                        IN OUT PPORT_MESSAGE LpcReply)
 {
-    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort;
+    PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
     PLPCP_MESSAGE Message;
@@ -286,8 +305,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             if (!QueuePort)
             {
                 /* We have no connected port, fail */
-                LpcpFreeToPortZone(Message, TRUE);
-                KeReleaseGuardedMutex(&LpcpLock);
+                LpcpFreeToPortZone(Message, 3);
                 ObDereferenceObject(Port);
                 return STATUS_PORT_DISCONNECTED;
             }
@@ -299,15 +317,32 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
             {
                 /* Copy the port context and use the connection port */
-                Message->PortContext = ReplyPort->PortContext;
-                QueuePort = Port->ConnectionPort;
+                Message->PortContext = QueuePort->PortContext;
+                ConnectionPort = QueuePort = Port->ConnectionPort;
+                if (!ConnectionPort)
+                {
+                    /* Fail */
+                    LpcpFreeToPortZone(Message, 3);
+                    ObDereferenceObject(Port);
+                    return STATUS_PORT_DISCONNECTED;
+                }
             }
             else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
                       LPCP_COMMUNICATION_PORT)
             {
                 /* Use the connection port for anything but communication ports */
-                QueuePort = Port->ConnectionPort;
+                ConnectionPort = QueuePort = Port->ConnectionPort;
+                if (!ConnectionPort)
+                {
+                    /* Fail */
+                    LpcpFreeToPortZone(Message, 3);
+                    ObDereferenceObject(Port);
+                    return STATUS_PORT_DISCONNECTED;
+                }
             }
+
+            /* Reference the connection port if it exists */
+            if (ConnectionPort) ObReferenceObject(ConnectionPort);
         }
         else
         {
@@ -317,6 +352,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
 
         /* No reply thread */
         Message->RepliedToThread = NULL;
+        Message->SenderPort = Port;
 
         /* Generate the Message ID and set it */
         Message->Request.MessageId =  LpcpNextMessageId++;
@@ -330,8 +366,10 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
         /* Insert the message in our chain */
         InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
         InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
+        Thread->LpcWaitingOnPort = Port;
 
         /* Release the lock and get the semaphore we'll use later */
+        KeEnterCriticalRegion();
         KeReleaseGuardedMutex(&LpcpLock);
         Semaphore = QueuePort->MsgQueue.Semaphore;
 
@@ -345,6 +383,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
 
     /* Now release the semaphore */
     LpcpCompleteWait(Semaphore);
+    KeLeaveCriticalRegion();
 
     /* And let's wait for the reply */
     LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);
@@ -396,7 +435,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
             else
             {
                 /* Otherwise, just free it */
-                LpcpFreeToPortZone(Message, FALSE);
+                LpcpFreeToPortZone(Message, 0);
             }
         }
         else
@@ -407,10 +446,8 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
     }
     else
     {
-        /* The wait failed, free the message while holding the lock */
-        KeAcquireGuardedMutex(&LpcpLock);
-        LpcpFreeToPortZone(Message, TRUE);
-        KeReleaseGuardedMutex(&LpcpLock);
+        /* The wait failed, free the message */
+        if (Message) LpcpFreeToPortZone(Message, 0);
     }
 
     /* All done */
@@ -419,6 +456,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
              Port,
              Status);
     ObDereferenceObject(Port);
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
     return Status;
 }