--- /dev/null
+/*\r
+ * COPYRIGHT: See COPYING in the top level directory\r
+ * PROJECT: ReactOS kernel\r
+ * FILE: lib/ntdll/csr/connect.c\r
+ * PURPOSE: Routines for connecting and calling CSR\r
+ * PROGRAMMER: Alex Ionescu (alex@relsoft.net)\r
+ */\r
+\r
+/* INCLUDES *****************************************************************/\r
+\r
+#include <ntdll.h>\r
+#define NDEBUG\r
+#include <debug.h>\r
+\r
+/* GLOBALS *******************************************************************/\r
+\r
+HANDLE CsrApiPort;\r
+HANDLE CsrProcessId;\r
+HANDLE CsrPortHeap;\r
+ULONG_PTR CsrPortMemoryDelta;\r
+BOOLEAN InsideCsrProcess = FALSE;\r
+BOOLEAN UsingOldCsr = TRUE;\r
+\r
+typedef NTSTATUS\r
+(NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,\r
+ IN PPORT_MESSAGE Reply);\r
+\r
+PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;\r
+\r
+#define UNICODE_PATH_SEP L"\\"\r
+#define CSR_PORT_NAME L"ApiPort"\r
+\r
+/* FUNCTIONS *****************************************************************/\r
+\r
+/*\r
+ * @implemented\r
+ */\r
+HANDLE\r
+NTAPI\r
+CsrGetProcessId(VOID)\r
+{\r
+ return CsrProcessId;\r
+}\r
+\r
+/*\r
+ * @implemented\r
+ */\r
+NTSTATUS \r
+NTAPI\r
+CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,\r
+ PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,\r
+ CSR_API_NUMBER ApiNumber,\r
+ ULONG RequestLength)\r
+{\r
+ NTSTATUS Status;\r
+ ULONG PointerCount;\r
+ PULONG_PTR Pointers;\r
+ ULONG_PTR CurrentPointer;\r
+ DPRINT("CsrClientCallServer\n");\r
+ \r
+ /* Fill out the Port Message Header */\r
+ ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);\r
+ ApiMessage->Header.u1.s1.TotalLength = RequestLength;\r
+\r
+ /* Fill out the CSR Header */\r
+ ApiMessage->Type = ApiNumber;\r
+ //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR\r
+ ApiMessage->CsrCaptureData = NULL;\r
+\r
+ DPRINT("API: %x, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n", \r
+ ApiNumber,\r
+ ApiMessage->Header.u1.s1.DataLength,\r
+ ApiMessage->Header.u1.s1.TotalLength);\r
+ \r
+ /* Check if we are already inside a CSR Server */\r
+ if (!InsideCsrProcess)\r
+ {\r
+ /* Check if we got a a Capture Buffer */\r
+ if (CaptureBuffer)\r
+ {\r
+ /* We have to convert from our local view to the remote view */\r
+ DPRINT1("Converting CaptureBuffer\n");\r
+ ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +\r
+ CsrPortMemoryDelta);\r
+\r
+ /* Lock the buffer */\r
+ CaptureBuffer->BufferEnd = 0;\r
+\r
+ /* Get the pointer information */\r
+ PointerCount = CaptureBuffer->PointerCount;\r
+ Pointers = CaptureBuffer->PointerArray;\r
+\r
+ /* Loop through every pointer and convert it */\r
+ while (PointerCount--)\r
+ {\r
+ /* Get this pointer and check if it's valid */\r
+ if ((CurrentPointer = *Pointers++))\r
+ {\r
+ /* Update it */\r
+ *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;\r
+ Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;\r
+ }\r
+ }\r
+ }\r
+\r
+ /* Send the LPC Message */\r
+ Status = NtRequestWaitReplyPort(CsrApiPort,\r
+ &ApiMessage->Header,\r
+ &ApiMessage->Header);\r
+\r
+ /* Check if we got a a Capture Buffer */\r
+ if (CaptureBuffer)\r
+ {\r
+ /* We have to convert from the remote view to our remote view */\r
+ DPRINT1("Reconverting CaptureBuffer\n");\r
+ ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)\r
+ ApiMessage->CsrCaptureData -\r
+ CsrPortMemoryDelta);\r
+\r
+ /* Get the pointer information */\r
+ PointerCount = CaptureBuffer->PointerCount;\r
+ Pointers = CaptureBuffer->PointerArray;\r
+\r
+ /* Loop through every pointer and convert it */\r
+ while (PointerCount--)\r
+ {\r
+ /* Get this pointer and check if it's valid */\r
+ if ((CurrentPointer = *Pointers++))\r
+ {\r
+ /* Update it */\r
+ CurrentPointer += (ULONG_PTR)ApiMessage;\r
+ Pointers[-1] = CurrentPointer;\r
+ *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;\r
+ }\r
+ }\r
+ }\r
+\r
+ /* Check for success */\r
+ if (!NT_SUCCESS(Status))\r
+ {\r
+ /* We failed. Overwrite the return value with the failure */\r
+ DPRINT1("LPC Failed: %lx\n", Status);\r
+ ApiMessage->Status = Status;\r
+ }\r
+ }\r
+ else\r
+ {\r
+ /* This is a server-to-server call. Save our CID and do a direct call */\r
+ DbgBreakPoint();\r
+ ApiMessage->Header.ClientId = NtCurrentTeb()->Cid;\r
+ Status = CsrServerApiRoutine(&ApiMessage->Header,\r
+ &ApiMessage->Header);\r
+ \r
+ /* Check for success */\r
+ if (!NT_SUCCESS(Status))\r
+ {\r
+ /* We failed. Overwrite the return value with the failure */\r
+ ApiMessage->Status = Status;\r
+ }\r
+ }\r
+\r
+ /* Return the CSR Result */\r
+ DPRINT("Got back: %x\n", ApiMessage->Status);\r
+ return ApiMessage->Status;\r
+}\r
+\r
+NTSTATUS\r
+NTAPI\r
+CsrConnectToServer(IN PWSTR ObjectDirectory)\r
+{\r
+ ULONG PortNameLength;\r
+ UNICODE_STRING PortName;\r
+ LARGE_INTEGER CsrSectionViewSize;\r
+ NTSTATUS Status;\r
+ HANDLE CsrSectionHandle;\r
+ PORT_VIEW LpcWrite;\r
+ REMOTE_PORT_VIEW LpcRead;\r
+ SECURITY_QUALITY_OF_SERVICE SecurityQos;\r
+ SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};\r
+ PSID SystemSid = NULL;\r
+ CSR_CONNECTION_INFO ConnectionInfo;\r
+ ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);\r
+ DPRINT("CsrConnectToServer\n");\r
+\r
+ /* Calculate the total port name size */\r
+ PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +\r
+ sizeof(CSR_PORT_NAME);\r
+\r
+ /* Set the port name */\r
+ PortName.Length = 0;\r
+ PortName.MaximumLength = PortNameLength;\r
+\r
+ /* Allocate a buffer for it */\r
+ PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);\r
+\r
+ /* Create the name */\r
+ RtlAppendUnicodeToString(&PortName, ObjectDirectory );\r
+ RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);\r
+ RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);\r
+\r
+ /* Create a section for the port memory */\r
+ CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;\r
+ Status = NtCreateSection(&CsrSectionHandle,\r
+ SECTION_ALL_ACCESS,\r
+ NULL,\r
+ &CsrSectionViewSize,\r
+ PAGE_READWRITE,\r
+ SEC_COMMIT,\r
+ NULL);\r
+ if (!NT_SUCCESS(Status))\r
+ {\r
+ DPRINT1("Failure allocating CSR Section\n");\r
+ return Status;\r
+ }\r
+\r
+ /* Set up the port view structures to match them with the section */\r
+ LpcWrite.Length = sizeof(PORT_VIEW);\r
+ LpcWrite.SectionHandle = CsrSectionHandle;\r
+ LpcWrite.SectionOffset = 0;\r
+ LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;\r
+ LpcWrite.ViewBase = 0;\r
+ LpcWrite.ViewRemoteBase = 0;\r
+ LpcRead.Length = sizeof(REMOTE_PORT_VIEW);\r
+ LpcRead.ViewSize = 0;\r
+ LpcRead.ViewBase = 0;\r
+\r
+ /* Setup the QoS */\r
+ SecurityQos.ImpersonationLevel = SecurityImpersonation;\r
+ SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;\r
+ SecurityQos.EffectiveOnly = TRUE;\r
+\r
+ /* Setup the connection info */\r
+ ConnectionInfo.Version = 0x10000;\r
+\r
+ /* Create a SID for us */\r
+ Status = RtlAllocateAndInitializeSid(&NtSidAuthority,\r
+ 1,\r
+ SECURITY_LOCAL_SYSTEM_RID,\r
+ 0,\r
+ 0,\r
+ 0,\r
+ 0,\r
+ 0,\r
+ 0,\r
+ 0,\r
+ &SystemSid);\r
+\r
+ /* Connect to the port */\r
+ Status = NtSecureConnectPort(&CsrApiPort,\r
+ &PortName,\r
+ &SecurityQos,\r
+ &LpcWrite,\r
+ SystemSid,\r
+ &LpcRead,\r
+ NULL,\r
+ &ConnectionInfo,\r
+ &ConnectionInfoLength);\r
+ NtClose(CsrSectionHandle);\r
+ if (!NT_SUCCESS(Status))\r
+ {\r
+ /* Failure */\r
+ DPRINT1("Couldn't connect to CSR port\n");\r
+ return Status;\r
+ }\r
+\r
+ /* Save the delta between the sections, for capture usage later */\r
+ CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -\r
+ (ULONG_PTR)LpcWrite.ViewBase;\r
+\r
+ /* Save the Process */\r
+ CsrProcessId = ConnectionInfo.ProcessId;\r
+\r
+ /* Save CSR Section data */\r
+ NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;\r
+ NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;\r
+ NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;\r
+\r
+ /* Create the port heap */\r
+ CsrPortHeap = RtlCreateHeap(0,\r
+ LpcWrite.ViewBase,\r
+ LpcWrite.ViewSize,\r
+ PAGE_SIZE,\r
+ 0,\r
+ 0);\r
+\r
+ /* Return success */\r
+ return STATUS_SUCCESS;\r
+}\r
+\r
+/*\r
+ * @implemented\r
+ */\r
+NTSTATUS\r
+NTAPI\r
+CsrClientConnectToServer(PWSTR ObjectDirectory,\r
+ ULONG ServerId,\r
+ PVOID ConnectionInfo,\r
+ PULONG ConnectionInfoSize,\r
+ PBOOLEAN ServerToServerCall)\r
+{\r
+ NTSTATUS Status;\r
+ PIMAGE_NT_HEADERS NtHeader;\r
+ UNICODE_STRING CsrSrvName;\r
+ HANDLE hCsrSrv;\r
+ ANSI_STRING CsrServerRoutineName;\r
+ PCSR_CAPTURE_BUFFER CaptureBuffer;\r
+ CSR_API_MESSAGE RosApiMessage;\r
+ CSR_API_MESSAGE2 ApiMessage;\r
+ PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;\r
+\r
+ /* Validate the Connection Info */\r
+ DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);\r
+ if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))\r
+ {\r
+ DPRINT1("Connection info given, but no length\n");\r
+ return STATUS_INVALID_PARAMETER;\r
+ }\r
+\r
+ /* Check if we're inside a CSR Process */\r
+ if (InsideCsrProcess)\r
+ {\r
+ /* Tell the client that we're already inside CSR */\r
+ if (ServerToServerCall) *ServerToServerCall = TRUE;\r
+ return STATUS_SUCCESS;\r
+ }\r
+\r
+ /*\r
+ * We might be in a CSR Process but not know it, if this is the first call.\r
+ * So let's find out.\r
+ */\r
+ if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))\r
+ {\r
+ /* The image isn't valid */\r
+ DPRINT1("Invalid image\n");\r
+ return STATUS_INVALID_IMAGE_FORMAT;\r
+ }\r
+ InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);\r
+\r
+ /* Now we can check if we are inside or not */\r
+ if (InsideCsrProcess && !UsingOldCsr)\r
+ {\r
+ /* We're inside, so let's find csrsrv */\r
+ DbgBreakPoint();\r
+ RtlInitUnicodeString(&CsrSrvName, L"csrsrv");\r
+ Status = LdrGetDllHandle(NULL,\r
+ NULL,\r
+ &CsrSrvName,\r
+ &hCsrSrv);\r
+ RtlFreeUnicodeString(&CsrSrvName);\r
+\r
+ /* Now get the Server to Server routine */\r
+ RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");\r
+ Status = LdrGetProcedureAddress(hCsrSrv,\r
+ &CsrServerRoutineName,\r
+ 0L,\r
+ (PVOID*)&CsrServerApiRoutine);\r
+\r
+ /* Use the local heap as port heap */\r
+ CsrPortHeap = RtlGetProcessHeap();\r
+\r
+ /* Tell the caller we're inside the server */\r
+ *ServerToServerCall = InsideCsrProcess;\r
+ return STATUS_SUCCESS;\r
+ }\r
+\r
+ /* Now check if connection info is given */\r
+ if (ConnectionInfo)\r
+ {\r
+ /* Well, we're defintely in a client now */\r
+ InsideCsrProcess = FALSE;\r
+\r
+ /* Do we have a connection to CSR yet? */\r
+ if (!CsrApiPort)\r
+ {\r
+ /* No, set it up now */\r
+ if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))\r
+ {\r
+ /* Failed */\r
+ DPRINT1("Failure to connect to CSR\n");\r
+ return Status;\r
+ }\r
+ }\r
+\r
+ /* Setup the connect message header */\r
+ ClientConnect->ServerId = ServerId;\r
+ ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;\r
+\r
+ /* Setup a buffer for the connection info */\r
+ CaptureBuffer = CsrAllocateCaptureBuffer(1,\r
+ ClientConnect->ConnectionInfoSize);\r
+\r
+ /* Allocate a pointer for the connection info*/\r
+ CsrAllocateMessagePointer(CaptureBuffer,\r
+ ClientConnect->ConnectionInfoSize,\r
+ &ClientConnect->ConnectionInfo);\r
+\r
+ /* Copy the data into the buffer */\r
+ RtlMoveMemory(ClientConnect->ConnectionInfo,\r
+ ConnectionInfo,\r
+ ClientConnect->ConnectionInfoSize);\r
+\r
+ /* Return the allocated length */\r
+ *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;\r
+\r
+ /* Call CSR */\r
+#if 0\r
+ Status = CsrClientCallServer(&ApiMessage,\r
+ CaptureBuffer,\r
+ CSR_MAKE_OPCODE(CsrSrvClientConnect,\r
+ CSR_SRV_DLL),\r
+ sizeof(CSR_CLIENT_CONNECT));\r
+#endif\r
+ Status = CsrClientCallServer(&RosApiMessage,\r
+ NULL,\r
+ MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),\r
+ sizeof(CSR_API_MESSAGE));\r
+ }\r
+ else\r
+ {\r
+ /* No connection info, just return */\r
+ Status = STATUS_SUCCESS;\r
+ }\r
+\r
+ /* Let the caller know if this was server to server */\r
+ DPRINT("Status was: %lx. Are we in server: %lx\n", Status, InsideCsrProcess);\r
+ if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;\r
+ return Status;\r
+}\r
+\r
+/* EOF */\r