[STORPORT] Fix x64 build
[reactos.git] / ntoskrnl / lpc / complete.c
index 0f9f7f2..d33a38d 100644 (file)
@@ -41,22 +41,23 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
                     IN PVOID PortContext OPTIONAL,
                     IN PPORT_MESSAGE ReplyMessage,
                     IN BOOLEAN AcceptConnection,
-                    IN PPORT_VIEW ServerView,
-                    IN PREMOTE_PORT_VIEW ClientView)
+                    IN OUT PPORT_VIEW ServerView OPTIONAL,
+                    OUT PREMOTE_PORT_VIEW ClientView OPTIONAL)
 {
-    PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
-    PVOID ClientSectionToMap = NULL;
-    HANDLE Handle;
-    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     NTSTATUS Status;
+    KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    PORT_VIEW CapturedServerView;
+    PORT_MESSAGE CapturedReplyMessage;
     ULONG ConnectionInfoLength;
-    PLPCP_MESSAGE Message;
+    PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
     PLPCP_CONNECTION_MESSAGE ConnectMessage;
+    PLPCP_MESSAGE Message;
+    PVOID ClientSectionToMap = NULL;
+    HANDLE Handle;
     PEPROCESS ClientProcess;
     PETHREAD ClientThread;
     LARGE_INTEGER SectionOffset;
-    CLIENT_ID ClientId;
-    ULONG MessageId;
+
     PAGED_CODE();
     LPCTRACE(LPC_COMPLETE_DEBUG,
              "Context: %p. Message: %p. Accept: %lx. Views: %p/%p\n",
@@ -69,41 +70,42 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     /* Check if the call comes from user mode */
     if (PreviousMode != KernelMode)
     {
-        /* Enter SEH for probing the parameters */
         _SEH2_TRY
         {
+            /* Probe the PortHandle */
             ProbeForWriteHandle(PortHandle);
 
             /* Probe the basic ReplyMessage structure */
-            ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
-
-            /* Grab some values */
-            ClientId = ReplyMessage->ClientId;
-            MessageId = ReplyMessage->MessageId;
-            ConnectionInfoLength = ReplyMessage->u1.s1.DataLength;
+            ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
+            CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
+            ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
 
             /* Probe the connection info */
             ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1);
 
             /* The following parameters are optional */
-            if (ServerView != NULL)
+
+            /* Capture the server view */
+            if (ServerView)
             {
-                ProbeForWrite(ServerView, sizeof(PORT_VIEW), sizeof(ULONG));
+                ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
+                CapturedServerView = *(volatile PORT_VIEW*)ServerView;
 
                 /* Validate the size of the server view */
-                if (ServerView->Length != sizeof(PORT_VIEW))
+                if (CapturedServerView.Length != sizeof(CapturedServerView))
                 {
                     /* Invalid size */
                     _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
                 }
             }
 
-            if (ClientView != NULL)
+            /* Capture the client view */
+            if (ClientView)
             {
-                ProbeForWrite(ClientView, sizeof(REMOTE_PORT_VIEW), sizeof(ULONG));
+                ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
 
                 /* Validate the size of the client view */
-                if (ClientView->Length != sizeof(REMOTE_PORT_VIEW))
+                if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView))
                 {
                     /* Invalid size */
                     _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
@@ -119,28 +121,35 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     }
     else
     {
-        /* Grab some values */
-        ClientId = ReplyMessage->ClientId;
-        MessageId = ReplyMessage->MessageId;
-        ConnectionInfoLength = ReplyMessage->u1.s1.DataLength;
+        CapturedReplyMessage = *ReplyMessage;
+        ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
 
-        /* Validate the size of the server view */
-        if ((ServerView) && (ServerView->Length != sizeof(PORT_VIEW)))
+        /* Capture the server view */
+        if (ServerView)
         {
-            /* Invalid size */
-            return STATUS_INVALID_PARAMETER;
+            /* Validate the size of the server view */
+            if (ServerView->Length != sizeof(*ServerView))
+            {
+                /* Invalid size */
+                return STATUS_INVALID_PARAMETER;
+            }
+            CapturedServerView = *ServerView;
         }
 
-        /* Validate the size of the client view */
-        if ((ClientView) && (ClientView->Length != sizeof(REMOTE_PORT_VIEW)))
+        /* Capture the client view */
+        if (ClientView)
         {
-            /* Invalid size */
-            return STATUS_INVALID_PARAMETER;
+            /* Validate the size of the client view */
+            if (ClientView->Length != sizeof(*ClientView))
+            {
+                /* Invalid size */
+                return STATUS_INVALID_PARAMETER;
+            }
         }
     }
 
     /* Get the client process and thread */
-    Status = PsLookupProcessThreadByCid(&ClientId,
+    Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
                                         &ClientProcess,
                                         &ClientThread);
     if (!NT_SUCCESS(Status)) return Status;
@@ -150,8 +159,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
 
     /* Make sure that the client wants a reply, and this is the right one */
     if (!(LpcpGetMessageFromThread(ClientThread)) ||
-        !(MessageId) ||
-        (ClientThread->LpcReplyMessageId != MessageId))
+        !(CapturedReplyMessage.MessageId) ||
+        (ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId))
     {
         /* Not the reply asked for, or no reply wanted, fail */
         KeReleaseGuardedMutex(&LpcpLock);
@@ -202,8 +211,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     /* Setup the reply message */
     Message->Request.u2.s2.Type = LPC_REPLY;
     Message->Request.u2.s2.DataInfoOffset = 0;
-    Message->Request.ClientId ClientId;
-    Message->Request.MessageId = MessageId;
+    Message->Request.ClientId  = CapturedReplyMessage.ClientId;
+    Message->Request.MessageId = CapturedReplyMessage.MessageId;
     Message->Request.ClientViewSize = 0;
 
     _SEH2_TRY
@@ -309,6 +318,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     if (ServerView)
     {
         /* FIXME: TODO */
+        UNREFERENCED_PARAMETER(CapturedServerView);
         ASSERT(FALSE);
     }
 
@@ -346,7 +356,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
         /* Cleanup and return the exception code */
-        ObCloseHandle(Handle, UserMode);
+        ObCloseHandle(Handle, PreviousMode);
         ObDereferenceObject(ServerPort);
         Status = _SEH2_GetExceptionCode();
         _SEH2_YIELD(goto Cleanup);
@@ -384,8 +394,12 @@ Cleanup:
     /* Check if we got here while still having a client thread */
     if (ClientThread)
     {
-        /* FIXME: Complex cleanup code */
-        ASSERT(FALSE);
+        KeAcquireGuardedMutex(&LpcpLock);
+        ClientThread->LpcReplyMessage = Message;
+        LpcpPrepareToWakeClient(ClientThread);
+        KeReleaseGuardedMutex(&LpcpLock);
+        LpcpCompleteWait(&ClientThread->LpcReplySemaphore);
+        ObDereferenceObject(ClientThread);
     }
 
     /* Dereference the client port if we have one, and the process */
@@ -407,9 +421,10 @@ NTAPI
 NtCompleteConnectPort(IN HANDLE PortHandle)
 {
     NTSTATUS Status;
-    PLPCP_PORT_OBJECT Port;
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
+    PLPCP_PORT_OBJECT Port;
     PETHREAD Thread;
+
     PAGED_CODE();
     LPCTRACE(LPC_COMPLETE_DEBUG, "Handle: %p\n", PortHandle);
 
@@ -462,7 +477,7 @@ NtCompleteConnectPort(IN HANDLE PortHandle)
     KeReleaseGuardedMutex(&LpcpLock);
     LpcpCompleteWait(&Thread->LpcReplySemaphore);
 
-    /* Dereference the Thread and Port  and return */
+    /* Dereference the Thread and Port and return */
     ObDereferenceObject(Port);
     ObDereferenceObject(Thread);
     LPCTRACE(LPC_COMPLETE_DEBUG, "Port: %p. Thread: %p\n", Port, Thread);