[STORPORT] Fix x64 build
[reactos.git] / ntoskrnl / lpc / complete.c
index 5134757..d33a38d 100644 (file)
@@ -41,20 +41,23 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
                     IN PVOID PortContext OPTIONAL,
                     IN PPORT_MESSAGE ReplyMessage,
                     IN BOOLEAN AcceptConnection,
-                    IN PPORT_VIEW ServerView,
-                    IN PREMOTE_PORT_VIEW ClientView)
+                    IN OUT PPORT_VIEW ServerView OPTIONAL,
+                    OUT PREMOTE_PORT_VIEW ClientView OPTIONAL)
 {
-    PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
-    PVOID ClientSectionToMap = NULL;
-    HANDLE Handle;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    PORT_VIEW CapturedServerView;
+    PORT_MESSAGE CapturedReplyMessage;
     ULONG ConnectionInfoLength;
-    PLPCP_MESSAGE Message;
+    PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
+    PLPCP_MESSAGE Message;
+    PVOID ClientSectionToMap = NULL;
+    HANDLE Handle;
     PEPROCESS ClientProcess;
     PETHREAD ClientThread;
     LARGE_INTEGER SectionOffset;
+
     PAGED_CODE();
     LPCTRACE(LPC_COMPLETE_DEBUG,
              "Context: %p. Message: %p. Accept: %lx. Views: %p/%p\n",
@@ -64,22 +67,89 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
              ClientView,
              ServerView);
 
-    /* Validate the size of the server view */
-    if ((ServerView) && (ServerView->Length != sizeof(PORT_VIEW)))
+    /* Check if the call comes from user mode */
+    if (PreviousMode != KernelMode)
     {
-        /* Invalid size */
-        return STATUS_INVALID_PARAMETER;
+        _SEH2_TRY
+        {
+            /* Probe the PortHandle */
+            ProbeForWriteHandle(PortHandle);
+
+            /* Probe the basic ReplyMessage structure */
+            ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
+            CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
+            ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
+
+            /* Probe the connection info */
+            ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1);
+
+            /* The following parameters are optional */
+
+            /* Capture the server view */
+            if (ServerView)
+            {
+                ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
+                CapturedServerView = *(volatile PORT_VIEW*)ServerView;
+
+                /* Validate the size of the server view */
+                if (CapturedServerView.Length != sizeof(CapturedServerView))
+                {
+                    /* Invalid size */
+                    _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
+                }
+            }
+
+            /* Capture the client view */
+            if (ClientView)
+            {
+                ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
+
+                /* Validate the size of the client view */
+                if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView))
+                {
+                    /* Invalid size */
+                    _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
+                }
+            }
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            /* There was an exception, return the exception code */
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
     }
-
-    /* Validate the size of the client view */
-    if ((ClientView) && (ClientView->Length != sizeof(REMOTE_PORT_VIEW)))
+    else
     {
-        /* Invalid size */
-        return STATUS_INVALID_PARAMETER;
+        CapturedReplyMessage = *ReplyMessage;
+        ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
+
+        /* Capture the server view */
+        if (ServerView)
+        {
+            /* Validate the size of the server view */
+            if (ServerView->Length != sizeof(*ServerView))
+            {
+                /* Invalid size */
+                return STATUS_INVALID_PARAMETER;
+            }
+            CapturedServerView = *ServerView;
+        }
+
+        /* Capture the client view */
+        if (ClientView)
+        {
+            /* Validate the size of the client view */
+            if (ClientView->Length != sizeof(*ClientView))
+            {
+                /* Invalid size */
+                return STATUS_INVALID_PARAMETER;
+            }
+        }
     }
 
     /* Get the client process and thread */
-    Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
+    Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
                                         &ClientProcess,
                                         &ClientThread);
     if (!NT_SUCCESS(Status)) return Status;
@@ -89,8 +159,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
 
     /* Make sure that the client wants a reply, and this is the right one */
     if (!(LpcpGetMessageFromThread(ClientThread)) ||
-        !(ReplyMessage->MessageId) ||
-        (ClientThread->LpcReplyMessageId != ReplyMessage->MessageId))
+        !(CapturedReplyMessage.MessageId) ||
+        (ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId))
     {
         /* Not the reply asked for, or no reply wanted, fail */
         KeReleaseGuardedMutex(&LpcpLock);
@@ -125,8 +195,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     ConnectMessage->ClientPort = NULL;
     KeReleaseGuardedMutex(&LpcpLock);
 
-    /* Get the connection information length */
-    ConnectionInfoLength = ReplyMessage->u1.s1.DataLength;
+    /* Check the connection information length */
     if (ConnectionInfoLength > ConnectionPort->MaxConnectionInfoLength)
     {
         /* Normalize it since it's too large */
@@ -142,16 +211,26 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     /* Setup the reply message */
     Message->Request.u2.s2.Type = LPC_REPLY;
     Message->Request.u2.s2.DataInfoOffset = 0;
-    Message->Request.ClientId = ReplyMessage->ClientId;
-    Message->Request.MessageId = ReplyMessage->MessageId;
+    Message->Request.ClientId  = CapturedReplyMessage.ClientId;
+    Message->Request.MessageId = CapturedReplyMessage.MessageId;
     Message->Request.ClientViewSize = 0;
-    RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength);
+
+    _SEH2_TRY
+    {
+        RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength);
+    }
+    _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+    {
+        Status = _SEH2_GetExceptionCode();
+        _SEH2_YIELD(goto Cleanup);
+    }
+    _SEH2_END;
 
     /* At this point, if the caller refused the connection, go to cleanup */
     if (!AcceptConnection)
     {
         DPRINT1("LPC connection was refused\n");
-        goto Cleanup;   
+        goto Cleanup;
     }
 
     /* Otherwise, create the actual port */
@@ -239,6 +318,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     if (ServerView)
     {
         /* FIXME: TODO */
+        UNREFERENCED_PARAMETER(CapturedServerView);
         ASSERT(FALSE);
     }
 
@@ -259,16 +339,30 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
         goto Cleanup;
     }
 
-    /* Check if the caller gave a client view */
-    if (ClientView)
+    /* Enter SEH to write back the results */
+    _SEH2_TRY
     {
-        /* Fill it out */
-        ClientView->ViewBase = ConnectMessage->ClientView.ViewRemoteBase;
-        ClientView->ViewSize = ConnectMessage->ClientView.ViewSize;
+        /* Check if the caller gave a client view */
+        if (ClientView)
+        {
+            /* Fill it out */
+            ClientView->ViewBase = ConnectMessage->ClientView.ViewRemoteBase;
+            ClientView->ViewSize = ConnectMessage->ClientView.ViewSize;
+        }
+
+        /* Return the handle to user mode */
+        *PortHandle = Handle;
     }
+    _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+    {
+        /* Cleanup and return the exception code */
+        ObCloseHandle(Handle, PreviousMode);
+        ObDereferenceObject(ServerPort);
+        Status = _SEH2_GetExceptionCode();
+        _SEH2_YIELD(goto Cleanup);
+    }
+    _SEH2_END;
 
-    /* Return the handle to user mode */
-    *PortHandle = Handle;
     LPCTRACE(LPC_COMPLETE_DEBUG,
              "Handle: %p. Messages: %p/%p. Ports: %p/%p/%p\n",
              Handle,
@@ -300,8 +394,12 @@ Cleanup:
     /* Check if we got here while still having a client thread */
     if (ClientThread)
     {
-        /* FIXME: Complex cleanup code */
-        ASSERT(FALSE);
+        KeAcquireGuardedMutex(&LpcpLock);
+        ClientThread->LpcReplyMessage = Message;
+        LpcpPrepareToWakeClient(ClientThread);
+        KeReleaseGuardedMutex(&LpcpLock);
+        LpcpCompleteWait(&ClientThread->LpcReplySemaphore);
+        ObDereferenceObject(ClientThread);
     }
 
     /* Dereference the client port if we have one, and the process */
@@ -323,9 +421,10 @@ NTAPI
 NtCompleteConnectPort(IN HANDLE PortHandle)
 {
     NTSTATUS Status;
-    PLPCP_PORT_OBJECT Port;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    PLPCP_PORT_OBJECT Port;
     PETHREAD Thread;
+
     PAGED_CODE();
     LPCTRACE(LPC_COMPLETE_DEBUG, "Handle: %p\n", PortHandle);
 
@@ -378,7 +477,7 @@ NtCompleteConnectPort(IN HANDLE PortHandle)
     KeReleaseGuardedMutex(&LpcpLock);
     LpcpCompleteWait(&Thread->LpcReplySemaphore);
 
-    /* Dereference the Thread and Port  and return */
+    /* Dereference the Thread and Port and return */
     ObDereferenceObject(Port);
     ObDereferenceObject(Thread);
     LPCTRACE(LPC_COMPLETE_DEBUG, "Port: %p. Thread: %p\n", Port, Thread);