X-Git-Url: https://git.reactos.org/?p=reactos.git;a=blobdiff_plain;f=dll%2Fntdll%2Fcsr%2Fconnect.c;h=a1a95096e37d288630d3ace6ba077214e5edf8d2;hp=e356cc772d5ecee6bf8195df235483521a8178c1;hb=8b15a5ecd703433de03e58f6354b18f490a7e0bc;hpb=4666acdde5d0fec237e472f7768cf0e8c5c769e6 diff --git a/dll/ntdll/csr/connect.c b/dll/ntdll/csr/connect.c index e356cc772d5..a1a95096e37 100644 --- a/dll/ntdll/csr/connect.c +++ b/dll/ntdll/csr/connect.c @@ -1,18 +1,22 @@ /* * COPYRIGHT: See COPYING in the top level directory * PROJECT: ReactOS kernel - * FILE: lib/ntdll/csr/connect.c + * FILE: dll/ntdll/csr/connect.c * PURPOSE: Routines for connecting and calling CSR * PROGRAMMER: Alex Ionescu (alex@relsoft.net) */ -/* INCLUDES *****************************************************************/ +/* INCLUDES *******************************************************************/ #include + +#include +#include + #define NDEBUG #include -/* GLOBALS *******************************************************************/ +/* GLOBALS ********************************************************************/ HANDLE CsrApiPort; HANDLE CsrProcessId; @@ -27,163 +31,25 @@ typedef NTSTATUS PCSR_SERVER_API_ROUTINE CsrServerApiRoutine; #define UNICODE_PATH_SEP L"\\" -#define CSR_PORT_NAME L"ApiPort" - -/* FUNCTIONS *****************************************************************/ - -/* - * @implemented - */ -HANDLE -NTAPI -CsrGetProcessId(VOID) -{ - return CsrProcessId; -} - -/* - * @implemented - */ -NTSTATUS -NTAPI -CsrClientCallServer(PCSR_API_MESSAGE ApiMessage, - PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL, - CSR_API_NUMBER ApiNumber, - ULONG RequestLength) -{ - NTSTATUS Status; - ULONG PointerCount; - PULONG_PTR Pointers; - ULONG_PTR CurrentPointer; - DPRINT("CsrClientCallServer\n"); - /* Fill out the Port Message Header */ - ApiMessage->Header.u2.ZeroInit = 0; - ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE); - ApiMessage->Header.u1.s1.TotalLength = RequestLength; - - /* Fill out the CSR Header */ - ApiMessage->Type = ApiNumber; - //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR - ApiMessage->CsrCaptureData = NULL; - - DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n", - ApiNumber, - ApiMessage->Header.u1.s1.DataLength, - ApiMessage->Header.u1.s1.TotalLength); - - /* Check if we are already inside a CSR Server */ - if (!InsideCsrProcess) - { - /* Check if we got a a Capture Buffer */ - if (CaptureBuffer) - { - /* We have to convert from our local view to the remote view */ - ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer + - CsrPortMemoryDelta); - - /* Lock the buffer */ - CaptureBuffer->BufferEnd = 0; - - /* Get the pointer information */ - PointerCount = CaptureBuffer->PointerCount; - Pointers = CaptureBuffer->PointerArray; - - /* Loop through every pointer and convert it */ - DPRINT("PointerCount: %lx\n", PointerCount); - while (PointerCount--) - { - /* Get this pointer and check if it's valid */ - DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n", - &Pointers, Pointers, *Pointers); - if ((CurrentPointer = *Pointers++)) - { - /* Update it */ - DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer); - *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta; - Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage; - DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer); - } - } - } - - /* Send the LPC Message */ - Status = NtRequestWaitReplyPort(CsrApiPort, - &ApiMessage->Header, - &ApiMessage->Header); - - /* Check if we got a a Capture Buffer */ - if (CaptureBuffer) - { - /* We have to convert from the remote view to our remote view */ - DPRINT("Reconverting CaptureBuffer\n"); - ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR) - ApiMessage->CsrCaptureData - - CsrPortMemoryDelta); - - /* Get the pointer information */ - PointerCount = CaptureBuffer->PointerCount; - Pointers = CaptureBuffer->PointerArray; - - /* Loop through every pointer and convert it */ - while (PointerCount--) - { - /* Get this pointer and check if it's valid */ - if ((CurrentPointer = *Pointers++)) - { - /* Update it */ - CurrentPointer += (ULONG_PTR)ApiMessage; - Pointers[-1] = CurrentPointer; - *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta; - } - } - } - - /* Check for success */ - if (!NT_SUCCESS(Status)) - { - /* We failed. Overwrite the return value with the failure */ - DPRINT1("LPC Failed: %lx\n", Status); - ApiMessage->Status = Status; - } - } - else - { - /* This is a server-to-server call. Save our CID and do a direct call */ - DPRINT1("Next gen server-to-server call\n"); - ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId; - Status = CsrServerApiRoutine(&ApiMessage->Header, - &ApiMessage->Header); - - /* Check for success */ - if (!NT_SUCCESS(Status)) - { - /* We failed. Overwrite the return value with the failure */ - ApiMessage->Status = Status; - } - } - - /* Return the CSR Result */ - DPRINT("Got back: 0x%lx\n", ApiMessage->Status); - return ApiMessage->Status; -} +/* FUNCTIONS ******************************************************************/ NTSTATUS NTAPI -CsrConnectToServer(IN PWSTR ObjectDirectory) +CsrpConnectToServer(IN PWSTR ObjectDirectory) { + NTSTATUS Status; ULONG PortNameLength; UNICODE_STRING PortName; LARGE_INTEGER CsrSectionViewSize; - NTSTATUS Status; HANDLE CsrSectionHandle; PORT_VIEW LpcWrite; REMOTE_PORT_VIEW LpcRead; SECURITY_QUALITY_OF_SERVICE SecurityQos; SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY}; PSID SystemSid = NULL; - CSR_CONNECTION_INFO ConnectionInfo; - ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO); + CSR_API_CONNECTINFO ConnectionInfo; + ULONG ConnectionInfoLength = sizeof(CSR_API_CONNECTINFO); DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory); @@ -202,7 +68,7 @@ CsrConnectToServer(IN PWSTR ObjectDirectory) PortName.MaximumLength = PortNameLength; /* Allocate a buffer for it */ - PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength); + PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength); if (PortName.Buffer == NULL) { return STATUS_INSUFFICIENT_RESOURCES; @@ -220,7 +86,7 @@ CsrConnectToServer(IN PWSTR ObjectDirectory) NULL, &CsrSectionViewSize, PAGE_READWRITE, - SEC_COMMIT, + SEC_RESERVE, NULL); if (!NT_SUCCESS(Status)) { @@ -245,20 +111,20 @@ CsrConnectToServer(IN PWSTR ObjectDirectory) SecurityQos.EffectiveOnly = TRUE; /* Setup the connection info */ - ConnectionInfo.Version = 0x10000; + ConnectionInfo.DebugFlags = 0; /* Create a SID for us */ Status = RtlAllocateAndInitializeSid(&NtSidAuthority, - 1, - SECURITY_LOCAL_SYSTEM_RID, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - &SystemSid); + 1, + SECURITY_LOCAL_SYSTEM_RID, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + &SystemSid); if (!NT_SUCCESS(Status)) { /* Failure */ @@ -277,6 +143,7 @@ CsrConnectToServer(IN PWSTR ObjectDirectory) NULL, &ConnectionInfo, &ConnectionInfoLength); + RtlFreeSid(SystemSid); NtClose(CsrSectionHandle); if (!NT_SUCCESS(Status)) { @@ -290,12 +157,12 @@ CsrConnectToServer(IN PWSTR ObjectDirectory) (ULONG_PTR)LpcWrite.ViewBase; /* Save the Process */ - CsrProcessId = ConnectionInfo.ProcessId; + CsrProcessId = ConnectionInfo.ServerProcessId; /* Save CSR Section data */ NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase; NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap; - NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData; + NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData; /* Create the port heap */ CsrPortHeap = RtlCreateHeap(0, @@ -306,6 +173,8 @@ CsrConnectToServer(IN PWSTR ObjectDirectory) 0); if (CsrPortHeap == NULL) { + /* Failure */ + DPRINT1("Couldn't create heap for CSR port\n"); NtClose(CsrApiPort); CsrApiPort = NULL; return STATUS_INSUFFICIENT_RESOURCES; @@ -320,24 +189,24 @@ CsrConnectToServer(IN PWSTR ObjectDirectory) */ NTSTATUS NTAPI -CsrClientConnectToServer(PWSTR ObjectDirectory, - ULONG ServerId, - PVOID ConnectionInfo, - PULONG ConnectionInfoSize, - PBOOLEAN ServerToServerCall) +CsrClientConnectToServer(IN PWSTR ObjectDirectory, + IN ULONG ServerId, + IN PVOID ConnectionInfo, + IN OUT PULONG ConnectionInfoSize, + OUT PBOOLEAN ServerToServerCall) { NTSTATUS Status; PIMAGE_NT_HEADERS NtHeader; UNICODE_STRING CsrSrvName; HANDLE hCsrSrv; ANSI_STRING CsrServerRoutineName; + CSR_API_MESSAGE ApiMessage; + PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect; PCSR_CAPTURE_BUFFER CaptureBuffer; - CSR_API_MESSAGE RosApiMessage; - CSR_API_MESSAGE2 ApiMessage; - PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect; - /* Validate the Connection Info */ DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo); + + /* Validate the Connection Info */ if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize)) { DPRINT1("Connection info given, but no length\n"); @@ -368,7 +237,7 @@ CsrClientConnectToServer(PWSTR ObjectDirectory, if (InsideCsrProcess) { /* We're inside, so let's find csrsrv */ - DPRINT1("Next-GEN CSRSS support\n"); + DPRINT("Next-GEN CSRSS support\n"); RtlInitUnicodeString(&CsrSrvName, L"csrsrv"); Status = LdrGetDllHandle(NULL, NULL, @@ -386,21 +255,22 @@ CsrClientConnectToServer(PWSTR ObjectDirectory, CsrPortHeap = RtlGetProcessHeap(); /* Tell the caller we're inside the server */ - *ServerToServerCall = InsideCsrProcess; + if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess; return STATUS_SUCCESS; } /* Now check if connection info is given */ if (ConnectionInfo) { - /* Well, we're defintely in a client now */ + /* Well, we're definitely in a client now */ InsideCsrProcess = FALSE; /* Do we have a connection to CSR yet? */ if (!CsrApiPort) { /* No, set it up now */ - if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory))) + Status = CsrpConnectToServer(ObjectDirectory); + if (!NT_SUCCESS(Status)) { /* Failed */ DPRINT1("Failure to connect to CSR\n"); @@ -413,38 +283,34 @@ CsrClientConnectToServer(PWSTR ObjectDirectory, ClientConnect->ConnectionInfoSize = *ConnectionInfoSize; /* Setup a buffer for the connection info */ - CaptureBuffer = CsrAllocateCaptureBuffer(1, - ClientConnect->ConnectionInfoSize); + CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize); if (CaptureBuffer == NULL) { return STATUS_INSUFFICIENT_RESOURCES; } - /* Allocate a pointer for the connection info*/ - CsrAllocateMessagePointer(CaptureBuffer, - ClientConnect->ConnectionInfoSize, - &ClientConnect->ConnectionInfo); - - /* Copy the data into the buffer */ - RtlMoveMemory(ClientConnect->ConnectionInfo, - ConnectionInfo, - ClientConnect->ConnectionInfoSize); + /* Capture the connection info data */ + CsrCaptureMessageBuffer(CaptureBuffer, + ConnectionInfo, + ClientConnect->ConnectionInfoSize, + &ClientConnect->ConnectionInfo); /* Return the allocated length */ *ConnectionInfoSize = ClientConnect->ConnectionInfoSize; /* Call CSR */ -#if 0 Status = CsrClientCallServer(&ApiMessage, CaptureBuffer, - CSR_MAKE_OPCODE(CsrpClientConnect, - CSR_SRV_DLL), + CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect), sizeof(CSR_CLIENT_CONNECT)); -#endif - Status = CsrClientCallServer(&RosApiMessage, - NULL, - MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE), - sizeof(CSR_API_MESSAGE)); + + /* Copy the updated connection info data back into the user buffer */ + RtlMoveMemory(ConnectionInfo, + ClientConnect->ConnectionInfo, + *ConnectionInfoSize); + + /* Free the capture buffer */ + CsrFreeCaptureBuffer(CaptureBuffer); } else { @@ -455,7 +321,174 @@ CsrClientConnectToServer(PWSTR ObjectDirectory, /* Let the caller know if this was server to server */ DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess); if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess; + return Status; } +#if 0 +// +// Structures can be padded at the end, causing the size of the entire structure +// minus the size of the last field, not to be equal to the offset of the last +// field. +// +typedef struct _TEST_EMBEDDED +{ + ULONG One; + ULONG Two; + ULONG Three; +} TEST_EMBEDDED; + +typedef struct _TEST +{ + PORT_MESSAGE h; + TEST_EMBEDDED Three; +} TEST; + +C_ASSERT(sizeof(PORT_MESSAGE) == 0x18); +C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18); +C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC); + +C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE))); +C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three)); +#endif + +/* + * @implemented + */ +NTSTATUS +NTAPI +CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage, + IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL, + IN CSR_API_NUMBER ApiNumber, + IN ULONG DataLength) +{ + NTSTATUS Status; + ULONG PointerCount; + PULONG_PTR OffsetPointer; + + /* Fill out the Port Message Header */ + ApiMessage->Header.u2.ZeroInit = 0; + ApiMessage->Header.u1.s1.TotalLength = DataLength + + sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data); // FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength; + ApiMessage->Header.u1.s1.DataLength = DataLength + + FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header); // ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE); + + /* Fill out the CSR Header */ + ApiMessage->ApiNumber = ApiNumber; + ApiMessage->CsrCaptureData = NULL; + + DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n", + ApiNumber, + ApiMessage->Header.u1.s1.DataLength, + ApiMessage->Header.u1.s1.TotalLength); + + /* Check if we are already inside a CSR Server */ + if (!InsideCsrProcess) + { + /* Check if we got a Capture Buffer */ + if (CaptureBuffer) + { + /* + * We have to convert from our local (client) view + * to the remote (server) view. + */ + ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER) + ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta); + + /* Lock the buffer. */ + CaptureBuffer->BufferEnd = NULL; + + /* + * Each client pointer inside the CSR message is converted into + * a server pointer, and each pointer to these message pointers + * is converted into an offset. + */ + PointerCount = CaptureBuffer->PointerCount; + OffsetPointer = CaptureBuffer->PointerOffsetsArray; + while (PointerCount--) + { + if (*OffsetPointer != 0) + { + *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta; + *OffsetPointer -= (ULONG_PTR)ApiMessage; + } + ++OffsetPointer; + } + } + + /* Send the LPC Message */ + Status = NtRequestWaitReplyPort(CsrApiPort, + &ApiMessage->Header, + &ApiMessage->Header); + + /* Check if we got a Capture Buffer */ + if (CaptureBuffer) + { + /* + * We have to convert back from the remote (server) view + * to our local (client) view. + */ + ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER) + ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta); + + /* + * Convert back the offsets into pointers to CSR message + * pointers, and convert back these message server pointers + * into client pointers. + */ + PointerCount = CaptureBuffer->PointerCount; + OffsetPointer = CaptureBuffer->PointerOffsetsArray; + while (PointerCount--) + { + if (*OffsetPointer != 0) + { + *OffsetPointer += (ULONG_PTR)ApiMessage; + *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta; + } + ++OffsetPointer; + } + } + + /* Check for success */ + if (!NT_SUCCESS(Status)) + { + /* We failed. Overwrite the return value with the failure. */ + DPRINT1("LPC Failed: %lx\n", Status); + ApiMessage->Status = Status; + } + } + else + { + /* This is a server-to-server call. Save our CID and do a direct call. */ + DPRINT("Next gen server-to-server call\n"); + + /* We check this equality inside CsrValidateMessageBuffer */ + ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId; + + Status = CsrServerApiRoutine(&ApiMessage->Header, + &ApiMessage->Header); + + /* Check for success */ + if (!NT_SUCCESS(Status)) + { + /* We failed. Overwrite the return value with the failure. */ + ApiMessage->Status = Status; + } + } + + /* Return the CSR Result */ + DPRINT("Got back: 0x%lx\n", ApiMessage->Status); + return ApiMessage->Status; +} + +/* + * @implemented + */ +HANDLE +NTAPI +CsrGetProcessId(VOID) +{ + return CsrProcessId; +} + /* EOF */