/* INCLUDES ******************************************************************/
#include <ntoskrnl.h>
-#include "lpc.h"
#define NDEBUG
-#include <internal/debug.h>
+#include <debug.h>
/* PRIVATE FUNCTIONS *********************************************************/
LpcExitThread(IN PETHREAD Thread)
{
PLPCP_MESSAGE Message;
+ ASSERT(Thread == PsGetCurrentThread());
/* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock);
Thread->LpcReplyMessageId = 0;
/* Check if there's a reply message */
- Message = Thread->LpcReplyMessage;
+ Message = LpcpGetMessageFromThread(Thread);
if (Message)
{
/* FIXME: TODO */
- KEBUGCHECK(0);
+ ASSERT(FALSE);
}
/* Release the lock */
PLPCP_CONNECTION_MESSAGE ConnectMessage;
PLPCP_PORT_OBJECT ClientPort = NULL;
PETHREAD Thread = NULL;
- BOOLEAN LockHeld = Flags & 1;
+ BOOLEAN LockHeld = Flags & 1, ReleaseLock = Flags & 2;
PAGED_CODE();
LPCTRACE(LPC_CLOSE_DEBUG, "Message: %p. Flags: %lx\n", Message, Flags);
ExFreeToPagedLookasideList(&LpcpMessagesLookaside, Message);
/* Reacquire the lock if needed */
- if ((LockHeld) && !(Flags & 2)) KeAcquireGuardedMutex(&LpcpLock);
+ if ((LockHeld) && !(ReleaseLock)) KeAcquireGuardedMutex(&LpcpLock);
}
VOID
PLIST_ENTRY ListHead, NextEntry;
PETHREAD Thread;
PLPCP_MESSAGE Message;
+ PLPCP_PORT_OBJECT ConnectionPort = NULL;
PLPCP_CONNECTION_MESSAGE ConnectMessage;
+ PAGED_CODE();
LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags);
/* Hold the lock */
KeAcquireGuardedMutex(&LpcpLock);
- /* Disconnect the port to which this port is connected */
- if (Port->ConnectedPort) Port->ConnectedPort->ConnectedPort = NULL;
+ /* Check if we have a connected port */
+ if (((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_UNCONNECTED_PORT) &&
+ (Port->ConnectedPort))
+ {
+ /* Disconnect it */
+ Port->ConnectedPort->ConnectedPort = NULL;
+ ConnectionPort = Port->ConnectedPort->ConnectionPort;
+ if (ConnectionPort)
+ {
+ /* Clear connection port */
+ Port->ConnectedPort->ConnectionPort = NULL;
+ }
+ }
/* Check if this is a connection port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT)
/* Walk all the threads waiting and signal them */
ListHead = &Port->LpcReplyChainHead;
NextEntry = ListHead->Flink;
- while (NextEntry != ListHead)
+ while ((NextEntry) && (NextEntry != ListHead))
{
/* Get the Thread */
Thread = CONTAINING_RECORD(NextEntry, ETHREAD, LpcReplyChain);
/* Check if someone is waiting */
if (!KeReadStateSemaphore(&Thread->LpcReplySemaphore))
{
- /* Get the message and check if it's a connection request */
- Message = Thread->LpcReplyMessage;
- if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
+ /* Get the message */
+ Message = LpcpGetMessageFromThread(Thread);
+ if (Message)
{
- /* Get the connection message */
- ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
-
- /* Check if it had a section */
- if (ConnectMessage->SectionToMap)
+ /* Check if it's a connection request */
+ if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
{
- /* Dereference it */
- ObDereferenceObject(ConnectMessage->SectionToMap);
+ /* Get the connection message */
+ ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
+
+ /* Check if it had a section */
+ if (ConnectMessage->SectionToMap)
+ {
+ /* Dereference it */
+ ObDereferenceObject(ConnectMessage->SectionToMap);
+ }
}
- }
- /* Clear the reply message */
- Thread->LpcReplyMessage = NULL;
+ /* Clear the reply message */
+ Thread->LpcReplyMessage = NULL;
- /* And remove the message from the port zone */
- LpcpFreeToPortZone(Message, TRUE);
- }
+ /* And remove the message from the port zone */
+ LpcpFreeToPortZone(Message, 1);
+ NextEntry = Port->LpcReplyChainHead.Flink;
+ }
- /* Release the semaphore and reset message id count */
- Thread->LpcReplyMessageId = 0;
- LpcpCompleteWait(&Thread->LpcReplySemaphore);
+ /* Release the semaphore and reset message id count */
+ Thread->LpcReplyMessageId = 0;
+ KeReleaseSemaphore(&Thread->LpcReplySemaphore, 0, 1, FALSE);
+ }
}
/* Reinitialize the list head */
InitializeListHead(&Port->LpcReplyChainHead);
/* Loop queued messages */
- ListHead = &Port->MsgQueue.ReceiveHead;
- NextEntry = ListHead->Flink;
- while (ListHead != NextEntry)
+ while ((Port->MsgQueue.ReceiveHead.Flink) &&
+ !(IsListEmpty(&Port->MsgQueue.ReceiveHead)))
{
/* Get the message */
- Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
- NextEntry = NextEntry->Flink;
+ Message = CONTAINING_RECORD(Port->MsgQueue.ReceiveHead.Flink,
+ LPCP_MESSAGE,
+ Entry);
/* Free and reinitialize it's list head */
+ RemoveEntryList(&Message->Entry);
InitializeListHead(&Message->Entry);
/* Remove it from the port zone */
- LpcpFreeToPortZone(Message, TRUE);
+ LpcpFreeToPortZone(Message, 1);
}
- /* Reinitialize the message queue list head */
- InitializeListHead(&Port->MsgQueue.ReceiveHead);
-
/* Release the lock */
KeReleaseGuardedMutex(&LpcpLock);
+ /* Dereference the connection port */
+ if (ConnectionPort) ObDereferenceObject(ConnectionPort);
+
/* Check if we have to free the port entirely */
if (Destroy)
{
/* Setup the client died message */
ClientDiedMsg.h.u1.s1.TotalLength = sizeof(ClientDiedMsg);
ClientDiedMsg.h.u1.s1.DataLength = sizeof(ClientDiedMsg.CreateTime);
- ClientDiedMsg.h.u2.ZeroInit = LPC_PORT_CLOSED;
+ ClientDiedMsg.h.u2.ZeroInit = 0;
+ ClientDiedMsg.h.u2.s2.Type = LPC_PORT_CLOSED;
ClientDiedMsg.CreateTime = PsGetCurrentProcess()->CreateTime;
/* Send it */
/* Destroy the port queue */
LpcpDestroyPortQueue(Port, TRUE);
- /* Check if we had a client view */
- if (Port->ClientSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(),
- Port->ClientSectionBase);
+ /* Check if we had views */
+ if ((Port->ClientSectionBase) || (Port->ServerSectionBase))
+ {
+ /* Check if we had a client view */
+ if (Port->ClientSectionBase)
+ {
+ /* Unmap it */
+ MmUnmapViewOfSection(Port->MappingProcess,
+ Port->ClientSectionBase);
+ }
- /* Check for a server view */
- if (Port->ServerSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(),
- Port->ServerSectionBase);
+ /* Check for a server view */
+ if (Port->ServerSectionBase)
+ {
+ /* Unmap it */
+ MmUnmapViewOfSection(Port->MappingProcess,
+ Port->ServerSectionBase);
+ }
+
+ /* Dereference the mapping process */
+ ObDereferenceObject(Port->MappingProcess);
+ Port->MappingProcess = NULL;
+ }
+
+ /* Acquire the lock */
+ KeAcquireGuardedMutex(&LpcpLock);
/* Get the connection port */
ConnectionPort = Port->ConnectionPort;
/* Get the PID */
Pid = PsGetCurrentProcessId();
- /* Acquire the lock */
- KeAcquireGuardedMutex(&LpcpLock);
-
/* Loop the data lists */
ListHead = &ConnectionPort->LpcDataInfoChainHead;
NextEntry = ListHead->Flink;
Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
NextEntry = NextEntry->Flink;
- /* Check if the PID matches */
- if (Message->Request.ClientId.UniqueProcess == Pid)
+ /* Check if this is the connection port */
+ if (Port == ConnectionPort)
+ {
+ /* Free queued messages */
+ RemoveEntryList(&Message->Entry);
+ InitializeListHead(&Message->Entry);
+ LpcpFreeToPortZone(Message, 1);
+
+ /* Restart at the head */
+ NextEntry = ListHead->Flink;
+ }
+ else if ((Message->Request.ClientId.UniqueProcess == Pid) &&
+ ((Message->SenderPort == Port) ||
+ (Message->SenderPort == Port->ConnectedPort) ||
+ (Message->SenderPort == ConnectionPort)))
{
/* Remove it */
RemoveEntryList(&Message->Entry);
- LpcpFreeToPortZone(Message, TRUE);
+ InitializeListHead(&Message->Entry);
+ LpcpFreeToPortZone(Message, 1);
+
+ /* Restart at the head */
+ NextEntry = ListHead->Flink;
}
}
/* Dereference the object unless it's the same port */
if (ConnectionPort != Port) ObDereferenceObject(ConnectionPort);
}
+ else
+ {
+ /* Release the lock */
+ KeReleaseGuardedMutex(&LpcpLock);
+ }
/* Check if this is a connection port with a server process*/
if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) &&