Implement NTDLL's CSR routines in a compatible way. Fix prototypes, argument count...
[reactos.git] / reactos / lib / ntdll / csr / connect.c
diff --git a/reactos/lib/ntdll/csr/connect.c b/reactos/lib/ntdll/csr/connect.c
new file mode 100644 (file)
index 0000000..26fee8a
--- /dev/null
@@ -0,0 +1,430 @@
+/*\r
+ * COPYRIGHT:       See COPYING in the top level directory\r
+ * PROJECT:         ReactOS kernel\r
+ * FILE:            lib/ntdll/csr/connect.c\r
+ * PURPOSE:         Routines for connecting and calling CSR\r
+ * PROGRAMMER:      Alex Ionescu (alex@relsoft.net)\r
+ */\r
+\r
+/* INCLUDES *****************************************************************/\r
+\r
+#include <ntdll.h>\r
+#define NDEBUG\r
+#include <debug.h>\r
+\r
+/* GLOBALS *******************************************************************/\r
+\r
+HANDLE CsrApiPort;\r
+HANDLE CsrProcessId;\r
+HANDLE CsrPortHeap;\r
+ULONG_PTR CsrPortMemoryDelta;\r
+BOOLEAN InsideCsrProcess = FALSE;\r
+BOOLEAN UsingOldCsr = TRUE;\r
+\r
+typedef NTSTATUS\r
+(NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,\r
+                                 IN PPORT_MESSAGE Reply);\r
+\r
+PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;\r
+\r
+#define UNICODE_PATH_SEP L"\\"\r
+#define CSR_PORT_NAME L"ApiPort"\r
+\r
+/* FUNCTIONS *****************************************************************/\r
+\r
+/*\r
+ * @implemented\r
+ */\r
+HANDLE\r
+NTAPI\r
+CsrGetProcessId(VOID)\r
+{\r
+       return CsrProcessId;\r
+}\r
+\r
+/*\r
+ * @implemented\r
+ */\r
+NTSTATUS \r
+NTAPI\r
+CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,\r
+                    PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,\r
+                    CSR_API_NUMBER ApiNumber,\r
+                    ULONG RequestLength)\r
+{\r
+    NTSTATUS Status;\r
+    ULONG PointerCount;\r
+    PULONG_PTR Pointers;\r
+    ULONG_PTR CurrentPointer;\r
+    DPRINT("CsrClientCallServer\n");\r
+  \r
+    /* Fill out the Port Message Header */\r
+    ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);\r
+    ApiMessage->Header.u1.s1.TotalLength = RequestLength;\r
+\r
+    /* Fill out the CSR Header */\r
+    ApiMessage->Type = ApiNumber;\r
+    //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR\r
+    ApiMessage->CsrCaptureData = NULL;\r
+\r
+    DPRINT("API: %x, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n", \r
+           ApiNumber,\r
+           ApiMessage->Header.u1.s1.DataLength,\r
+           ApiMessage->Header.u1.s1.TotalLength);\r
+                \r
+    /* Check if we are already inside a CSR Server */\r
+    if (!InsideCsrProcess)\r
+    {\r
+        /* Check if we got a a Capture Buffer */\r
+        if (CaptureBuffer)\r
+        {\r
+            /* We have to convert from our local view to the remote view */\r
+            DPRINT1("Converting CaptureBuffer\n");\r
+            ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +\r
+                                                 CsrPortMemoryDelta);\r
+\r
+            /* Lock the buffer */\r
+            CaptureBuffer->BufferEnd = 0;\r
+\r
+            /* Get the pointer information */\r
+            PointerCount = CaptureBuffer->PointerCount;\r
+            Pointers = CaptureBuffer->PointerArray;\r
+\r
+            /* Loop through every pointer and convert it */\r
+            while (PointerCount--)\r
+            {\r
+                /* Get this pointer and check if it's valid */\r
+                if ((CurrentPointer = *Pointers++))\r
+                {\r
+                    /* Update it */\r
+                    *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;\r
+                    Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;\r
+                }\r
+            }\r
+        }\r
+\r
+        /* Send the LPC Message */\r
+        Status = NtRequestWaitReplyPort(CsrApiPort,\r
+                                        &ApiMessage->Header,\r
+                                        &ApiMessage->Header);\r
+\r
+        /* Check if we got a a Capture Buffer */\r
+        if (CaptureBuffer)\r
+        {\r
+            /* We have to convert from the remote view to our remote view */\r
+            DPRINT1("Reconverting CaptureBuffer\n");\r
+            ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)\r
+                                                 ApiMessage->CsrCaptureData -\r
+                                                 CsrPortMemoryDelta);\r
+\r
+            /* Get the pointer information */\r
+            PointerCount = CaptureBuffer->PointerCount;\r
+            Pointers = CaptureBuffer->PointerArray;\r
+\r
+            /* Loop through every pointer and convert it */\r
+            while (PointerCount--)\r
+            {\r
+                /* Get this pointer and check if it's valid */\r
+                if ((CurrentPointer = *Pointers++))\r
+                {\r
+                    /* Update it */\r
+                    CurrentPointer += (ULONG_PTR)ApiMessage;\r
+                    Pointers[-1] = CurrentPointer;\r
+                    *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;\r
+                }\r
+            }\r
+        }\r
+\r
+        /* Check for success */\r
+        if (!NT_SUCCESS(Status))\r
+        {\r
+            /* We failed. Overwrite the return value with the failure */\r
+            DPRINT1("LPC Failed: %lx\n", Status);\r
+            ApiMessage->Status = Status;\r
+        }\r
+    }\r
+    else\r
+    {\r
+        /* This is a server-to-server call. Save our CID and do a direct call */\r
+        DbgBreakPoint();\r
+        ApiMessage->Header.ClientId = NtCurrentTeb()->Cid;\r
+        Status = CsrServerApiRoutine(&ApiMessage->Header,\r
+                                     &ApiMessage->Header);\r
+       \r
+        /* Check for success */\r
+        if (!NT_SUCCESS(Status))\r
+        {\r
+            /* We failed. Overwrite the return value with the failure */\r
+            ApiMessage->Status = Status;\r
+        }\r
+    }\r
+\r
+    /* Return the CSR Result */\r
+    DPRINT("Got back: %x\n", ApiMessage->Status);\r
+    return ApiMessage->Status;\r
+}\r
+\r
+NTSTATUS\r
+NTAPI\r
+CsrConnectToServer(IN PWSTR ObjectDirectory)\r
+{\r
+    ULONG PortNameLength;\r
+    UNICODE_STRING PortName;\r
+    LARGE_INTEGER CsrSectionViewSize;\r
+    NTSTATUS Status;\r
+    HANDLE CsrSectionHandle;\r
+    PORT_VIEW LpcWrite;\r
+    REMOTE_PORT_VIEW LpcRead;\r
+    SECURITY_QUALITY_OF_SERVICE SecurityQos;\r
+    SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};\r
+    PSID SystemSid = NULL;\r
+    CSR_CONNECTION_INFO ConnectionInfo;\r
+    ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);\r
+    DPRINT("CsrConnectToServer\n");\r
+\r
+    /* Calculate the total port name size */\r
+    PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +\r
+                     sizeof(CSR_PORT_NAME);\r
+\r
+    /* Set the port name */\r
+    PortName.Length = 0;\r
+    PortName.MaximumLength = PortNameLength;\r
+\r
+    /* Allocate a buffer for it */\r
+    PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);\r
+\r
+    /* Create the name */\r
+    RtlAppendUnicodeToString(&PortName, ObjectDirectory );\r
+    RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);\r
+    RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);\r
+\r
+    /* Create a section for the port memory */\r
+    CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;\r
+    Status = NtCreateSection(&CsrSectionHandle,\r
+                                    SECTION_ALL_ACCESS,\r
+                                        NULL,\r
+                                        &CsrSectionViewSize,\r
+                                        PAGE_READWRITE,\r
+                                        SEC_COMMIT,\r
+                                        NULL);\r
+    if (!NT_SUCCESS(Status))\r
+    {\r
+        DPRINT1("Failure allocating CSR Section\n");\r
+        return Status;\r
+    }\r
+\r
+    /* Set up the port view structures to match them with the section */\r
+    LpcWrite.Length = sizeof(PORT_VIEW);\r
+    LpcWrite.SectionHandle = CsrSectionHandle;\r
+    LpcWrite.SectionOffset = 0;\r
+    LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;\r
+    LpcWrite.ViewBase = 0;\r
+    LpcWrite.ViewRemoteBase = 0;\r
+    LpcRead.Length = sizeof(REMOTE_PORT_VIEW);\r
+    LpcRead.ViewSize = 0;\r
+    LpcRead.ViewBase = 0;\r
+\r
+    /* Setup the QoS */\r
+    SecurityQos.ImpersonationLevel = SecurityImpersonation;\r
+    SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;\r
+    SecurityQos.EffectiveOnly = TRUE;\r
+\r
+    /* Setup the connection info */\r
+    ConnectionInfo.Version = 0x10000;\r
+\r
+    /* Create a SID for us */\r
+    Status = RtlAllocateAndInitializeSid(&NtSidAuthority,\r
+                                          1,\r
+                                          SECURITY_LOCAL_SYSTEM_RID,\r
+                                          0,\r
+                                          0,\r
+                                          0,\r
+                                          0,\r
+                                          0,\r
+                                          0,\r
+                                          0,\r
+                                          &SystemSid);\r
+\r
+    /* Connect to the port */\r
+    Status = NtSecureConnectPort(&CsrApiPort,\r
+                                 &PortName,\r
+                                            &SecurityQos,\r
+                                            &LpcWrite,\r
+                                 SystemSid,\r
+                                            &LpcRead,\r
+                                            NULL,\r
+                                            &ConnectionInfo,\r
+                                            &ConnectionInfoLength);\r
+    NtClose(CsrSectionHandle);\r
+    if (!NT_SUCCESS(Status))\r
+    {\r
+        /* Failure */\r
+        DPRINT1("Couldn't connect to CSR port\n");\r
+        return Status;\r
+    }\r
+\r
+    /* Save the delta between the sections, for capture usage later */\r
+    CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -\r
+                         (ULONG_PTR)LpcWrite.ViewBase;\r
+\r
+    /* Save the Process */\r
+    CsrProcessId = ConnectionInfo.ProcessId;\r
+\r
+    /* Save CSR Section data */\r
+    NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;\r
+    NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;\r
+    NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;\r
+\r
+    /* Create the port heap */\r
+    CsrPortHeap = RtlCreateHeap(0,\r
+                                LpcWrite.ViewBase,\r
+                                LpcWrite.ViewSize,\r
+                                PAGE_SIZE,\r
+                                0,\r
+                                0);\r
+\r
+    /* Return success */\r
+    return STATUS_SUCCESS;\r
+}\r
+\r
+/*\r
+ * @implemented\r
+ */\r
+NTSTATUS\r
+NTAPI\r
+CsrClientConnectToServer(PWSTR ObjectDirectory,\r
+                         ULONG ServerId,\r
+                         PVOID ConnectionInfo,\r
+                         PULONG ConnectionInfoSize,\r
+                         PBOOLEAN ServerToServerCall)\r
+{\r
+    NTSTATUS Status;\r
+    PIMAGE_NT_HEADERS NtHeader;\r
+    UNICODE_STRING CsrSrvName;\r
+    HANDLE hCsrSrv;\r
+    ANSI_STRING CsrServerRoutineName;\r
+    PCSR_CAPTURE_BUFFER CaptureBuffer;\r
+    CSR_API_MESSAGE RosApiMessage;\r
+    CSR_API_MESSAGE2 ApiMessage;\r
+    PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;\r
+\r
+    /* Validate the Connection Info */\r
+    DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);\r
+    if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))\r
+    {\r
+        DPRINT1("Connection info given, but no length\n");\r
+        return STATUS_INVALID_PARAMETER;\r
+    }\r
+\r
+    /* Check if we're inside a CSR Process */\r
+    if (InsideCsrProcess)\r
+    {\r
+        /* Tell the client that we're already inside CSR */\r
+        if (ServerToServerCall) *ServerToServerCall = TRUE;\r
+        return STATUS_SUCCESS;\r
+    }\r
+\r
+    /*\r
+     * We might be in a CSR Process but not know it, if this is the first call.\r
+     * So let's find out.\r
+     */\r
+    if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))\r
+    {\r
+        /* The image isn't valid */\r
+        DPRINT1("Invalid image\n");\r
+        return STATUS_INVALID_IMAGE_FORMAT;\r
+    }\r
+    InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);\r
+\r
+    /* Now we can check if we are inside or not */\r
+    if (InsideCsrProcess && !UsingOldCsr)\r
+    {\r
+        /* We're inside, so let's find csrsrv */\r
+        DbgBreakPoint();\r
+        RtlInitUnicodeString(&CsrSrvName, L"csrsrv");\r
+        Status = LdrGetDllHandle(NULL,\r
+                                 NULL,\r
+                                 &CsrSrvName,\r
+                                 &hCsrSrv);\r
+        RtlFreeUnicodeString(&CsrSrvName);\r
+\r
+        /* Now get the Server to Server routine */\r
+        RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");\r
+        Status = LdrGetProcedureAddress(hCsrSrv,\r
+                                        &CsrServerRoutineName,\r
+                                        0L,\r
+                                        (PVOID*)&CsrServerApiRoutine);\r
+\r
+        /* Use the local heap as port heap */\r
+        CsrPortHeap = RtlGetProcessHeap();\r
+\r
+        /* Tell the caller we're inside the server */\r
+        *ServerToServerCall = InsideCsrProcess;\r
+        return STATUS_SUCCESS;\r
+    }\r
+\r
+    /* Now check if connection info is given */\r
+    if (ConnectionInfo)\r
+    {\r
+        /* Well, we're defintely in a client now */\r
+        InsideCsrProcess = FALSE;\r
+\r
+        /* Do we have a connection to CSR yet? */\r
+        if (!CsrApiPort)\r
+        {\r
+            /* No, set it up now */\r
+            if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))\r
+            {\r
+                /* Failed */\r
+                DPRINT1("Failure to connect to CSR\n");\r
+                return Status;\r
+            }\r
+        }\r
+\r
+        /* Setup the connect message header */\r
+        ClientConnect->ServerId = ServerId;\r
+        ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;\r
+\r
+        /* Setup a buffer for the connection info */\r
+        CaptureBuffer = CsrAllocateCaptureBuffer(1,\r
+                                                 ClientConnect->ConnectionInfoSize);\r
+\r
+        /* Allocate a pointer for the connection info*/\r
+        CsrAllocateMessagePointer(CaptureBuffer,\r
+                                  ClientConnect->ConnectionInfoSize,\r
+                                  &ClientConnect->ConnectionInfo);\r
+\r
+        /* Copy the data into the buffer */\r
+        RtlMoveMemory(ClientConnect->ConnectionInfo,\r
+                      ConnectionInfo,\r
+                      ClientConnect->ConnectionInfoSize);\r
+\r
+        /* Return the allocated length */\r
+        *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;\r
+\r
+        /* Call CSR */\r
+#if 0\r
+        Status = CsrClientCallServer(&ApiMessage,\r
+                                     CaptureBuffer,\r
+                                     CSR_MAKE_OPCODE(CsrSrvClientConnect,\r
+                                                     CSR_SRV_DLL),\r
+                                     sizeof(CSR_CLIENT_CONNECT));\r
+#endif\r
+        Status = CsrClientCallServer(&RosApiMessage,\r
+                                                    NULL,\r
+                                                    MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),\r
+                                                    sizeof(CSR_API_MESSAGE));\r
+    }\r
+    else\r
+    {\r
+        /* No connection info, just return */\r
+        Status = STATUS_SUCCESS;\r
+    }\r
+\r
+    /* Let the caller know if this was server to server */\r
+    DPRINT("Status was: %lx. Are we in server: %lx\n", Status, InsideCsrProcess);\r
+    if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;\r
+    return Status;\r
+}\r
+\r
+/* EOF */\r