[ntoskrnl]
[reactos.git] / reactos / ntoskrnl / lpc / close.c
index af9803d..1d37b59 100644 (file)
@@ -9,9 +9,8 @@
 /* INCLUDES ******************************************************************/
 
 #include <ntoskrnl.h>
-#include "lpc.h"
 #define NDEBUG
-#include <internal/debug.h>
+#include <debug.h>
 
 /* PRIVATE FUNCTIONS *********************************************************/
 
@@ -20,6 +19,7 @@ NTAPI
 LpcExitThread(IN PETHREAD Thread)
 {
     PLPCP_MESSAGE Message;
+    ASSERT(Thread == PsGetCurrentThread());
 
     /* Acquire the lock */
     KeAcquireGuardedMutex(&LpcpLock);
@@ -36,11 +36,11 @@ LpcExitThread(IN PETHREAD Thread)
     Thread->LpcReplyMessageId = 0;
 
     /* Check if there's a reply message */
-    Message = Thread->LpcReplyMessage;
+    Message = LpcpGetMessageFromThread(Thread);
     if (Message)
     {
         /* FIXME: TODO */
-        KEBUGCHECK(0);
+        ASSERT(FALSE);
     }
 
     /* Release the lock */
@@ -55,7 +55,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message,
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
     PLPCP_PORT_OBJECT ClientPort = NULL;
     PETHREAD Thread = NULL;
-    BOOLEAN LockHeld = Flags & 1;
+    BOOLEAN LockHeld = Flags & 1, ReleaseLock = Flags & 2;
     PAGED_CODE();
     LPCTRACE(LPC_CLOSE_DEBUG, "Message: %p. Flags: %lx\n", Message, Flags);
 
@@ -100,7 +100,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message,
     ExFreeToPagedLookasideList(&LpcpMessagesLookaside, Message);
 
     /* Reacquire the lock if needed */
-    if ((LockHeld) && !(Flags & 2)) KeAcquireGuardedMutex(&LpcpLock);
+    if ((LockHeld) && !(ReleaseLock)) KeAcquireGuardedMutex(&LpcpLock);
 }
 
 VOID
@@ -111,14 +111,27 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
     PLIST_ENTRY ListHead, NextEntry;
     PETHREAD Thread;
     PLPCP_MESSAGE Message;
+    PLPCP_PORT_OBJECT ConnectionPort = NULL;
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
+    PAGED_CODE();
     LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags);
 
     /* Hold the lock */
     KeAcquireGuardedMutex(&LpcpLock);
 
-    /* Disconnect the port to which this port is connected */
-    if (Port->ConnectedPort) Port->ConnectedPort->ConnectedPort = NULL;
+    /* Check if we have a connected port */
+    if (((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_UNCONNECTED_PORT) &&
+        (Port->ConnectedPort))
+    {
+        /* Disconnect it */
+        Port->ConnectedPort->ConnectedPort = NULL;
+        ConnectionPort = Port->ConnectedPort->ConnectionPort;
+        if (ConnectionPort)
+        {
+            /* Clear connection port */
+            Port->ConnectedPort->ConnectionPort = NULL;
+        }
+    }
 
     /* Check if this is a connection port */
     if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT)
@@ -130,7 +143,7 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
     /* Walk all the threads waiting and signal them */
     ListHead = &Port->LpcReplyChainHead;
     NextEntry = ListHead->Flink;
-    while (NextEntry != ListHead)
+    while ((NextEntry) && (NextEntry != ListHead))
     {
         /* Get the Thread */
         Thread = CONTAINING_RECORD(NextEntry, ETHREAD, LpcReplyChain);
@@ -148,58 +161,64 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
         /* Check if someone is waiting */
         if (!KeReadStateSemaphore(&Thread->LpcReplySemaphore))
         {
-            /* Get the message and check if it's a connection request */
-            Message = Thread->LpcReplyMessage;
-            if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
+            /* Get the message */
+            Message = LpcpGetMessageFromThread(Thread);
+            if (Message)
             {
-                /* Get the connection message */
-                ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
-
-                /* Check if it had a section */
-                if (ConnectMessage->SectionToMap)
+                /* Check if it's a connection request */
+                if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
                 {
-                    /* Dereference it */
-                    ObDereferenceObject(ConnectMessage->SectionToMap);
+                    /* Get the connection message */
+                    ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+
+                    /* Check if it had a section */
+                    if (ConnectMessage->SectionToMap)
+                    {
+                        /* Dereference it */
+                        ObDereferenceObject(ConnectMessage->SectionToMap);
+                    }
                 }
-            }
 
-            /* Clear the reply message */
-            Thread->LpcReplyMessage = NULL;
+                /* Clear the reply message */
+                Thread->LpcReplyMessage = NULL;
 
-            /* And remove the message from the port zone */
-            LpcpFreeToPortZone(Message, TRUE);
-        }
+                /* And remove the message from the port zone */
+                LpcpFreeToPortZone(Message, 1);
+                NextEntry = Port->LpcReplyChainHead.Flink;
+            }
 
-        /* Release the semaphore and reset message id count */
-        Thread->LpcReplyMessageId = 0;
-        LpcpCompleteWait(&Thread->LpcReplySemaphore);
+            /* Release the semaphore and reset message id count */
+            Thread->LpcReplyMessageId = 0;
+            KeReleaseSemaphore(&Thread->LpcReplySemaphore, 0, 1, FALSE);
+        }
     }
 
     /* Reinitialize the list head */
     InitializeListHead(&Port->LpcReplyChainHead);
 
     /* Loop queued messages */
-    ListHead = &Port->MsgQueue.ReceiveHead;
-    NextEntry =  ListHead->Flink;
-    while (ListHead != NextEntry)
+    while ((Port->MsgQueue.ReceiveHead.Flink) &&
+           !(IsListEmpty(&Port->MsgQueue.ReceiveHead)))
     {
         /* Get the message */
-        Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
-        NextEntry = NextEntry->Flink;
+        Message = CONTAINING_RECORD(Port->MsgQueue.ReceiveHead.Flink,
+                                    LPCP_MESSAGE,
+                                    Entry);
 
         /* Free and reinitialize it's list head */
+        RemoveEntryList(&Message->Entry);
         InitializeListHead(&Message->Entry);
 
         /* Remove it from the port zone */
-        LpcpFreeToPortZone(Message, TRUE);
+        LpcpFreeToPortZone(Message, 1);
     }
 
-    /* Reinitialize the message queue list head */
-    InitializeListHead(&Port->MsgQueue.ReceiveHead);
-
     /* Release the lock */
     KeReleaseGuardedMutex(&LpcpLock);
 
+    /* Dereference the connection port */
+    if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+
     /* Check if we have to free the port entirely */
     if (Destroy)
     {
@@ -317,7 +336,8 @@ LpcpDeletePort(IN PVOID ObjectBody)
         /* Setup the client died message */
         ClientDiedMsg.h.u1.s1.TotalLength = sizeof(ClientDiedMsg);
         ClientDiedMsg.h.u1.s1.DataLength = sizeof(ClientDiedMsg.CreateTime);
-        ClientDiedMsg.h.u2.ZeroInit = LPC_PORT_CLOSED;
+        ClientDiedMsg.h.u2.ZeroInit = 0;
+        ClientDiedMsg.h.u2.s2.Type = LPC_PORT_CLOSED;
         ClientDiedMsg.CreateTime = PsGetCurrentProcess()->CreateTime;
 
         /* Send it */
@@ -335,13 +355,32 @@ LpcpDeletePort(IN PVOID ObjectBody)
     /* Destroy the port queue */
     LpcpDestroyPortQueue(Port, TRUE);
 
-    /* Check if we had a client view */
-    if (Port->ClientSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(),
-                                                      Port->ClientSectionBase);
+    /* Check if we had views */
+    if ((Port->ClientSectionBase) || (Port->ServerSectionBase))
+    {
+        /* Check if we had a client view */
+        if (Port->ClientSectionBase)
+        {
+            /* Unmap it */
+            MmUnmapViewOfSection(Port->MappingProcess,
+                                 Port->ClientSectionBase);
+        }
 
-    /* Check for a server view */
-    if (Port->ServerSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(),
-                                                      Port->ServerSectionBase);
+        /* Check for a server view */
+        if (Port->ServerSectionBase)
+        {
+            /* Unmap it */
+            MmUnmapViewOfSection(Port->MappingProcess,
+                                 Port->ServerSectionBase);
+        }
+
+        /* Dereference the mapping process */
+        ObDereferenceObject(Port->MappingProcess);
+        Port->MappingProcess = NULL;
+    }
+
+    /* Acquire the lock */
+    KeAcquireGuardedMutex(&LpcpLock);
 
     /* Get the connection port */
     ConnectionPort = Port->ConnectionPort;
@@ -350,9 +389,6 @@ LpcpDeletePort(IN PVOID ObjectBody)
         /* Get the PID */
         Pid = PsGetCurrentProcessId();
 
-        /* Acquire the lock */
-        KeAcquireGuardedMutex(&LpcpLock);
-
         /* Loop the data lists */
         ListHead = &ConnectionPort->LpcDataInfoChainHead;
         NextEntry = ListHead->Flink;
@@ -362,12 +398,29 @@ LpcpDeletePort(IN PVOID ObjectBody)
             Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
             NextEntry = NextEntry->Flink;
 
-            /* Check if the PID matches */
-            if (Message->Request.ClientId.UniqueProcess == Pid)
+            /* Check if this is the connection port */
+            if (Port == ConnectionPort)
+            {
+                /* Free queued messages */
+                RemoveEntryList(&Message->Entry);
+                InitializeListHead(&Message->Entry);
+                LpcpFreeToPortZone(Message, 1);
+
+                /* Restart at the head */
+                NextEntry = ListHead->Flink;
+            }
+            else if ((Message->Request.ClientId.UniqueProcess == Pid) &&
+                     ((Message->SenderPort == Port) ||
+                      (Message->SenderPort == Port->ConnectedPort) ||
+                      (Message->SenderPort == ConnectionPort)))
             {
                 /* Remove it */
                 RemoveEntryList(&Message->Entry);
-                LpcpFreeToPortZone(Message, TRUE);
+                InitializeListHead(&Message->Entry);
+                LpcpFreeToPortZone(Message, 1);
+
+                /* Restart at the head */
+                NextEntry = ListHead->Flink;
             }
         }
 
@@ -377,6 +430,11 @@ LpcpDeletePort(IN PVOID ObjectBody)
         /* Dereference the object unless it's the same port */
         if (ConnectionPort != Port) ObDereferenceObject(ConnectionPort);
     }
+    else
+    {
+        /* Release the lock */
+        KeReleaseGuardedMutex(&LpcpLock);
+    }
 
     /* Check if this is a connection port with a server process*/
     if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) &&