[NTDLL/CSRSRV]
[reactos.git] / dll / ntdll / csr / connect.c
index dc6bb7c..63d14d3 100644 (file)
@@ -6,20 +6,19 @@
  * PROGRAMMER:      Alex Ionescu (alex@relsoft.net)
  */
 
-/* INCLUDES *****************************************************************/
+/* INCLUDES *******************************************************************/
 
 #include <ntdll.h>
 #define NDEBUG
 #include <debug.h>
 
-/* GLOBALS *******************************************************************/
+/* GLOBALS ********************************************************************/
 
 HANDLE CsrApiPort;
 HANDLE CsrProcessId;
 HANDLE CsrPortHeap;
 ULONG_PTR CsrPortMemoryDelta;
 BOOLEAN InsideCsrProcess = FALSE;
-BOOLEAN UsingOldCsr = TRUE;
 
 typedef NTSTATUS
 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
@@ -28,9 +27,8 @@ typedef NTSTATUS
 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
 
 #define UNICODE_PATH_SEP L"\\"
-#define CSR_PORT_NAME L"ApiPort"
 
-/* FUNCTIONS *****************************************************************/
+/* FUNCTIONS ******************************************************************/
 
 /*
  * @implemented
@@ -47,131 +45,133 @@ CsrGetProcessId(VOID)
  */
 NTSTATUS 
 NTAPI
-CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,
-                    PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
-                    CSR_API_NUMBER ApiNumber,
-                    ULONG RequestLength)
+CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
+                    IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
+                    IN CSR_API_NUMBER ApiNumber,
+                    IN ULONG DataLength)
 {
     NTSTATUS Status;
     ULONG PointerCount;
-    PULONG_PTR Pointers;
-    ULONG_PTR CurrentPointer;
-    DPRINT("CsrClientCallServer\n");
+    PULONG_PTR OffsetPointer;
 
-    /* Fill out the Port Message Header */
+    /* Fill out the Port Message Header. */
     ApiMessage->Header.u2.ZeroInit = 0;
-    ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);
-    ApiMessage->Header.u1.s1.TotalLength = RequestLength;
+    ApiMessage->Header.u1.s1.TotalLength =
+        FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
+    ApiMessage->Header.u1.s1.DataLength =
+        ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
 
-    /* Fill out the CSR Header */
-    ApiMessage->Type = ApiNumber;
-    //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR
+    /* Fill out the CSR Header. */
+    ApiMessage->ApiNumber = ApiNumber;
     ApiMessage->CsrCaptureData = NULL;
 
-    DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n", 
+    DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
            ApiNumber,
            ApiMessage->Header.u1.s1.DataLength,
            ApiMessage->Header.u1.s1.TotalLength);
                 
-    /* Check if we are already inside a CSR Server */
+    /* Check if we are already inside a CSR Server. */
     if (!InsideCsrProcess)
     {
-        /* Check if we got a a Capture Buffer */
+        /* Check if we got a Capture Buffer. */
         if (CaptureBuffer)
         {
-            /* We have to convert from our local view to the remote view */
-            ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
-                                                 CsrPortMemoryDelta);
-
-            /* Lock the buffer */
-            CaptureBuffer->BufferEnd = 0;
-
-            /* Get the pointer information */
-            PointerCount = CaptureBuffer->PointerCount;
-            Pointers = CaptureBuffer->PointerArray;
-
-            /* Loop through every pointer and convert it */
-            DPRINT("PointerCount: %lx\n", PointerCount);
+            /*
+             * We have to convert from our local (client) view
+             * to the remote (server) view.
+             */
+            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+                ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
+
+            /* Lock the buffer. */
+            CaptureBuffer->BufferEnd = NULL;
+
+            /*
+             * Each client pointer inside the CSR message is converted into
+             * a server pointer, and each pointer to these message pointers
+             * is converted into an offset.
+             */
+            PointerCount  = CaptureBuffer->PointerCount;
+            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
             while (PointerCount--)
             {
-                /* Get this pointer and check if it's valid */
-                DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n",
-                        &Pointers, Pointers, *Pointers);
-                if ((CurrentPointer = *Pointers++))
+                if (*OffsetPointer != 0)
                 {
-                    /* Update it */
-                    DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
-                    *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
-                    Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
-                    DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
+                    *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
+                    *OffsetPointer -= (ULONG_PTR)ApiMessage;
                 }
+                ++OffsetPointer;
             }
         }
 
-        /* Send the LPC Message */
+        /* Send the LPC Message. */
         Status = NtRequestWaitReplyPort(CsrApiPort,
                                         &ApiMessage->Header,
                                         &ApiMessage->Header);
 
-        /* Check if we got a a Capture Buffer */
+        /* Check if we got a Capture Buffer. */
         if (CaptureBuffer)
         {
-            /* We have to convert from the remote view to our remote view */
-            DPRINT("Reconverting CaptureBuffer\n");
-            ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
-                                                 ApiMessage->CsrCaptureData -
-                                                 CsrPortMemoryDelta);
-
-            /* Get the pointer information */
-            PointerCount = CaptureBuffer->PointerCount;
-            Pointers = CaptureBuffer->PointerArray;
-
-            /* Loop through every pointer and convert it */
+            /*
+             * We have to convert back from the remote (server) view
+             * to our local (client) view.
+             */
+            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+                ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
+
+            /*
+             * Convert back the offsets into pointers to CSR message
+             * pointers, and convert back these message server pointers
+             * into client pointers.
+             */
+            PointerCount  = CaptureBuffer->PointerCount;
+            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
             while (PointerCount--)
             {
-                /* Get this pointer and check if it's valid */
-                if ((CurrentPointer = *Pointers++))
+                if (*OffsetPointer != 0)
                 {
-                    /* Update it */
-                    CurrentPointer += (ULONG_PTR)ApiMessage;
-                    Pointers[-1] = CurrentPointer;
-                    *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
+                    *OffsetPointer += (ULONG_PTR)ApiMessage;
+                    *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
                 }
+                ++OffsetPointer;
             }
         }
 
-        /* Check for success */
+        /* Check for success. */
         if (!NT_SUCCESS(Status))
         {
-            /* We failed. Overwrite the return value with the failure */
+            /* We failed. Overwrite the return value with the failure. */
             DPRINT1("LPC Failed: %lx\n", Status);
             ApiMessage->Status = Status;
         }
     }
     else
     {
-        /* This is a server-to-server call. Save our CID and do a direct call */
-        DbgBreakPoint();
+        /* This is a server-to-server call. Save our CID and do a direct call. */
+        DPRINT1("Next gen server-to-server call\n");
+
+        /* We check this equality inside CsrValidateMessageBuffer. */
         ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
+
         Status = CsrServerApiRoutine(&ApiMessage->Header,
                                      &ApiMessage->Header);
-       
-        /* Check for success */
+
+        /* Check for success. */
         if (!NT_SUCCESS(Status))
         {
-            /* We failed. Overwrite the return value with the failure */
+            /* We failed. Overwrite the return value with the failure. */
             ApiMessage->Status = Status;
         }
     }
 
-    /* Return the CSR Result */
+    /* Return the CSR Result. */
     DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
     return ApiMessage->Status;
 }
 
 NTSTATUS
 NTAPI
-CsrConnectToServer(IN PWSTR ObjectDirectory)
+CsrpConnectToServer(IN PWSTR ObjectDirectory)
 {
     ULONG PortNameLength;
     UNICODE_STRING PortName;
@@ -203,7 +203,7 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
     PortName.MaximumLength = PortNameLength;
 
     /* Allocate a buffer for it */
-    PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
+    PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
     if (PortName.Buffer == NULL)
     {
         return STATUS_INSUFFICIENT_RESOURCES;
@@ -307,6 +307,8 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
                                 0);
     if (CsrPortHeap == NULL)
     {
+        /* Failure */
+        DPRINT1("Couldn't create heap for CSR port\n");
         NtClose(CsrApiPort);
         CsrApiPort = NULL;
         return STATUS_INSUFFICIENT_RESOURCES;
@@ -321,24 +323,24 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
  */
 NTSTATUS
 NTAPI
-CsrClientConnectToServer(PWSTR ObjectDirectory,
-                         ULONG ServerId,
-                         PVOID ConnectionInfo,
-                         PULONG ConnectionInfoSize,
-                         PBOOLEAN ServerToServerCall)
+CsrClientConnectToServer(IN PWSTR ObjectDirectory,
+                         IN ULONG ServerId,
+                         IN PVOID ConnectionInfo,
+                         IN OUT PULONG ConnectionInfoSize,
+                         OUT PBOOLEAN ServerToServerCall)
 {
     NTSTATUS Status;
     PIMAGE_NT_HEADERS NtHeader;
     UNICODE_STRING CsrSrvName;
     HANDLE hCsrSrv;
     ANSI_STRING CsrServerRoutineName;
+    CSR_API_MESSAGE ApiMessage;
+    PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
     PCSR_CAPTURE_BUFFER CaptureBuffer;
-    CSR_API_MESSAGE RosApiMessage;
-    CSR_API_MESSAGE2 ApiMessage;
-    PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;
 
-    /* Validate the Connection Info */
     DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
+
+    /* Validate the Connection Info */
     if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
     {
         DPRINT1("Connection info given, but no length\n");
@@ -366,10 +368,10 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
     InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
 
     /* Now we can check if we are inside or not */
-    if (InsideCsrProcess && !UsingOldCsr)
+    if (InsideCsrProcess)
     {
         /* We're inside, so let's find csrsrv */
-        DbgBreakPoint();
+        DPRINT1("Next-GEN CSRSS support\n");
         RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
         Status = LdrGetDllHandle(NULL,
                                  NULL,
@@ -401,7 +403,7 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
         if (!CsrApiPort)
         {
             /* No, set it up now */
-            if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))
+            if (!NT_SUCCESS(Status = CsrpConnectToServer(ObjectDirectory)))
             {
                 /* Failed */
                 DPRINT1("Failure to connect to CSR\n");
@@ -414,38 +416,34 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
         ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
 
         /* Setup a buffer for the connection info */
-        CaptureBuffer = CsrAllocateCaptureBuffer(1,
-                                                 ClientConnect->ConnectionInfoSize);
+        CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
         if (CaptureBuffer == NULL)
         {
             return STATUS_INSUFFICIENT_RESOURCES;
         }
 
-        /* Allocate a pointer for the connection info*/
-        CsrAllocateMessagePointer(CaptureBuffer,
-                                  ClientConnect->ConnectionInfoSize,
-                                  &ClientConnect->ConnectionInfo);
-
-        /* Copy the data into the buffer */
-        RtlMoveMemory(ClientConnect->ConnectionInfo,
-                      ConnectionInfo,
-                      ClientConnect->ConnectionInfoSize);
+        /* Capture the connection info data */
+        CsrCaptureMessageBuffer(CaptureBuffer,
+                                ConnectionInfo,
+                                ClientConnect->ConnectionInfoSize,
+                                &ClientConnect->ConnectionInfo);
 
         /* Return the allocated length */
         *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
 
         /* Call CSR */
-#if 0
         Status = CsrClientCallServer(&ApiMessage,
                                      CaptureBuffer,
-                                     CSR_MAKE_OPCODE(CsrSrvClientConnect,
-                                                     CSR_SRV_DLL),
+                                     CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
                                      sizeof(CSR_CLIENT_CONNECT));
-#endif
-        Status = CsrClientCallServer(&RosApiMessage,
-                                     NULL,
-                                     MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),
-                                     sizeof(CSR_API_MESSAGE));
+
+        /* Copy the updated connection info data back into the user buffer */
+        RtlMoveMemory(ConnectionInfo,
+                      ClientConnect->ConnectionInfo,
+                      *ConnectionInfoSize);
+
+        /* Free the capture buffer */
+        CsrFreeCaptureBuffer(CaptureBuffer);
     }
     else
     {
@@ -456,6 +454,7 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
     /* Let the caller know if this was server to server */
     DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
     if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
+
     return Status;
 }