[STORPORT] Fix x64 build
[reactos.git] / ntoskrnl / lpc / port.c
index 5a6205f..7ca93a1 100644 (file)
@@ -60,6 +60,7 @@ LpcInitSystem(VOID)
                        NULL,
                        &LpcPortObjectType);
 
+    /* Create the Waitable Port Object Type */
     RtlInitUnicodeString(&Name, L"WaitablePort");
     ObjectTypeInitializer.PoolType = NonPagedPool;
     ObjectTypeInitializer.DefaultNonPagedPoolCharge += sizeof(LPCP_PORT_OBJECT);
@@ -84,6 +85,40 @@ LpcInitSystem(VOID)
     return TRUE;
 }
 
+BOOLEAN
+NTAPI
+LpcpValidateClientPort(
+    PETHREAD ClientThread,
+    PLPCP_PORT_OBJECT Port)
+{
+    PLPCP_PORT_OBJECT ThreadPort;
+
+    /* Get the thread's port */
+    ThreadPort = LpcpGetPortFromThread(ClientThread);
+    if (ThreadPort == NULL)
+    {
+        return FALSE;
+    }
+
+    /* Check if the port matches directly */
+    if ((Port == ThreadPort) ||
+        (Port == ThreadPort->ConnectionPort) ||
+        (Port == ThreadPort->ConnectedPort))
+    {
+        return TRUE;
+    }
+
+    /* Check if this is a communication port and the connection port matches */
+    if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_COMMUNICATION_PORT) &&
+        (Port->ConnectionPort == ThreadPort))
+    {
+        return TRUE;
+    }
+
+    return FALSE;
+}
+
+
 /* PUBLIC FUNCTIONS **********************************************************/
 
 NTSTATUS
@@ -91,8 +126,150 @@ NTAPI
 NtImpersonateClientOfPort(IN HANDLE PortHandle,
                           IN PPORT_MESSAGE ClientMessage)
 {
-    UNIMPLEMENTED;
-    return STATUS_NOT_IMPLEMENTED;
+    NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    CLIENT_ID ClientId;
+    ULONG MessageId;
+    PLPCP_PORT_OBJECT Port = NULL, ConnectedPort = NULL;
+    PETHREAD ClientThread = NULL;
+    SECURITY_CLIENT_CONTEXT ClientContext;
+
+    PAGED_CODE();
+
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
+    {
+        _SEH2_TRY
+        {
+            ProbeForRead(ClientMessage, sizeof(*ClientMessage), sizeof(PVOID));
+            ClientId  = ((volatile PORT_MESSAGE*)ClientMessage)->ClientId;
+            MessageId = ((volatile PORT_MESSAGE*)ClientMessage)->MessageId;
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        ClientId  = ClientMessage->ClientId;
+        MessageId = ClientMessage->MessageId;
+    }
+
+    /* Reference the port handle */
+    Status = ObReferenceObjectByHandle(PortHandle,
+                                       PORT_ALL_ACCESS,
+                                       LpcPortObjectType,
+                                       PreviousMode,
+                                       (PVOID*)&Port,
+                                       NULL);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to reference port handle: 0x%ls\n", Status);
+        return Status;
+    }
+
+    /* Make sure this is a connection port */
+    if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
+    {
+        /* It isn't, fail */
+        DPRINT1("Port is not a communication port\n");
+        Status = STATUS_INVALID_PORT_HANDLE;
+        goto Cleanup;
+    }
+
+    /* Look up the client thread */
+    Status = PsLookupProcessThreadByCid(&ClientId, NULL, &ClientThread);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to lookup client thread: 0x%ls\n", Status);
+        goto Cleanup;
+    }
+
+    /* Acquire the lock */
+    KeAcquireGuardedMutex(&LpcpLock);
+
+    /* Get the connected port and try to reference it */
+    ConnectedPort = Port->ConnectedPort;
+    if ((ConnectedPort == NULL) || !ObReferenceObjectSafe(ConnectedPort))
+    {
+        DPRINT1("Failed to reference the connected port\n");
+        ConnectedPort = NULL;
+        Status = STATUS_PORT_DISCONNECTED;
+        goto CleanupWithLock;
+    }
+
+    /* Check for no-impersonation flag */
+    if ((ULONG_PTR)ClientThread->LpcReplyMessage & LPCP_THREAD_FLAG_NO_IMPERSONATION)
+    {
+        DPRINT1("Reply message has no impersonation flag set\n");
+        Status = STATUS_ACCESS_DENIED;
+        goto CleanupWithLock;
+    }
+
+    /* Check for message id mismatch */
+    if ((ClientThread->LpcReplyMessageId != MessageId) || (MessageId == 0))
+    {
+        DPRINT1("LpcReplyMessageId mismatch: 0x%lx/0x%lx.\n",
+                ClientThread->LpcReplyMessageId, MessageId);
+        Status = STATUS_REPLY_MESSAGE_MISMATCH;
+        goto CleanupWithLock;
+    }
+
+    /* Validate the port */
+    if (!LpcpValidateClientPort(ClientThread, Port))
+    {
+        DPRINT1("LpcpValidateClientPort failed\n");
+        Status = STATUS_REPLY_MESSAGE_MISMATCH;
+        goto CleanupWithLock;
+    }
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+
+    /* Check if security is static */
+    if (!(ConnectedPort->Flags & LPCP_SECURITY_DYNAMIC))
+    {
+        /* Use the static security for impersonation */
+        Status = SeImpersonateClientEx(&ConnectedPort->StaticSecurity, NULL);
+        goto Cleanup;
+    }
+
+    /* Create new dynamic security */
+    Status = SeCreateClientSecurity(ClientThread,
+                                    &ConnectedPort->SecurityQos,
+                                    FALSE,
+                                    &ClientContext);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("SeCreateClientSecurity failed\n");
+        goto Cleanup;
+    }
+
+    /* Use dynamic security for impersonation */
+    Status = SeImpersonateClientEx(&ClientContext, NULL);
+
+    /* Get rid of the security context */
+    SeDeleteClientSecurity(&ClientContext);
+
+Cleanup:
+
+    if (ConnectedPort != NULL)
+        ObDereferenceObject(ConnectedPort);
+
+    if (ClientThread != NULL)
+        ObDereferenceObject(ClientThread);
+
+    ObDereferenceObject(Port);
+
+    return Status;
+
+CleanupWithLock:
+
+    /* Release the lock */
+    KeReleaseGuardedMutex(&LpcpLock);
+    goto Cleanup;
 }
 
 NTSTATUS