NTAPI
NtSecureConnectPort(OUT PHANDLE PortHandle,
IN PUNICODE_STRING PortName,
- IN PSECURITY_QUALITY_OF_SERVICE Qos,
+ IN PSECURITY_QUALITY_OF_SERVICE SecurityQos,
IN OUT PPORT_VIEW ClientView OPTIONAL,
IN PSID ServerSid OPTIONAL,
IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL,
IN OUT PVOID ConnectionInformation OPTIONAL,
IN OUT PULONG ConnectionInformationLength OPTIONAL)
{
+ NTSTATUS Status = STATUS_SUCCESS;
+ KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+ PETHREAD Thread = PsGetCurrentThread();
+ SECURITY_QUALITY_OF_SERVICE CapturedQos;
+ PORT_VIEW CapturedClientView;
+ PSID CapturedServerSid;
ULONG ConnectionInfoLength = 0;
PLPCP_PORT_OBJECT Port, ClientPort;
- KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
- NTSTATUS Status = STATUS_SUCCESS;
- HANDLE Handle;
- PVOID SectionToMap;
PLPCP_MESSAGE Message;
PLPCP_CONNECTION_MESSAGE ConnectMessage;
- PETHREAD Thread = PsGetCurrentThread();
ULONG PortMessageLength;
+ HANDLE Handle;
+ PVOID SectionToMap;
LARGE_INTEGER SectionOffset;
PTOKEN Token;
PTOKEN_USER TokenUserInfo;
+
PAGED_CODE();
LPCTRACE(LPC_CONNECT_DEBUG,
- "Name: %wZ. Qos: %p. Views: %p/%p. Sid: %p\n",
+ "Name: %wZ. SecurityQos: %p. Views: %p/%p. Sid: %p\n",
PortName,
- Qos,
+ SecurityQos,
ClientView,
ServerView,
ServerSid);
- /* Validate client view */
- if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
+ /* Check if the call comes from user mode */
+ if (PreviousMode != KernelMode)
{
- /* Fail */
- return STATUS_INVALID_PARAMETER;
- }
+ /* Enter SEH for probing the parameters */
+ _SEH2_TRY
+ {
+ /* Probe the PortHandle */
+ ProbeForWriteHandle(PortHandle);
- /* Validate server view */
- if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
- {
- /* Fail */
- return STATUS_INVALID_PARAMETER;
- }
+ /* Probe and capture the QoS */
+ ProbeForRead(SecurityQos, sizeof(*SecurityQos), sizeof(ULONG));
+ CapturedQos = *(volatile SECURITY_QUALITY_OF_SERVICE*)SecurityQos;
+ /* NOTE: Do not care about CapturedQos.Length */
+
+ /* The following parameters are optional */
+
+ /* Capture the client view */
+ if (ClientView)
+ {
+ ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
+ CapturedClientView = *(volatile PORT_VIEW*)ClientView;
+
+ /* Validate the size of the client view */
+ if (CapturedClientView.Length != sizeof(CapturedClientView))
+ {
+ /* Invalid size */
+ _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
+ }
+
+ }
+
+ /* Capture the server view */
+ if (ServerView)
+ {
+ ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
+
+ /* Validate the size of the server view */
+ if (((volatile REMOTE_PORT_VIEW*)ServerView)->Length != sizeof(*ServerView))
+ {
+ /* Invalid size */
+ _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
+ }
+ }
+
+ if (MaxMessageLength)
+ ProbeForWriteUlong(MaxMessageLength);
+
+ /* Capture connection information length */
+ if (ConnectionInformationLength)
+ {
+ ProbeForWriteUlong(ConnectionInformationLength);
+ ConnectionInfoLength = *(volatile ULONG*)ConnectionInformationLength;
+ }
- /* Check if caller sent connection information length */
- if (ConnectionInformationLength)
+ /* Probe the ConnectionInformation */
+ if (ConnectionInformation)
+ ProbeForWrite(ConnectionInformation, ConnectionInfoLength, sizeof(ULONG));
+
+ CapturedServerSid = ServerSid;
+ if (ServerSid != NULL)
+ {
+ /* Capture it */
+ Status = SepCaptureSid(ServerSid,
+ PreviousMode,
+ PagedPool,
+ TRUE,
+ &CapturedServerSid);
+ if (!NT_SUCCESS(Status))
+ {
+ DPRINT1("Failed to capture ServerSid!\n");
+ _SEH2_YIELD(return Status);
+ }
+ }
+ }
+ _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+ {
+ /* There was an exception, return the exception code */
+ _SEH2_YIELD(return _SEH2_GetExceptionCode());
+ }
+ _SEH2_END;
+ }
+ else
{
- /* Retrieve the input length */
- ConnectionInfoLength = *ConnectionInformationLength;
+ CapturedQos = *SecurityQos;
+ /* NOTE: Do not care about CapturedQos.Length */
+
+ /* The following parameters are optional */
+
+ /* Capture the client view */
+ if (ClientView)
+ {
+ /* Validate the size of the client view */
+ if (ClientView->Length != sizeof(*ClientView))
+ {
+ /* Invalid size */
+ return STATUS_INVALID_PARAMETER;
+ }
+ CapturedClientView = *ClientView;
+ }
+
+ /* Capture the server view */
+ if (ServerView)
+ {
+ /* Validate the size of the server view */
+ if (ServerView->Length != sizeof(*ServerView))
+ {
+ /* Invalid size */
+ return STATUS_INVALID_PARAMETER;
+ }
+ }
+
+ /* Capture connection information length */
+ if (ConnectionInformationLength)
+ ConnectionInfoLength = *ConnectionInformationLength;
+
+ CapturedServerSid = ServerSid;
}
/* Get the port */
LpcPortObjectType,
PreviousMode,
NULL,
- (PVOID *)&Port);
- if (!NT_SUCCESS(Status)) return Status;
+ (PVOID*)&Port);
+ if (!NT_SUCCESS(Status))
+ {
+ DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status);
+
+ if (CapturedServerSid != ServerSid)
+ SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
+
+ return Status;
+ }
/* This has to be a connection port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
{
/* It isn't, so fail */
ObDereferenceObject(Port);
+
+ if (CapturedServerSid != ServerSid)
+ SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
+
return STATUS_INVALID_PORT_HANDLE;
}
- /* Check if we have a SID */
+ /* Check if we have a (captured) SID */
if (ServerSid)
{
/* Make sure that we have a server */
{
/* Get its token and query user information */
Token = PsReferencePrimaryToken(Port->ServerProcess);
- //Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
- // FIXME: Need SeQueryInformationToken
- Status = STATUS_SUCCESS;
- TokenUserInfo = ExAllocatePool(PagedPool, sizeof(TOKEN_USER));
- TokenUserInfo->User.Sid = ServerSid;
+ Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
PsDereferencePrimaryToken(Token);
/* Check for success */
if (NT_SUCCESS(Status))
{
/* Compare the SIDs */
- if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid))
+ if (!RtlEqualSid(CapturedServerSid, TokenUserInfo->User.Sid))
{
/* Fail */
Status = STATUS_SERVER_SID_MISMATCH;
}
/* Free token information */
- ExFreePool(TokenUserInfo);
+ ExFreePoolWithTag(TokenUserInfo, TAG_SE);
}
}
else
Status = STATUS_SERVER_SID_MISMATCH;
}
+ /* Finally release the captured SID, we don't need it anymore */
+ if (CapturedServerSid != ServerSid)
+ SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
+
/* Check if SID failed */
if (!NT_SUCCESS(Status))
{
sizeof(LPCP_PORT_OBJECT),
0,
0,
- (PVOID *)&ClientPort);
+ (PVOID*)&ClientPort);
if (!NT_SUCCESS(Status))
{
/* Failed, dereference the server port and return */
return Status;
}
- /* Setup the client port */
+ /*
+ * Setup the client port -- From now on, dereferencing the client port
+ * will automatically dereference the connection port too.
+ */
RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
ClientPort->Flags = LPCP_CLIENT_PORT;
ClientPort->ConnectionPort = Port;
ClientPort->MaxMessageLength = Port->MaxMessageLength;
- ClientPort->SecurityQos = *Qos;
+ ClientPort->SecurityQos = CapturedQos;
InitializeListHead(&ClientPort->LpcReplyChainHead);
InitializeListHead(&ClientPort->LpcDataInfoChainHead);
/* Check if we have dynamic security */
- if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
+ if (CapturedQos.ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
{
/* Remember that */
ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
{
/* Create our own client security */
Status = SeCreateClientSecurity(Thread,
- Qos,
+ &CapturedQos,
FALSE,
&ClientPort->StaticSecurity);
if (!NT_SUCCESS(Status))
if (ClientView)
{
/* Get the section handle */
- Status = ObReferenceObjectByHandle(ClientView->SectionHandle,
+ Status = ObReferenceObjectByHandle(CapturedClientView.SectionHandle,
SECTION_MAP_READ |
SECTION_MAP_WRITE,
MmSectionObjectType,
if (!NT_SUCCESS(Status))
{
/* Fail */
- ObDereferenceObject(Port);
+ ObDereferenceObject(ClientPort);
return Status;
}
/* Set the section offset */
- SectionOffset.QuadPart = ClientView->SectionOffset;
+ SectionOffset.QuadPart = CapturedClientView.SectionOffset;
/* Map it */
Status = MmMapViewOfSection(SectionToMap,
0,
0,
&SectionOffset,
- &ClientView->ViewSize,
+ &CapturedClientView.ViewSize,
ViewUnmap,
0,
PAGE_READWRITE);
/* Update the offset */
- ClientView->SectionOffset = SectionOffset.LowPart;
+ CapturedClientView.SectionOffset = SectionOffset.LowPart;
/* Check for failure */
if (!NT_SUCCESS(Status))
{
/* Fail */
ObDereferenceObject(SectionToMap);
- ObDereferenceObject(Port);
+ ObDereferenceObject(ClientPort);
return Status;
}
/* Update the base */
- ClientView->ViewBase = ClientPort->ClientSectionBase;
+ CapturedClientView.ViewBase = ClientPort->ClientSectionBase;
/* Reference and remember the process */
ClientPort->MappingProcess = PsGetCurrentProcess();
if (ClientView)
{
/* Set the view size */
- Message->Request.ClientViewSize = ClientView->ViewSize;
+ Message->Request.ClientViewSize = CapturedClientView.ViewSize;
/* Copy the client view and clear the server view */
RtlCopyMemory(&ConnectMessage->ClientView,
- ClientView,
- sizeof(PORT_VIEW));
+ &CapturedClientView,
+ sizeof(CapturedClientView));
RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
}
else
ConnectMessage->SectionToMap = SectionToMap;
/* Set the data for the connection request message */
- Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
+ Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
sizeof(LPCP_CONNECTION_MESSAGE);
Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
Message->Request.u1.s1.DataLength;
/* Check if we have connection information */
if (ConnectionInformation)
{
- /* Copy it in */
- RtlCopyMemory(ConnectMessage + 1,
- ConnectionInformation,
- ConnectionInfoLength);
+ _SEH2_TRY
+ {
+ /* Copy it in */
+ RtlCopyMemory(ConnectMessage + 1,
+ ConnectionInformation,
+ ConnectionInfoLength);
+ }
+ _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+ {
+ /* Cleanup and return the exception code */
+
+ /* Free the message we have */
+ LpcpFreeToPortZone(Message, 0);
+
+ /* Dereference other objects */
+ if (SectionToMap) ObDereferenceObject(SectionToMap);
+ ObDereferenceObject(ClientPort);
+
+ /* Return status */
+ _SEH2_YIELD(return _SEH2_GetExceptionCode());
+ }
+ _SEH2_END;
}
+ /* Reset the status code */
+ Status = STATUS_SUCCESS;
+
/* Acquire the port lock */
KeAcquireGuardedMutex(&LpcpLock);
InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
Thread->LpcReplyMessage = Message;
- /* Now we can finally reference the client port and link it*/
+ /* Now we can finally reference the client port and link it */
ObReferenceObject(ClientPort);
ConnectMessage->ClientPort = ClientPort;
Status);
/* If this is a waitable port, set the event */
- if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
- 1,
- FALSE);
+ if (Port->Flags & LPCP_WAITABLE_PORT)
+ KeSetEvent(&Port->WaitEvent, 1, FALSE);
/* Release the queue semaphore and leave the critical region */
LpcpCompleteWait(Port->MsgQueue.Semaphore);
KeLeaveCriticalRegion();
- /* Now wait for a reply */
+ /* Now wait for a reply and set 'Status' */
LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
}
+ /* Now, always free the connection message */
+ SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
+
/* Check for failure */
- if (!NT_SUCCESS(Status)) goto Cleanup;
+ if (!NT_SUCCESS(Status))
+ {
+ /* Check if the semaphore got signaled in the meantime */
+ if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
+ {
+ /* Wait on it */
+ KeWaitForSingleObject(&Thread->LpcReplySemaphore,
+ WrExecutive,
+ KernelMode,
+ FALSE,
+ NULL);
+ }
- /* Free the connection message */
- SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
+ goto Failure;
+ }
/* Check if we got a message back */
if (Message)
sizeof(LPCP_CONNECTION_MESSAGE);
}
- /* Check if we had connection information */
+ /* Check if the caller had connection information */
if (ConnectionInformation)
{
- /* Check if we had a length pointer */
- if (ConnectionInformationLength)
+ _SEH2_TRY
{
- /* Return the length */
- *ConnectionInformationLength = ConnectionInfoLength;
+ /* Return the connection information length if needed */
+ if (ConnectionInformationLength)
+ *ConnectionInformationLength = ConnectionInfoLength;
+
+ /* Return the connection information */
+ RtlCopyMemory(ConnectionInformation,
+ ConnectMessage + 1,
+ ConnectionInfoLength);
}
-
- /* Return the connection information */
- RtlCopyMemory(ConnectionInformation,
- ConnectMessage + 1,
- ConnectionInfoLength );
+ _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+ {
+ /* Cleanup and return the exception code */
+ Status = _SEH2_GetExceptionCode();
+ _SEH2_YIELD(goto Failure);
+ }
+ _SEH2_END;
}
/* Make sure we had a connected port */
NULL,
PORT_ALL_ACCESS,
0,
- (PVOID *)NULL,
+ NULL,
&Handle);
if (NT_SUCCESS(Status))
{
- /* Return the handle */
- *PortHandle = Handle;
LPCTRACE(LPC_CONNECT_DEBUG,
"Handle: %p. Length: %lx\n",
Handle,
PortMessageLength);
- /* Check if maximum length was requested */
- if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
-
- /* Check if we had a client view */
- if (ClientView)
+ _SEH2_TRY
{
- /* Copy it back */
- RtlCopyMemory(ClientView,
- &ConnectMessage->ClientView,
- sizeof(PORT_VIEW));
+ /* Return the handle */
+ *PortHandle = Handle;
+
+ /* Check if maximum length was requested */
+ if (MaxMessageLength)
+ *MaxMessageLength = PortMessageLength;
+
+ /* Check if we had a client view */
+ if (ClientView)
+ {
+ /* Copy it back */
+ RtlCopyMemory(ClientView,
+ &ConnectMessage->ClientView,
+ sizeof(*ClientView));
+ }
+
+ /* Check if we had a server view */
+ if (ServerView)
+ {
+ /* Copy it back */
+ RtlCopyMemory(ServerView,
+ &ConnectMessage->ServerView,
+ sizeof(*ServerView));
+ }
}
-
- /* Check if we had a server view */
- if (ServerView)
+ _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
- /* Copy it back */
- RtlCopyMemory(ServerView,
- &ConnectMessage->ServerView,
- sizeof(REMOTE_PORT_VIEW));
+ /* An exception happened, close the opened handle */
+ ObCloseHandle(Handle, PreviousMode);
+ Status = _SEH2_GetExceptionCode();
}
+ _SEH2_END;
}
}
else
else
{
/* No reply message, fail */
- if (SectionToMap) ObDereferenceObject(SectionToMap);
- ObDereferenceObject(ClientPort);
Status = STATUS_PORT_CONNECTION_REFUSED;
+ goto Failure;
}
- /* Return status */
ObDereferenceObject(Port);
- return Status;
-
-Cleanup:
- /* We failed, free the message */
- SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
- /* Check if the semaphore got signaled */
- if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
- {
- /* Wait on it */
- KeWaitForSingleObject(&Thread->LpcReplySemaphore,
- WrExecutive,
- KernelMode,
- FALSE,
- NULL);
- }
+ /* Return status */
+ return Status;
+Failure:
/* Check if we had a message and free it */
if (Message) LpcpFreeToPortZone(Message, 0);
/* Dereference other objects */
if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort);
+ ObDereferenceObject(Port);
/* Return status */
- ObDereferenceObject(Port);
return Status;
}
NTAPI
NtConnectPort(OUT PHANDLE PortHandle,
IN PUNICODE_STRING PortName,
- IN PSECURITY_QUALITY_OF_SERVICE Qos,
- IN PPORT_VIEW ClientView,
- IN PREMOTE_PORT_VIEW ServerView,
- OUT PULONG MaxMessageLength,
- IN PVOID ConnectionInformation,
- OUT PULONG ConnectionInformationLength)
+ IN PSECURITY_QUALITY_OF_SERVICE SecurityQos,
+ IN OUT PPORT_VIEW ClientView OPTIONAL,
+ IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL,
+ OUT PULONG MaxMessageLength OPTIONAL,
+ IN OUT PVOID ConnectionInformation OPTIONAL,
+ IN OUT PULONG ConnectionInformationLength OPTIONAL)
{
/* Call the newer API */
return NtSecureConnectPort(PortHandle,
PortName,
- Qos,
+ SecurityQos,
ClientView,
NULL,
ServerView,