- Use the actual previous mode (spotted by Thomas)
[reactos.git] / reactos / ntoskrnl / lpc / connect.c
index 9bb82a7..3ef7025 100644 (file)
 #define NDEBUG
 #include <internal/debug.h>
 
-/* GLOBALS *******************************************************************/
-
-#define TAG_LPC_CONNECT_MESSAGE   TAG('L', 'P', 'C', 'C')
-
 /* FUNCTIONS *****************************************************************/
 
 /**********************************************************************
@@ -254,26 +250,77 @@ NtConnectPort (PHANDLE                            UnsafeConnectedPortHandle,
   LPC_SECTION_WRITE WriteMap;
   LPC_SECTION_READ ReadMap;
   ULONG MaximumMessageSize;
-  PVOID ConnectData;
-  ULONG ConnectDataLength;
+  PVOID ConnectData = NULL;
+  ULONG ConnectDataLength = 0;
   PSECTION_OBJECT SectionObject;
   LARGE_INTEGER SectionOffset;
   PEPORT ConnectedPort;
-  NTSTATUS Status;
+  KPROCESSOR_MODE PreviousMode;
+  NTSTATUS Status = STATUS_SUCCESS;
   PEPORT NamedPort;
+  
+  PreviousMode = ExGetPreviousMode();
+  
+  if (PreviousMode != KernelMode)
+    {
+      _SEH_TRY
+        {
+          ProbeForWrite(UnsafeConnectedPortHandle,
+                        sizeof(HANDLE),
+                        sizeof(ULONG));
+          if (UnsafeMaximumMessageSize != NULL)
+            {
+              ProbeForWrite(UnsafeMaximumMessageSize,
+                            sizeof(ULONG),
+                            sizeof(ULONG));
+            }
+        }
+      _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+      _SEH_END;
+      
+      if (!NT_SUCCESS(Status))
+        {
+          return Status;
+        }
+    }
 
   /*
    * Copy in write map and partially validate.
    */
   if (UnsafeWriteMap != NULL)
     {
-      Status = MmCopyFromCaller(&WriteMap,
-                               UnsafeWriteMap,
-                               sizeof(LPC_SECTION_WRITE));
-      if (!NT_SUCCESS(Status))
-       {
-         return(Status);
-       }
+      if (PreviousMode != KernelMode)
+        {
+          _SEH_TRY
+            {
+              ProbeForWrite(UnsafeWriteMap,
+                            sizeof(LPC_SECTION_WRITE),
+                            1);
+              RtlCopyMemory(&WriteMap,
+                            UnsafeWriteMap,
+                            sizeof(LPC_SECTION_WRITE));
+            }
+          _SEH_HANDLE
+            {
+              Status = _SEH_GetExceptionCode();
+            }
+          _SEH_END;
+
+          if (!NT_SUCCESS(Status))
+            {
+              return Status;
+            }
+        }
+      else
+        {
+          RtlCopyMemory(&WriteMap,
+                        UnsafeWriteMap,
+                        sizeof(LPC_SECTION_WRITE));
+        }
+
       if (WriteMap.Length != sizeof(LPC_SECTION_WRITE))
        {
          return(STATUS_INVALID_PARAMETER_4);
@@ -288,41 +335,71 @@ NtConnectPort (PHANDLE                            UnsafeConnectedPortHandle,
   /*
    * Handle connection data.
    */
-  if (UnsafeConnectData == NULL)
-    {
-      ConnectDataLength = 0;
-      ConnectData = NULL;
-    }
-  else
-    {
-      if (ExGetPreviousMode() == KernelMode)
-       {
-         ConnectDataLength = *UnsafeConnectDataLength;
-         ConnectData = UnsafeConnectData;
-       }
+  if (UnsafeConnectData)
+    {
+      if (PreviousMode != KernelMode)
+        {
+          _SEH_TRY
+            {
+              ProbeForRead(UnsafeConnectDataLength,
+                           sizeof(ULONG),
+                           1);
+              ConnectDataLength = *UnsafeConnectDataLength;
+            }
+          _SEH_HANDLE
+            {
+              Status = _SEH_GetExceptionCode();
+            }
+          _SEH_END;
+
+          if (!NT_SUCCESS(Status))
+            {
+              return Status;
+            }
+        }
       else
-       {
-         Status = MmCopyFromCaller(&ConnectDataLength,
-                                   UnsafeConnectDataLength,
-                                   sizeof(ULONG));
-         if (!NT_SUCCESS(Status))
-           {
-             return(Status);
-           }
-         ConnectData = ExAllocatePool(NonPagedPool, ConnectDataLength);
-         if (ConnectData == NULL && ConnectDataLength != 0)
-           {
-             return(STATUS_NO_MEMORY);
-           }
-         Status = MmCopyFromCaller(ConnectData,
-                                   UnsafeConnectData,
-                                   ConnectDataLength);
-         if (!NT_SUCCESS(Status))
-           {
-             ExFreePool(ConnectData);
-             return(Status);
-           }
-       }
+        {
+          ConnectDataLength = *UnsafeConnectDataLength;
+        }
+
+      if (ConnectDataLength != 0)
+        {
+          ConnectData = ExAllocatePool(NonPagedPool, ConnectDataLength);
+          if (ConnectData == NULL)
+            {
+              return(STATUS_NO_MEMORY);
+            }
+
+          if (PreviousMode != KernelMode)
+            {
+              _SEH_TRY
+                {
+                  ProbeForWrite(UnsafeConnectData,
+                                ConnectDataLength,
+                                1);
+                  RtlCopyMemory(ConnectData,
+                                UnsafeConnectData,
+                                ConnectDataLength);
+                }
+              _SEH_HANDLE
+                {
+                  Status = _SEH_GetExceptionCode();
+                }
+              _SEH_END;
+
+              if (!NT_SUCCESS(Status))
+                {
+                  ExFreePool(ConnectData);
+                  return Status;
+                }
+            }
+          else
+            {
+              RtlCopyMemory(ConnectData,
+                            UnsafeConnectData,
+                            ConnectDataLength);
+            }
+        }
     }
 
   /*
@@ -333,7 +410,7 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
                                     NULL,
                                     PORT_ALL_ACCESS,  /* DesiredAccess */
                                     LpcPortObjectType,
-                                    UserMode,
+                                    PreviousMode,
                                     NULL,
                                     (PVOID*)&NamedPort);
   if (!NT_SUCCESS(Status))
@@ -353,7 +430,7 @@ NtConnectPort (PHANDLE                              UnsafeConnectedPortHandle,
       Status = ObReferenceObjectByHandle(WriteMap.SectionHandle,
                                         SECTION_MAP_READ | SECTION_MAP_WRITE,
                                         MmSectionObjectType,
-                                        UserMode,
+                                        PreviousMode,
                                         (PVOID*)&SectionObject,
                                         NULL);
       if (!NT_SUCCESS(Status))
@@ -391,16 +468,30 @@ NtConnectPort (PHANDLE                            UnsafeConnectedPortHandle,
       /* FIXME: Again, check what NT does here. */
       if (UnsafeConnectDataLength != NULL)
        {
-         if (ExGetPreviousMode() != KernelMode)
+         if (PreviousMode != KernelMode)
            {
-             MmCopyToCaller(UnsafeConnectData,
-                            ConnectData,
-                            ConnectDataLength);
-             ExFreePool(ConnectData);
+              _SEH_TRY
+                {
+                  RtlCopyMemory(UnsafeConnectData,
+                                ConnectData,
+                                ConnectDataLength);
+                  *UnsafeConnectDataLength = ConnectDataLength;
+                }
+              _SEH_HANDLE
+                {
+                  Status = _SEH_GetExceptionCode();
+                }
+              _SEH_END;
+           }
+         else
+           {
+               RtlCopyMemory(UnsafeConnectData,
+                             ConnectData,
+                             ConnectDataLength);
+               *UnsafeConnectDataLength = ConnectDataLength;
            }
-         MmCopyToCaller(UnsafeConnectDataLength,
-                        &ConnectDataLength,
-                        sizeof(ULONG));
+
+          ExFreePool(ConnectData);
        }
       return(Status);
     }
@@ -419,28 +510,52 @@ NtConnectPort (PHANDLE                            UnsafeConnectedPortHandle,
   /*
    * Copy the data back to the caller.
    */
-  if (ExGetPreviousMode() != KernelMode)
+
+  if (UnsafeConnectDataLength != NULL)
     {
-      if (UnsafeConnectDataLength != NULL)
+      if (PreviousMode != KernelMode)
        {
-         Status = MmCopyToCaller(UnsafeConnectDataLength,
-                                 &ConnectDataLength,
-                                 sizeof(ULONG));
+          _SEH_TRY
+            {
+              *UnsafeConnectDataLength = ConnectDataLength;
+              
+              if (ConnectData != NULL)
+                {
+                  RtlCopyMemory(UnsafeConnectData,
+                                ConnectData,
+                                ConnectDataLength);
+                }
+            }
+          _SEH_HANDLE
+            {
+              Status = _SEH_GetExceptionCode();
+            }
+          _SEH_END;
+
          if (!NT_SUCCESS(Status))
            {
-             return(Status);
+              if (ConnectData != NULL)
+                {
+                  ExFreePool(ConnectData);
+                }
+              return(Status);
            }
        }
-      if (UnsafeConnectData != NULL && ConnectData != NULL)
+      else
+        {
+          *UnsafeConnectDataLength = ConnectDataLength;
+          
+          if (ConnectData != NULL)
+            {
+              RtlCopyMemory(UnsafeConnectData,
+                            ConnectData,
+                            ConnectDataLength);
+            }
+        }
+
+      if (ConnectData != NULL)
        {
-         Status = MmCopyToCaller(UnsafeConnectData,
-                                     ConnectData,
-                                     ConnectDataLength);
          ExFreePool(ConnectData);
-         if (!NT_SUCCESS(Status))
-           {
-             return(Status);
-           }
        }
     }
   Status = ObInsertObject(ConnectedPort,
@@ -453,42 +568,65 @@ NtConnectPort (PHANDLE                            UnsafeConnectedPortHandle,
     {
       return(Status);
     }
-  Status = MmCopyToCaller(UnsafeConnectedPortHandle,
-                         &ConnectedPortHandle,
-                         sizeof(HANDLE));
-  if (!NT_SUCCESS(Status))
-    {
-      return(Status);
-    }
-  if (UnsafeWriteMap != NULL)
-    {
-      Status = MmCopyToCaller(UnsafeWriteMap,
-                             &WriteMap,
-                             sizeof(LPC_SECTION_WRITE));
-      if (!NT_SUCCESS(Status))
-       {
-         return(Status);
-       }
-    }
-  if (UnsafeReadMap != NULL)
-    {
-      Status = MmCopyToCaller(UnsafeReadMap,
-                             &ReadMap,
-                             sizeof(LPC_SECTION_READ));
+
+  if (PreviousMode != KernelMode)
+    {
+      _SEH_TRY
+        {
+          *UnsafeConnectedPortHandle = ConnectedPortHandle;
+          
+          if (UnsafeWriteMap != NULL)
+            {
+              RtlCopyMemory(UnsafeWriteMap,
+                            &WriteMap,
+                            sizeof(LPC_SECTION_WRITE));
+            }
+
+          if (UnsafeReadMap != NULL)
+            {
+              RtlCopyMemory(UnsafeReadMap,
+                            &ReadMap,
+                            sizeof(LPC_SECTION_READ));
+            }
+
+          if (UnsafeMaximumMessageSize != NULL)
+            {
+              *UnsafeMaximumMessageSize = MaximumMessageSize;
+            }
+        }
+      _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+      _SEH_END;
+      
       if (!NT_SUCCESS(Status))
-       {
-         return(Status);
-       }
+        {
+          return Status;
+        }
     }
-  if (UnsafeMaximumMessageSize != NULL)
+  else
     {
-      Status = MmCopyToCaller(UnsafeMaximumMessageSize,
-                             &MaximumMessageSize,
-                             sizeof(ULONG));
-      if (!NT_SUCCESS(Status))
-       {
-         return(Status);
-       }
+      *UnsafeConnectedPortHandle = ConnectedPortHandle;
+      
+      if (UnsafeWriteMap != NULL)
+        {
+          RtlCopyMemory(UnsafeWriteMap,
+                        &WriteMap,
+                        sizeof(LPC_SECTION_WRITE));
+        }
+
+      if (UnsafeReadMap != NULL)
+        {
+          RtlCopyMemory(UnsafeReadMap,
+                        &ReadMap,
+                        sizeof(LPC_SECTION_READ));
+        }
+
+      if (UnsafeMaximumMessageSize != NULL)
+        {
+          *UnsafeMaximumMessageSize = MaximumMessageSize;
+        }
     }
 
   /*
@@ -531,6 +669,7 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
   PEPORT_CONNECT_REQUEST_MESSAGE CRequest;
   PEPORT_CONNECT_REPLY_MESSAGE CReply;
   ULONG Size;
+  KPROCESSOR_MODE PreviousMode = ExGetPreviousMode();
 
   Size = sizeof(EPORT_CONNECT_REPLY_MESSAGE);
   if (LpcMessage)
@@ -547,7 +686,7 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
   Status = ObReferenceObjectByHandle(NamedPortHandle,
                                     PORT_ALL_ACCESS,
                                     LpcPortObjectType,
-                                    UserMode,
+                                    PreviousMode,
                                     (PVOID*)&NamedPort,
                                     NULL);
   if (!NT_SUCCESS(Status))
@@ -561,10 +700,10 @@ NtAcceptConnectPort (PHANDLE                      ServerPortHandle,
    */
   if (AcceptIt)
     {
-      Status = ObCreateObject(ExGetPreviousMode(),
+      Status = ObCreateObject(PreviousMode,
                              LpcPortObjectType,
                              NULL,
-                             ExGetPreviousMode(),
+                             PreviousMode,
                              NULL,
                              sizeof(EPORT),
                              0,
@@ -651,7 +790,7 @@ NtAcceptConnectPort (PHANDLE                        ServerPortHandle,
       Status = ObReferenceObjectByHandle(WriteMap->SectionHandle,
                                         SECTION_MAP_READ | SECTION_MAP_WRITE,
                                         MmSectionObjectType,
-                                        UserMode,
+                                        PreviousMode,
                                         (PVOID*)&SectionObject,
                                         NULL);
       if (!NT_SUCCESS(Status))