[AVIFIL32]
[reactos.git] / reactos / dll / ntdll / csr / connect.c
index dc6bb7c..c2d2ac8 100644 (file)
@@ -1,25 +1,24 @@
 /*
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
- * FILE:            lib/ntdll/csr/connect.c
+ * FILE:            dll/ntdll/csr/connect.c
  * PURPOSE:         Routines for connecting and calling CSR
  * PROGRAMMER:      Alex Ionescu (alex@relsoft.net)
  */
 
-/* INCLUDES *****************************************************************/
+/* INCLUDES *******************************************************************/
 
 #include <ntdll.h>
 #define NDEBUG
 #include <debug.h>
 
-/* GLOBALS *******************************************************************/
+/* GLOBALS ********************************************************************/
 
 HANDLE CsrApiPort;
 HANDLE CsrProcessId;
 HANDLE CsrPortHeap;
 ULONG_PTR CsrPortMemoryDelta;
 BOOLEAN InsideCsrProcess = FALSE;
-BOOLEAN UsingOldCsr = TRUE;
 
 typedef NTSTATUS
 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
@@ -28,163 +27,25 @@ typedef NTSTATUS
 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
 
 #define UNICODE_PATH_SEP L"\\"
-#define CSR_PORT_NAME L"ApiPort"
 
-/* FUNCTIONS *****************************************************************/
-
-/*
- * @implemented
- */
-HANDLE
-NTAPI
-CsrGetProcessId(VOID)
-{
-    return CsrProcessId;
-}
-
-/*
- * @implemented
- */
-NTSTATUS 
-NTAPI
-CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,
-                    PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
-                    CSR_API_NUMBER ApiNumber,
-                    ULONG RequestLength)
-{
-    NTSTATUS Status;
-    ULONG PointerCount;
-    PULONG_PTR Pointers;
-    ULONG_PTR CurrentPointer;
-    DPRINT("CsrClientCallServer\n");
-
-    /* Fill out the Port Message Header */
-    ApiMessage->Header.u2.ZeroInit = 0;
-    ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);
-    ApiMessage->Header.u1.s1.TotalLength = RequestLength;
-
-    /* Fill out the CSR Header */
-    ApiMessage->Type = ApiNumber;
-    //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR
-    ApiMessage->CsrCaptureData = NULL;
-
-    DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n", 
-           ApiNumber,
-           ApiMessage->Header.u1.s1.DataLength,
-           ApiMessage->Header.u1.s1.TotalLength);
-                
-    /* Check if we are already inside a CSR Server */
-    if (!InsideCsrProcess)
-    {
-        /* Check if we got a a Capture Buffer */
-        if (CaptureBuffer)
-        {
-            /* We have to convert from our local view to the remote view */
-            ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
-                                                 CsrPortMemoryDelta);
-
-            /* Lock the buffer */
-            CaptureBuffer->BufferEnd = 0;
-
-            /* Get the pointer information */
-            PointerCount = CaptureBuffer->PointerCount;
-            Pointers = CaptureBuffer->PointerArray;
-
-            /* Loop through every pointer and convert it */
-            DPRINT("PointerCount: %lx\n", PointerCount);
-            while (PointerCount--)
-            {
-                /* Get this pointer and check if it's valid */
-                DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n",
-                        &Pointers, Pointers, *Pointers);
-                if ((CurrentPointer = *Pointers++))
-                {
-                    /* Update it */
-                    DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
-                    *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
-                    Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
-                    DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
-                }
-            }
-        }
-
-        /* Send the LPC Message */
-        Status = NtRequestWaitReplyPort(CsrApiPort,
-                                        &ApiMessage->Header,
-                                        &ApiMessage->Header);
-
-        /* Check if we got a a Capture Buffer */
-        if (CaptureBuffer)
-        {
-            /* We have to convert from the remote view to our remote view */
-            DPRINT("Reconverting CaptureBuffer\n");
-            ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
-                                                 ApiMessage->CsrCaptureData -
-                                                 CsrPortMemoryDelta);
-
-            /* Get the pointer information */
-            PointerCount = CaptureBuffer->PointerCount;
-            Pointers = CaptureBuffer->PointerArray;
-
-            /* Loop through every pointer and convert it */
-            while (PointerCount--)
-            {
-                /* Get this pointer and check if it's valid */
-                if ((CurrentPointer = *Pointers++))
-                {
-                    /* Update it */
-                    CurrentPointer += (ULONG_PTR)ApiMessage;
-                    Pointers[-1] = CurrentPointer;
-                    *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
-                }
-            }
-        }
-
-        /* Check for success */
-        if (!NT_SUCCESS(Status))
-        {
-            /* We failed. Overwrite the return value with the failure */
-            DPRINT1("LPC Failed: %lx\n", Status);
-            ApiMessage->Status = Status;
-        }
-    }
-    else
-    {
-        /* This is a server-to-server call. Save our CID and do a direct call */
-        DbgBreakPoint();
-        ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
-        Status = CsrServerApiRoutine(&ApiMessage->Header,
-                                     &ApiMessage->Header);
-       
-        /* Check for success */
-        if (!NT_SUCCESS(Status))
-        {
-            /* We failed. Overwrite the return value with the failure */
-            ApiMessage->Status = Status;
-        }
-    }
-
-    /* Return the CSR Result */
-    DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
-    return ApiMessage->Status;
-}
+/* FUNCTIONS ******************************************************************/
 
 NTSTATUS
 NTAPI
-CsrConnectToServer(IN PWSTR ObjectDirectory)
+CsrpConnectToServer(IN PWSTR ObjectDirectory)
 {
+    NTSTATUS Status;
     ULONG PortNameLength;
     UNICODE_STRING PortName;
     LARGE_INTEGER CsrSectionViewSize;
-    NTSTATUS Status;
     HANDLE CsrSectionHandle;
     PORT_VIEW LpcWrite;
     REMOTE_PORT_VIEW LpcRead;
     SECURITY_QUALITY_OF_SERVICE SecurityQos;
     SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
     PSID SystemSid = NULL;
-    CSR_CONNECTION_INFO ConnectionInfo;
-    ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
+    CSR_API_CONNECTINFO ConnectionInfo;
+    ULONG ConnectionInfoLength = sizeof(CSR_API_CONNECTINFO);
 
     DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
 
@@ -203,7 +64,7 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
     PortName.MaximumLength = PortNameLength;
 
     /* Allocate a buffer for it */
-    PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
+    PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
     if (PortName.Buffer == NULL)
     {
         return STATUS_INSUFFICIENT_RESOURCES;
@@ -221,7 +82,7 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
                              NULL,
                              &CsrSectionViewSize,
                              PAGE_READWRITE,
-                             SEC_COMMIT,
+                             SEC_RESERVE,
                              NULL);
     if (!NT_SUCCESS(Status))
     {
@@ -246,20 +107,20 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
     SecurityQos.EffectiveOnly = TRUE;
 
     /* Setup the connection info */
-    ConnectionInfo.Version = 0x10000;
+    ConnectionInfo.DebugFlags = 0;
 
     /* Create a SID for us */
     Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
-                                          1,
-                                          SECURITY_LOCAL_SYSTEM_RID,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          &SystemSid);
+                                         1,
+                                         SECURITY_LOCAL_SYSTEM_RID,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         &SystemSid);
     if (!NT_SUCCESS(Status))
     {
         /* Failure */
@@ -278,6 +139,7 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
                                  NULL,
                                  &ConnectionInfo,
                                  &ConnectionInfoLength);
+    RtlFreeSid(SystemSid);
     NtClose(CsrSectionHandle);
     if (!NT_SUCCESS(Status))
     {
@@ -291,12 +153,12 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
                          (ULONG_PTR)LpcWrite.ViewBase;
 
     /* Save the Process */
-    CsrProcessId = ConnectionInfo.ProcessId;
+    CsrProcessId = ConnectionInfo.ServerProcessId;
 
     /* Save CSR Section data */
     NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
     NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
-    NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
+    NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
 
     /* Create the port heap */
     CsrPortHeap = RtlCreateHeap(0,
@@ -307,6 +169,8 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
                                 0);
     if (CsrPortHeap == NULL)
     {
+        /* Failure */
+        DPRINT1("Couldn't create heap for CSR port\n");
         NtClose(CsrApiPort);
         CsrApiPort = NULL;
         return STATUS_INSUFFICIENT_RESOURCES;
@@ -321,24 +185,24 @@ CsrConnectToServer(IN PWSTR ObjectDirectory)
  */
 NTSTATUS
 NTAPI
-CsrClientConnectToServer(PWSTR ObjectDirectory,
-                         ULONG ServerId,
-                         PVOID ConnectionInfo,
-                         PULONG ConnectionInfoSize,
-                         PBOOLEAN ServerToServerCall)
+CsrClientConnectToServer(IN PWSTR ObjectDirectory,
+                         IN ULONG ServerId,
+                         IN PVOID ConnectionInfo,
+                         IN OUT PULONG ConnectionInfoSize,
+                         OUT PBOOLEAN ServerToServerCall)
 {
     NTSTATUS Status;
     PIMAGE_NT_HEADERS NtHeader;
     UNICODE_STRING CsrSrvName;
     HANDLE hCsrSrv;
     ANSI_STRING CsrServerRoutineName;
+    CSR_API_MESSAGE ApiMessage;
+    PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
     PCSR_CAPTURE_BUFFER CaptureBuffer;
-    CSR_API_MESSAGE RosApiMessage;
-    CSR_API_MESSAGE2 ApiMessage;
-    PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;
 
-    /* Validate the Connection Info */
     DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
+
+    /* Validate the Connection Info */
     if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
     {
         DPRINT1("Connection info given, but no length\n");
@@ -366,10 +230,10 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
     InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
 
     /* Now we can check if we are inside or not */
-    if (InsideCsrProcess && !UsingOldCsr)
+    if (InsideCsrProcess)
     {
         /* We're inside, so let's find csrsrv */
-        DbgBreakPoint();
+        DPRINT("Next-GEN CSRSS support\n");
         RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
         Status = LdrGetDllHandle(NULL,
                                  NULL,
@@ -401,7 +265,8 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
         if (!CsrApiPort)
         {
             /* No, set it up now */
-            if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))
+            Status = CsrpConnectToServer(ObjectDirectory);
+            if (!NT_SUCCESS(Status))
             {
                 /* Failed */
                 DPRINT1("Failure to connect to CSR\n");
@@ -414,38 +279,34 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
         ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
 
         /* Setup a buffer for the connection info */
-        CaptureBuffer = CsrAllocateCaptureBuffer(1,
-                                                 ClientConnect->ConnectionInfoSize);
+        CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
         if (CaptureBuffer == NULL)
         {
             return STATUS_INSUFFICIENT_RESOURCES;
         }
 
-        /* Allocate a pointer for the connection info*/
-        CsrAllocateMessagePointer(CaptureBuffer,
-                                  ClientConnect->ConnectionInfoSize,
-                                  &ClientConnect->ConnectionInfo);
-
-        /* Copy the data into the buffer */
-        RtlMoveMemory(ClientConnect->ConnectionInfo,
-                      ConnectionInfo,
-                      ClientConnect->ConnectionInfoSize);
+        /* Capture the connection info data */
+        CsrCaptureMessageBuffer(CaptureBuffer,
+                                ConnectionInfo,
+                                ClientConnect->ConnectionInfoSize,
+                                &ClientConnect->ConnectionInfo);
 
         /* Return the allocated length */
         *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
 
         /* Call CSR */
-#if 0
         Status = CsrClientCallServer(&ApiMessage,
                                      CaptureBuffer,
-                                     CSR_MAKE_OPCODE(CsrSrvClientConnect,
-                                                     CSR_SRV_DLL),
+                                     CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
                                      sizeof(CSR_CLIENT_CONNECT));
-#endif
-        Status = CsrClientCallServer(&RosApiMessage,
-                                     NULL,
-                                     MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),
-                                     sizeof(CSR_API_MESSAGE));
+
+        /* Copy the updated connection info data back into the user buffer */
+        RtlMoveMemory(ConnectionInfo,
+                      ClientConnect->ConnectionInfo,
+                      *ConnectionInfoSize);
+
+        /* Free the capture buffer */
+        CsrFreeCaptureBuffer(CaptureBuffer);
     }
     else
     {
@@ -456,7 +317,174 @@ CsrClientConnectToServer(PWSTR ObjectDirectory,
     /* Let the caller know if this was server to server */
     DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
     if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
+
     return Status;
 }
 
+#if 0 
+//
+// Structures can be padded at the end, causing the size of the entire structure
+// minus the size of the last field, not to be equal to the offset of the last
+// field.
+//
+typedef struct _TEST_EMBEDDED
+{
+    ULONG One;
+    ULONG Two;
+    ULONG Three;
+} TEST_EMBEDDED;
+
+typedef struct _TEST
+{
+    PORT_MESSAGE h;
+    TEST_EMBEDDED Three;
+} TEST;
+
+C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
+C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
+C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
+
+C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
+C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
+#endif
+
+/*
+ * @implemented
+ */
+NTSTATUS 
+NTAPI
+CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
+                    IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
+                    IN CSR_API_NUMBER ApiNumber,
+                    IN ULONG DataLength)
+{
+    NTSTATUS Status;
+    ULONG PointerCount;
+    PULONG_PTR OffsetPointer;
+
+    /* Fill out the Port Message Header */
+    ApiMessage->Header.u2.ZeroInit = 0;
+    ApiMessage->Header.u1.s1.TotalLength = DataLength +
+        sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data); // FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
+    ApiMessage->Header.u1.s1.DataLength = DataLength +
+        FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header);// ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+
+    /* Fill out the CSR Header */
+    ApiMessage->ApiNumber = ApiNumber;
+    ApiMessage->CsrCaptureData = NULL;
+
+    DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
+           ApiNumber,
+           ApiMessage->Header.u1.s1.DataLength,
+           ApiMessage->Header.u1.s1.TotalLength);
+                
+    /* Check if we are already inside a CSR Server */
+    if (!InsideCsrProcess)
+    {
+        /* Check if we got a Capture Buffer */
+        if (CaptureBuffer)
+        {
+            /*
+             * We have to convert from our local (client) view
+             * to the remote (server) view.
+             */
+            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+                ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
+
+            /* Lock the buffer. */
+            CaptureBuffer->BufferEnd = NULL;
+
+            /*
+             * Each client pointer inside the CSR message is converted into
+             * a server pointer, and each pointer to these message pointers
+             * is converted into an offset.
+             */
+            PointerCount  = CaptureBuffer->PointerCount;
+            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+            while (PointerCount--)
+            {
+                if (*OffsetPointer != 0)
+                {
+                    *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
+                    *OffsetPointer -= (ULONG_PTR)ApiMessage;
+                }
+                ++OffsetPointer;
+            }
+        }
+
+        /* Send the LPC Message */
+        Status = NtRequestWaitReplyPort(CsrApiPort,
+                                        &ApiMessage->Header,
+                                        &ApiMessage->Header);
+
+        /* Check if we got a Capture Buffer */
+        if (CaptureBuffer)
+        {
+            /*
+             * We have to convert back from the remote (server) view
+             * to our local (client) view.
+             */
+            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+                ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
+
+            /*
+             * Convert back the offsets into pointers to CSR message
+             * pointers, and convert back these message server pointers
+             * into client pointers.
+             */
+            PointerCount  = CaptureBuffer->PointerCount;
+            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+            while (PointerCount--)
+            {
+                if (*OffsetPointer != 0)
+                {
+                    *OffsetPointer += (ULONG_PTR)ApiMessage;
+                    *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
+                }
+                ++OffsetPointer;
+            }
+        }
+
+        /* Check for success */
+        if (!NT_SUCCESS(Status))
+        {
+            /* We failed. Overwrite the return value with the failure. */
+            DPRINT1("LPC Failed: %lx\n", Status);
+            ApiMessage->Status = Status;
+        }
+    }
+    else
+    {
+        /* This is a server-to-server call. Save our CID and do a direct call. */
+        DPRINT("Next gen server-to-server call\n");
+
+        /* We check this equality inside CsrValidateMessageBuffer */
+        ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
+
+        Status = CsrServerApiRoutine(&ApiMessage->Header,
+                                     &ApiMessage->Header);
+
+        /* Check for success */
+        if (!NT_SUCCESS(Status))
+        {
+            /* We failed. Overwrite the return value with the failure. */
+            ApiMessage->Status = Status;
+        }
+    }
+
+    /* Return the CSR Result */
+    DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
+    return ApiMessage->Status;
+}
+
+/*
+ * @implemented
+ */
+HANDLE
+NTAPI
+CsrGetProcessId(VOID)
+{
+    return CsrProcessId;
+}
+
 /* EOF */