ULONG MessageType;
PLPCP_MESSAGE Message;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+ PETHREAD Thread = PsGetCurrentThread();
+
PAGED_CODE();
+
LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", Port, LpcMessage);
/* Check if this is a non-datagram message */
/* Can't have data information on this type of call */
if (LpcMessage->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
- /* Validate message sizes */
+ /* Validate the message length */
if (((ULONG)LpcMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)LpcMessage->u1.s1.TotalLength <= (ULONG)LpcMessage->u1.s1.DataLength))
{
LpcMessage,
LpcMessage + 1,
MessageType,
- &PsGetCurrentThread()->Cid);
+ &Thread->Cid);
/* Acquire the LPC lock */
KeAcquireGuardedMutex(&LpcpLock);
if (!ConnectionPort)
{
/* Fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
return STATUS_PORT_DISCONNECTED;
}
}
if (!ConnectionPort)
{
/* Fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
return STATUS_PORT_DISCONNECTED;
}
}
if (QueuePort)
{
/* Generate the Message ID and set it */
- Message->Request.MessageId = LpcpNextMessageId++;
+ Message->Request.MessageId = LpcpNextMessageId++;
if (!LpcpNextMessageId) LpcpNextMessageId = 1;
Message->Request.CallbackId = 0;
/* No Message ID for the thread */
- PsGetCurrentThread()->LpcReplyMessageId = 0;
+ Thread->LpcReplyMessageId = 0;
/* Insert the message in our chain */
InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
- /* Release the lock and release the semaphore */
+ /* Release the lock and the semaphore */
KeEnterCriticalRegion();
KeReleaseGuardedMutex(&LpcpLock);
LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
}
- /* We're done */
KeLeaveCriticalRegion();
+
+ /* We're done */
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
return STATUS_SUCCESS;
}
/* If we got here, then free the message and fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return STATUS_PORT_DISCONNECTED;
}
{
/* No type */
case 0:
-
+
/* Assume LPC request */
MessageType = LPC_REQUEST;
break;
-
+
/* LPC request callback */
case LPC_REQUEST:
-
+
/* This is a callback */
Callback = TRUE;
break;
-
+
/* Anything else */
case LPC_CLIENT_DIED:
case LPC_PORT_CLOSED:
case LPC_EXCEPTION:
case LPC_DEBUG_EVENT:
case LPC_ERROR_EVENT:
-
+
/* Nothing to do */
break;
-
+
default:
-
+
/* Invalid message type */
return STATUS_INVALID_PARAMETER;
}
-
+
/* Set the request type */
LpcRequest->u2.s2.Type = MessageType;
if (!QueuePort)
{
/* We have no connected port, fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
return STATUS_PORT_DISCONNECTED;
}
if (!ConnectionPort)
{
/* Fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
return STATUS_PORT_DISCONNECTED;
}
}
if (!ConnectionPort)
{
/* Fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
return STATUS_PORT_DISCONNECTED;
}
}
Message->SenderPort = Port;
/* Generate the Message ID and set it */
- Message->Request.MessageId = LpcpNextMessageId++;
+ Message->Request.MessageId = LpcpNextMessageId++;
if (!LpcpNextMessageId) LpcpNextMessageId = 1;
Message->Request.CallbackId = 0;
(&Message->Request) + 1,
0,
NULL);
-
+
/* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock);
-
+
/* Check if we replied to a thread */
if (Message->RepliedToThread)
{
Message->RepliedToThread = NULL;
}
-
/* Free the message */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
}
else
{
/* All done */
LPCTRACE(LPC_SEND_DEBUG,
- "Port: %p. Status: %p\n",
+ "Port: %p. Status: %d\n",
Port,
Status);
NtRequestPort(IN HANDLE PortHandle,
IN PPORT_MESSAGE LpcRequest)
{
- PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
- KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
NTSTATUS Status;
+ PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
+ ULONG MessageType;
PLPCP_MESSAGE Message;
+ KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
PETHREAD Thread = PsGetCurrentThread();
- PKSEMAPHORE Semaphore;
- ULONG MessageType;
PAGED_CODE();
+
LPCTRACE(LPC_SEND_DEBUG,
- "Handle: %lx. Message: %p. Type: %lx\n",
+ "Handle: %p. Message: %p. Type: %lx\n",
PortHandle,
LpcRequest,
LpcpGetMessageType(LpcRequest));
{
/* Copy it */
LpcpMoveMessage(&Message->Request,
- LpcRequest,
- LpcRequest + 1,
- MessageType,
- &Thread->Cid);
+ LpcRequest,
+ LpcRequest + 1,
+ MessageType,
+ &Thread->Cid);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
if (!QueuePort)
{
/* We have no connected port, fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
if (!ConnectionPort)
{
/* Fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
if (!ConnectionPort)
{
/* Fail */
- LpcpFreeToPortZone(Message, 3);
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
Message->SenderPort = Port;
/* Generate the Message ID and set it */
- Message->Request.MessageId = LpcpNextMessageId++;
+ Message->Request.MessageId = LpcpNextMessageId++;
if (!LpcpNextMessageId) LpcpNextMessageId = 1;
Message->Request.CallbackId = 0;
/* No Message ID for the thread */
- PsGetCurrentThread()->LpcReplyMessageId = 0;
+ Thread->LpcReplyMessageId = 0;
/* Insert the message in our chain */
InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
- /* Release the lock and get the semaphore we'll use later */
+ /* Release the lock and the semaphore */
KeEnterCriticalRegion();
KeReleaseGuardedMutex(&LpcpLock);
-
- /* Now release the semaphore */
- Semaphore = QueuePort->MsgQueue.Semaphore;
- LpcpCompleteWait(Semaphore);
+ LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
/* If this is a waitable port, wake it up */
if (QueuePort->Flags & LPCP_WAITABLE_PORT)
/* All done with a failure*/
LPCTRACE(LPC_SEND_DEBUG,
- "Port: %p. Status: %p\n",
+ "Port: %p. Status: %d\n",
Port,
Status);
/* The wait failed, free the message */
- if (Message) LpcpFreeToPortZone(Message, 3);
+ if (Message) LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
ObDereferenceObject(Port);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return Status;
}
+NTSTATUS
+NTAPI
+LpcpVerifyMessageDataInfo(
+ _In_ PPORT_MESSAGE Message,
+ _Out_ PULONG NumberOfDataEntries)
+{
+ PLPCP_DATA_INFO DataInfo;
+ PUCHAR EndOfEntries;
+
+ /* Check if we have no data info at all */
+ if (Message->u2.s2.DataInfoOffset == 0)
+ {
+ *NumberOfDataEntries = 0;
+ return STATUS_SUCCESS;
+ }
+
+ /* Make sure the data info structure is within the message */
+ if (((ULONG)Message->u1.s1.TotalLength <
+ sizeof(PORT_MESSAGE) + sizeof(LPCP_DATA_INFO)) ||
+ ((ULONG)Message->u2.s2.DataInfoOffset < sizeof(PORT_MESSAGE)) ||
+ ((ULONG)Message->u2.s2.DataInfoOffset >
+ ((ULONG)Message->u1.s1.TotalLength - sizeof(LPCP_DATA_INFO))))
+ {
+ return STATUS_INVALID_PARAMETER;
+ }
+
+ /* Get a pointer to the data info */
+ DataInfo = LpcpGetDataInfoFromMessage(Message);
+
+ /* Make sure the full data info with all entries is within the message */
+ EndOfEntries = (PUCHAR)&DataInfo->Entries[DataInfo->NumberOfEntries];
+ if ((EndOfEntries > ((PUCHAR)Message + (ULONG)Message->u1.s1.TotalLength)) ||
+ (EndOfEntries < (PUCHAR)Message))
+ {
+ return STATUS_INVALID_PARAMETER;
+ }
+
+ *NumberOfDataEntries = DataInfo->NumberOfEntries;
+ return STATUS_SUCCESS;
+}
+
/*
* @implemented
*/
IN PPORT_MESSAGE LpcRequest,
IN OUT PPORT_MESSAGE LpcReply)
{
+ PORT_MESSAGE LocalLpcRequest;
+ ULONG NumberOfDataEntries;
PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
NTSTATUS Status;
BOOLEAN Callback;
PKSEMAPHORE Semaphore;
ULONG MessageType;
+ PLPCP_DATA_INFO DataInfo;
PAGED_CODE();
LPCTRACE(LPC_SEND_DEBUG,
- "Handle: %lx. Messages: %p/%p. Type: %lx\n",
+ "Handle: %p. Messages: %p/%p. Type: %lx\n",
PortHandle,
LpcRequest,
LpcReply,
/* Check if the thread is dying */
if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;
+ /* Check for user mode access */
+ if (PreviousMode != KernelMode)
+ {
+ _SEH2_TRY
+ {
+ /* Probe the full request message and copy the base structure */
+ ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
+ ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG));
+ LocalLpcRequest = *LpcRequest;
+
+ /* Probe the reply message for write */
+ ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG));
+
+ /* Make sure the data entries in the request message are valid */
+ Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
+ if (!NT_SUCCESS(Status))
+ {
+ DPRINT1("LpcpVerifyMessageDataInfo failed\n");
+ return Status;
+ }
+ }
+ _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+ {
+ DPRINT1("Got exception\n");
+ return _SEH2_GetExceptionCode();
+ }
+ _SEH2_END;
+ }
+ else
+ {
+ LocalLpcRequest = *LpcRequest;
+ Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries);
+ if (!NT_SUCCESS(Status))
+ {
+ DPRINT1("LpcpVerifyMessageDataInfo failed\n");
+ return Status;
+ }
+ }
+
/* Check if this is an LPC Request */
- if (LpcpGetMessageType(LpcRequest) == LPC_REQUEST)
+ if (LpcpGetMessageType(&LocalLpcRequest) == LPC_REQUEST)
{
/* Then it's a callback */
Callback = TRUE;
}
- else if (LpcpGetMessageType(LpcRequest))
+ else if (LpcpGetMessageType(&LocalLpcRequest))
{
/* This is a not kernel-mode message */
+ DPRINT1("Not a kernel-mode message!\n");
return STATUS_INVALID_PARAMETER;
}
else
{
/* This is a kernel-mode message without a callback */
- LpcRequest->u2.s2.Type |= LPC_REQUEST;
+ LocalLpcRequest.u2.s2.Type |= LPC_REQUEST;
Callback = FALSE;
}
/* Get the message type */
- MessageType = LpcRequest->u2.s2.Type;
+ MessageType = LocalLpcRequest.u2.s2.Type;
+
+ /* Due to the above probe, we know that TotalLength is positive */
+ NT_ASSERT(LocalLpcRequest.u1.s1.TotalLength >= 0);
/* Validate the length */
- if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
- (ULONG)LpcRequest->u1.s1.TotalLength)
+ if ((((ULONG)(USHORT)LocalLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
+ (ULONG)LocalLpcRequest.u1.s1.TotalLength))
{
/* Fail */
+ DPRINT1("Invalid message length: %u, %u\n",
+ LocalLpcRequest.u1.s1.DataLength,
+ LocalLpcRequest.u1.s1.TotalLength);
return STATUS_INVALID_PARAMETER;
}
if (!NT_SUCCESS(Status)) return Status;
/* Validate the message length */
- if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
- ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
+ if (((ULONG)LocalLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) ||
+ ((ULONG)LocalLpcRequest.u1.s1.TotalLength <= (ULONG)LocalLpcRequest.u1.s1.DataLength))
{
/* Fail */
+ DPRINT1("Invalid message length: %u, %u\n",
+ LocalLpcRequest.u1.s1.DataLength,
+ LocalLpcRequest.u1.s1.TotalLength);
ObDereferenceObject(Port);
return STATUS_PORT_MESSAGE_TOO_LONG;
}
if (!Message)
{
/* Fail if we couldn't allocate a message */
+ DPRINT1("Failed to allocate a message!\n");
ObDereferenceObject(Port);
return STATUS_NO_MEMORY;
}
/* No callback, just copy the message */
_SEH2_TRY
{
+ /* Check if we have data info entries */
+ if (LpcRequest->u2.s2.DataInfoOffset != 0)
+ {
+ /* Get the data info and check if the number of entries matches
+ what we expect */
+ DataInfo = LpcpGetDataInfoFromMessage(LpcRequest);
+ if (DataInfo->NumberOfEntries != NumberOfDataEntries)
+ {
+ LpcpFreeToPortZone(Message, 0);
+ ObDereferenceObject(Port);
+ DPRINT1("NumberOfEntries has changed: %u, %u\n",
+ DataInfo->NumberOfEntries, NumberOfDataEntries);
+ return STATUS_INVALID_PARAMETER;
+ }
+ }
+
/* Copy it */
LpcpMoveMessage(&Message->Request,
LpcRequest,
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Fail */
+ DPRINT1("Got exception!\n");
LpcpFreeToPortZone(Message, 0);
ObDereferenceObject(Port);
_SEH2_YIELD(return _SEH2_GetExceptionCode());
if (!QueuePort)
{
/* We have no connected port, fail */
- LpcpFreeToPortZone(Message, 3);
+ DPRINT1("No connected port\n");
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
/* This will be the rundown port */
ReplyPort = QueuePort;
- /* Check if this is a communication port */
+ /* Check if this is a client port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
{
- /* Copy the port context and use the connection port */
+ /* Copy the port context */
Message->PortContext = QueuePort->PortContext;
- ConnectionPort = QueuePort = Port->ConnectionPort;
- if (!ConnectionPort)
- {
- /* Fail */
- LpcpFreeToPortZone(Message, 3);
- ObDereferenceObject(Port);
- return STATUS_PORT_DISCONNECTED;
- }
}
- else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
- LPCP_COMMUNICATION_PORT)
+
+ if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
{
/* Use the connection port for anything but communication ports */
ConnectionPort = QueuePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
- LpcpFreeToPortZone(Message, 3);
+ DPRINT1("No connection port\n");
+ LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
Message->SenderPort = Port;
/* Generate the Message ID and set it */
- Message->Request.MessageId = LpcpNextMessageId++;
+ Message->Request.MessageId = LpcpNextMessageId++;
if (!LpcpNextMessageId) LpcpNextMessageId = 1;
Message->Request.CallbackId = 0;
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
+ DPRINT1("Got exception!\n");
Status = _SEH2_GetExceptionCode();
}
_SEH2_END;
/* All done */
LPCTRACE(LPC_SEND_DEBUG,
- "Port: %p. Status: %p\n",
+ "Port: %p. Status: %d\n",
Port,
Status);
ObDereferenceObject(Port);