[SDK][ACGENRAL] Add the shim IgnoreFreeLibrary
[reactos.git] / dll / ntdll / csr / connect.c
index 63d14d3..a1a9509 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS kernel
- * FILE:            lib/ntdll/csr/connect.c
+ * FILE:            dll/ntdll/csr/connect.c
  * PURPOSE:         Routines for connecting and calling CSR
  * PROGRAMMER:      Alex Ionescu (alex@relsoft.net)
  */
@@ -9,6 +9,10 @@
 /* INCLUDES *******************************************************************/
 
 #include <ntdll.h>
+
+#include <ndk/lpcfuncs.h>
+#include <csr/csrsrv.h>
+
 #define NDEBUG
 #include <debug.h>
 
@@ -30,161 +34,22 @@ PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
 
 /* FUNCTIONS ******************************************************************/
 
-/*
- * @implemented
- */
-HANDLE
-NTAPI
-CsrGetProcessId(VOID)
-{
-    return CsrProcessId;
-}
-
-/*
- * @implemented
- */
-NTSTATUS 
-NTAPI
-CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
-                    IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
-                    IN CSR_API_NUMBER ApiNumber,
-                    IN ULONG DataLength)
-{
-    NTSTATUS Status;
-    ULONG PointerCount;
-    PULONG_PTR OffsetPointer;
-
-    /* Fill out the Port Message Header. */
-    ApiMessage->Header.u2.ZeroInit = 0;
-    ApiMessage->Header.u1.s1.TotalLength =
-        FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
-    ApiMessage->Header.u1.s1.DataLength =
-        ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
-
-    /* Fill out the CSR Header. */
-    ApiMessage->ApiNumber = ApiNumber;
-    ApiMessage->CsrCaptureData = NULL;
-
-    DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
-           ApiNumber,
-           ApiMessage->Header.u1.s1.DataLength,
-           ApiMessage->Header.u1.s1.TotalLength);
-                
-    /* Check if we are already inside a CSR Server. */
-    if (!InsideCsrProcess)
-    {
-        /* Check if we got a Capture Buffer. */
-        if (CaptureBuffer)
-        {
-            /*
-             * We have to convert from our local (client) view
-             * to the remote (server) view.
-             */
-            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
-                ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
-
-            /* Lock the buffer. */
-            CaptureBuffer->BufferEnd = NULL;
-
-            /*
-             * Each client pointer inside the CSR message is converted into
-             * a server pointer, and each pointer to these message pointers
-             * is converted into an offset.
-             */
-            PointerCount  = CaptureBuffer->PointerCount;
-            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
-            while (PointerCount--)
-            {
-                if (*OffsetPointer != 0)
-                {
-                    *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
-                    *OffsetPointer -= (ULONG_PTR)ApiMessage;
-                }
-                ++OffsetPointer;
-            }
-        }
-
-        /* Send the LPC Message. */
-        Status = NtRequestWaitReplyPort(CsrApiPort,
-                                        &ApiMessage->Header,
-                                        &ApiMessage->Header);
-
-        /* Check if we got a Capture Buffer. */
-        if (CaptureBuffer)
-        {
-            /*
-             * We have to convert back from the remote (server) view
-             * to our local (client) view.
-             */
-            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
-                ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
-
-            /*
-             * Convert back the offsets into pointers to CSR message
-             * pointers, and convert back these message server pointers
-             * into client pointers.
-             */
-            PointerCount  = CaptureBuffer->PointerCount;
-            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
-            while (PointerCount--)
-            {
-                if (*OffsetPointer != 0)
-                {
-                    *OffsetPointer += (ULONG_PTR)ApiMessage;
-                    *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
-                }
-                ++OffsetPointer;
-            }
-        }
-
-        /* Check for success. */
-        if (!NT_SUCCESS(Status))
-        {
-            /* We failed. Overwrite the return value with the failure. */
-            DPRINT1("LPC Failed: %lx\n", Status);
-            ApiMessage->Status = Status;
-        }
-    }
-    else
-    {
-        /* This is a server-to-server call. Save our CID and do a direct call. */
-        DPRINT1("Next gen server-to-server call\n");
-
-        /* We check this equality inside CsrValidateMessageBuffer. */
-        ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
-
-        Status = CsrServerApiRoutine(&ApiMessage->Header,
-                                     &ApiMessage->Header);
-
-        /* Check for success. */
-        if (!NT_SUCCESS(Status))
-        {
-            /* We failed. Overwrite the return value with the failure. */
-            ApiMessage->Status = Status;
-        }
-    }
-
-    /* Return the CSR Result. */
-    DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
-    return ApiMessage->Status;
-}
-
 NTSTATUS
 NTAPI
 CsrpConnectToServer(IN PWSTR ObjectDirectory)
 {
+    NTSTATUS Status;
     ULONG PortNameLength;
     UNICODE_STRING PortName;
     LARGE_INTEGER CsrSectionViewSize;
-    NTSTATUS Status;
     HANDLE CsrSectionHandle;
     PORT_VIEW LpcWrite;
     REMOTE_PORT_VIEW LpcRead;
     SECURITY_QUALITY_OF_SERVICE SecurityQos;
     SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
     PSID SystemSid = NULL;
-    CSR_CONNECTION_INFO ConnectionInfo;
-    ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
+    CSR_API_CONNECTINFO ConnectionInfo;
+    ULONG ConnectionInfoLength = sizeof(CSR_API_CONNECTINFO);
 
     DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
 
@@ -221,7 +86,7 @@ CsrpConnectToServer(IN PWSTR ObjectDirectory)
                              NULL,
                              &CsrSectionViewSize,
                              PAGE_READWRITE,
-                             SEC_COMMIT,
+                             SEC_RESERVE,
                              NULL);
     if (!NT_SUCCESS(Status))
     {
@@ -246,20 +111,20 @@ CsrpConnectToServer(IN PWSTR ObjectDirectory)
     SecurityQos.EffectiveOnly = TRUE;
 
     /* Setup the connection info */
-    ConnectionInfo.Version = 0x10000;
+    ConnectionInfo.DebugFlags = 0;
 
     /* Create a SID for us */
     Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
-                                          1,
-                                          SECURITY_LOCAL_SYSTEM_RID,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          0,
-                                          &SystemSid);
+                                         1,
+                                         SECURITY_LOCAL_SYSTEM_RID,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         0,
+                                         &SystemSid);
     if (!NT_SUCCESS(Status))
     {
         /* Failure */
@@ -278,6 +143,7 @@ CsrpConnectToServer(IN PWSTR ObjectDirectory)
                                  NULL,
                                  &ConnectionInfo,
                                  &ConnectionInfoLength);
+    RtlFreeSid(SystemSid);
     NtClose(CsrSectionHandle);
     if (!NT_SUCCESS(Status))
     {
@@ -291,12 +157,12 @@ CsrpConnectToServer(IN PWSTR ObjectDirectory)
                          (ULONG_PTR)LpcWrite.ViewBase;
 
     /* Save the Process */
-    CsrProcessId = ConnectionInfo.ProcessId;
+    CsrProcessId = ConnectionInfo.ServerProcessId;
 
     /* Save CSR Section data */
     NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
     NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
-    NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
+    NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
 
     /* Create the port heap */
     CsrPortHeap = RtlCreateHeap(0,
@@ -371,7 +237,7 @@ CsrClientConnectToServer(IN PWSTR ObjectDirectory,
     if (InsideCsrProcess)
     {
         /* We're inside, so let's find csrsrv */
-        DPRINT1("Next-GEN CSRSS support\n");
+        DPRINT("Next-GEN CSRSS support\n");
         RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
         Status = LdrGetDllHandle(NULL,
                                  NULL,
@@ -389,21 +255,22 @@ CsrClientConnectToServer(IN PWSTR ObjectDirectory,
         CsrPortHeap = RtlGetProcessHeap();
 
         /* Tell the caller we're inside the server */
-        *ServerToServerCall = InsideCsrProcess;
+        if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
         return STATUS_SUCCESS;
     }
 
     /* Now check if connection info is given */
     if (ConnectionInfo)
     {
-        /* Well, we're defintely in a client now */
+        /* Well, we're definitely in a client now */
         InsideCsrProcess = FALSE;
 
         /* Do we have a connection to CSR yet? */
         if (!CsrApiPort)
         {
             /* No, set it up now */
-            if (!NT_SUCCESS(Status = CsrpConnectToServer(ObjectDirectory)))
+            Status = CsrpConnectToServer(ObjectDirectory);
+            if (!NT_SUCCESS(Status))
             {
                 /* Failed */
                 DPRINT1("Failure to connect to CSR\n");
@@ -458,4 +325,170 @@ CsrClientConnectToServer(IN PWSTR ObjectDirectory,
     return Status;
 }
 
+#if 0
+//
+// Structures can be padded at the end, causing the size of the entire structure
+// minus the size of the last field, not to be equal to the offset of the last
+// field.
+//
+typedef struct _TEST_EMBEDDED
+{
+    ULONG One;
+    ULONG Two;
+    ULONG Three;
+} TEST_EMBEDDED;
+
+typedef struct _TEST
+{
+    PORT_MESSAGE h;
+    TEST_EMBEDDED Three;
+} TEST;
+
+C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
+C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
+C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
+
+C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
+C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
+#endif
+
+/*
+ * @implemented
+ */
+NTSTATUS
+NTAPI
+CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
+                    IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
+                    IN CSR_API_NUMBER ApiNumber,
+                    IN ULONG DataLength)
+{
+    NTSTATUS Status;
+    ULONG PointerCount;
+    PULONG_PTR OffsetPointer;
+
+    /* Fill out the Port Message Header */
+    ApiMessage->Header.u2.ZeroInit = 0;
+    ApiMessage->Header.u1.s1.TotalLength = DataLength +
+        sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data); // FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
+    ApiMessage->Header.u1.s1.DataLength = DataLength +
+        FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header); // ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+
+    /* Fill out the CSR Header */
+    ApiMessage->ApiNumber = ApiNumber;
+    ApiMessage->CsrCaptureData = NULL;
+
+    DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
+           ApiNumber,
+           ApiMessage->Header.u1.s1.DataLength,
+           ApiMessage->Header.u1.s1.TotalLength);
+
+    /* Check if we are already inside a CSR Server */
+    if (!InsideCsrProcess)
+    {
+        /* Check if we got a Capture Buffer */
+        if (CaptureBuffer)
+        {
+            /*
+             * We have to convert from our local (client) view
+             * to the remote (server) view.
+             */
+            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+                ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
+
+            /* Lock the buffer. */
+            CaptureBuffer->BufferEnd = NULL;
+
+            /*
+             * Each client pointer inside the CSR message is converted into
+             * a server pointer, and each pointer to these message pointers
+             * is converted into an offset.
+             */
+            PointerCount  = CaptureBuffer->PointerCount;
+            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+            while (PointerCount--)
+            {
+                if (*OffsetPointer != 0)
+                {
+                    *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
+                    *OffsetPointer -= (ULONG_PTR)ApiMessage;
+                }
+                ++OffsetPointer;
+            }
+        }
+
+        /* Send the LPC Message */
+        Status = NtRequestWaitReplyPort(CsrApiPort,
+                                        &ApiMessage->Header,
+                                        &ApiMessage->Header);
+
+        /* Check if we got a Capture Buffer */
+        if (CaptureBuffer)
+        {
+            /*
+             * We have to convert back from the remote (server) view
+             * to our local (client) view.
+             */
+            ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+                ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
+
+            /*
+             * Convert back the offsets into pointers to CSR message
+             * pointers, and convert back these message server pointers
+             * into client pointers.
+             */
+            PointerCount  = CaptureBuffer->PointerCount;
+            OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+            while (PointerCount--)
+            {
+                if (*OffsetPointer != 0)
+                {
+                    *OffsetPointer += (ULONG_PTR)ApiMessage;
+                    *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
+                }
+                ++OffsetPointer;
+            }
+        }
+
+        /* Check for success */
+        if (!NT_SUCCESS(Status))
+        {
+            /* We failed. Overwrite the return value with the failure. */
+            DPRINT1("LPC Failed: %lx\n", Status);
+            ApiMessage->Status = Status;
+        }
+    }
+    else
+    {
+        /* This is a server-to-server call. Save our CID and do a direct call. */
+        DPRINT("Next gen server-to-server call\n");
+
+        /* We check this equality inside CsrValidateMessageBuffer */
+        ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
+
+        Status = CsrServerApiRoutine(&ApiMessage->Header,
+                                     &ApiMessage->Header);
+
+        /* Check for success */
+        if (!NT_SUCCESS(Status))
+        {
+            /* We failed. Overwrite the return value with the failure. */
+            ApiMessage->Status = Status;
+        }
+    }
+
+    /* Return the CSR Result */
+    DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
+    return ApiMessage->Status;
+}
+
+/*
+ * @implemented
+ */
+HANDLE
+NTAPI
+CsrGetProcessId(VOID)
+{
+    return CsrProcessId;
+}
+
 /* EOF */