/*
* COPYRIGHT: See COPYING in the top level directory
* PROJECT: ReactOS kernel
- * FILE: lib/ntdll/csr/connect.c
+ * FILE: dll/ntdll/csr/connect.c
* PURPOSE: Routines for connecting and calling CSR
* PROGRAMMER: Alex Ionescu (alex@relsoft.net)
*/
-/* INCLUDES *****************************************************************/
+/* INCLUDES *******************************************************************/
#include <ntdll.h>
#define NDEBUG
#include <debug.h>
-/* GLOBALS *******************************************************************/
+/* GLOBALS ********************************************************************/
HANDLE CsrApiPort;
HANDLE CsrProcessId;
HANDLE CsrPortHeap;
ULONG_PTR CsrPortMemoryDelta;
BOOLEAN InsideCsrProcess = FALSE;
-BOOLEAN UsingOldCsr = TRUE;
typedef NTSTATUS
(NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
#define UNICODE_PATH_SEP L"\\"
-#define CSR_PORT_NAME L"ApiPort"
-/* FUNCTIONS *****************************************************************/
-
-/*
- * @implemented
- */
-HANDLE
-NTAPI
-CsrGetProcessId(VOID)
-{
- return CsrProcessId;
-}
-
-/*
- * @implemented
- */
-NTSTATUS
-NTAPI
-CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,
- PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
- CSR_API_NUMBER ApiNumber,
- ULONG RequestLength)
-{
- NTSTATUS Status;
- ULONG PointerCount;
- PULONG_PTR Pointers;
- ULONG_PTR CurrentPointer;
- DPRINT("CsrClientCallServer\n");
-
- /* Fill out the Port Message Header */
- ApiMessage->Header.u2.ZeroInit = 0;
- ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);
- ApiMessage->Header.u1.s1.TotalLength = RequestLength;
-
- /* Fill out the CSR Header */
- ApiMessage->Type = ApiNumber;
- //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR
- ApiMessage->CsrCaptureData = NULL;
-
- DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
- ApiNumber,
- ApiMessage->Header.u1.s1.DataLength,
- ApiMessage->Header.u1.s1.TotalLength);
-
- /* Check if we are already inside a CSR Server */
- if (!InsideCsrProcess)
- {
- /* Check if we got a a Capture Buffer */
- if (CaptureBuffer)
- {
- /* We have to convert from our local view to the remote view */
- ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
- CsrPortMemoryDelta);
-
- /* Lock the buffer */
- CaptureBuffer->BufferEnd = 0;
-
- /* Get the pointer information */
- PointerCount = CaptureBuffer->PointerCount;
- Pointers = CaptureBuffer->PointerArray;
-
- /* Loop through every pointer and convert it */
- DPRINT("PointerCount: %lx\n", PointerCount);
- while (PointerCount--)
- {
- /* Get this pointer and check if it's valid */
- DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n",
- &Pointers, Pointers, *Pointers);
- if ((CurrentPointer = *Pointers++))
- {
- /* Update it */
- DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
- *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
- Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
- DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
- }
- }
- }
-
- /* Send the LPC Message */
- Status = NtRequestWaitReplyPort(CsrApiPort,
- &ApiMessage->Header,
- &ApiMessage->Header);
-
- /* Check if we got a a Capture Buffer */
- if (CaptureBuffer)
- {
- /* We have to convert from the remote view to our remote view */
- DPRINT("Reconverting CaptureBuffer\n");
- ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
- ApiMessage->CsrCaptureData -
- CsrPortMemoryDelta);
-
- /* Get the pointer information */
- PointerCount = CaptureBuffer->PointerCount;
- Pointers = CaptureBuffer->PointerArray;
-
- /* Loop through every pointer and convert it */
- while (PointerCount--)
- {
- /* Get this pointer and check if it's valid */
- if ((CurrentPointer = *Pointers++))
- {
- /* Update it */
- CurrentPointer += (ULONG_PTR)ApiMessage;
- Pointers[-1] = CurrentPointer;
- *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
- }
- }
- }
-
- /* Check for success */
- if (!NT_SUCCESS(Status))
- {
- /* We failed. Overwrite the return value with the failure */
- DPRINT1("LPC Failed: %lx\n", Status);
- ApiMessage->Status = Status;
- }
- }
- else
- {
- /* This is a server-to-server call. Save our CID and do a direct call */
- DbgBreakPoint();
- ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
- Status = CsrServerApiRoutine(&ApiMessage->Header,
- &ApiMessage->Header);
-
- /* Check for success */
- if (!NT_SUCCESS(Status))
- {
- /* We failed. Overwrite the return value with the failure */
- ApiMessage->Status = Status;
- }
- }
-
- /* Return the CSR Result */
- DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
- return ApiMessage->Status;
-}
+/* FUNCTIONS ******************************************************************/
NTSTATUS
NTAPI
-CsrConnectToServer(IN PWSTR ObjectDirectory)
+CsrpConnectToServer(IN PWSTR ObjectDirectory)
{
+ NTSTATUS Status;
ULONG PortNameLength;
UNICODE_STRING PortName;
LARGE_INTEGER CsrSectionViewSize;
- NTSTATUS Status;
HANDLE CsrSectionHandle;
PORT_VIEW LpcWrite;
REMOTE_PORT_VIEW LpcRead;
SECURITY_QUALITY_OF_SERVICE SecurityQos;
SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
PSID SystemSid = NULL;
- CSR_CONNECTION_INFO ConnectionInfo;
- ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
+ CSR_API_CONNECTINFO ConnectionInfo;
+ ULONG ConnectionInfoLength = sizeof(CSR_API_CONNECTINFO);
DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
PortName.MaximumLength = PortNameLength;
/* Allocate a buffer for it */
- PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
+ PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
if (PortName.Buffer == NULL)
{
return STATUS_INSUFFICIENT_RESOURCES;
NULL,
&CsrSectionViewSize,
PAGE_READWRITE,
- SEC_COMMIT,
+ SEC_RESERVE,
NULL);
if (!NT_SUCCESS(Status))
{
SecurityQos.EffectiveOnly = TRUE;
/* Setup the connection info */
- ConnectionInfo.Version = 0x10000;
+ ConnectionInfo.DebugFlags = 0;
/* Create a SID for us */
Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
- 1,
- SECURITY_LOCAL_SYSTEM_RID,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- &SystemSid);
+ 1,
+ SECURITY_LOCAL_SYSTEM_RID,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ &SystemSid);
if (!NT_SUCCESS(Status))
{
/* Failure */
NULL,
&ConnectionInfo,
&ConnectionInfoLength);
+ RtlFreeSid(SystemSid);
NtClose(CsrSectionHandle);
if (!NT_SUCCESS(Status))
{
(ULONG_PTR)LpcWrite.ViewBase;
/* Save the Process */
- CsrProcessId = ConnectionInfo.ProcessId;
+ CsrProcessId = ConnectionInfo.ServerProcessId;
/* Save CSR Section data */
NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
- NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
+ NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
/* Create the port heap */
CsrPortHeap = RtlCreateHeap(0,
0);
if (CsrPortHeap == NULL)
{
+ /* Failure */
+ DPRINT1("Couldn't create heap for CSR port\n");
NtClose(CsrApiPort);
CsrApiPort = NULL;
return STATUS_INSUFFICIENT_RESOURCES;
*/
NTSTATUS
NTAPI
-CsrClientConnectToServer(PWSTR ObjectDirectory,
- ULONG ServerId,
- PVOID ConnectionInfo,
- PULONG ConnectionInfoSize,
- PBOOLEAN ServerToServerCall)
+CsrClientConnectToServer(IN PWSTR ObjectDirectory,
+ IN ULONG ServerId,
+ IN PVOID ConnectionInfo,
+ IN OUT PULONG ConnectionInfoSize,
+ OUT PBOOLEAN ServerToServerCall)
{
NTSTATUS Status;
PIMAGE_NT_HEADERS NtHeader;
UNICODE_STRING CsrSrvName;
HANDLE hCsrSrv;
ANSI_STRING CsrServerRoutineName;
+ CSR_API_MESSAGE ApiMessage;
+ PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
PCSR_CAPTURE_BUFFER CaptureBuffer;
- CSR_API_MESSAGE RosApiMessage;
- CSR_API_MESSAGE2 ApiMessage;
- PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;
- /* Validate the Connection Info */
DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
+
+ /* Validate the Connection Info */
if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
{
DPRINT1("Connection info given, but no length\n");
InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
/* Now we can check if we are inside or not */
- if (InsideCsrProcess && !UsingOldCsr)
+ if (InsideCsrProcess)
{
/* We're inside, so let's find csrsrv */
- DbgBreakPoint();
+ DPRINT("Next-GEN CSRSS support\n");
RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
Status = LdrGetDllHandle(NULL,
NULL,
if (!CsrApiPort)
{
/* No, set it up now */
- if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))
+ Status = CsrpConnectToServer(ObjectDirectory);
+ if (!NT_SUCCESS(Status))
{
/* Failed */
DPRINT1("Failure to connect to CSR\n");
ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
/* Setup a buffer for the connection info */
- CaptureBuffer = CsrAllocateCaptureBuffer(1,
- ClientConnect->ConnectionInfoSize);
+ CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
if (CaptureBuffer == NULL)
{
return STATUS_INSUFFICIENT_RESOURCES;
}
- /* Allocate a pointer for the connection info*/
- CsrAllocateMessagePointer(CaptureBuffer,
- ClientConnect->ConnectionInfoSize,
- &ClientConnect->ConnectionInfo);
-
- /* Copy the data into the buffer */
- RtlMoveMemory(ClientConnect->ConnectionInfo,
- ConnectionInfo,
- ClientConnect->ConnectionInfoSize);
+ /* Capture the connection info data */
+ CsrCaptureMessageBuffer(CaptureBuffer,
+ ConnectionInfo,
+ ClientConnect->ConnectionInfoSize,
+ &ClientConnect->ConnectionInfo);
/* Return the allocated length */
*ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
/* Call CSR */
-#if 0
Status = CsrClientCallServer(&ApiMessage,
CaptureBuffer,
- CSR_MAKE_OPCODE(CsrSrvClientConnect,
- CSR_SRV_DLL),
+ CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
sizeof(CSR_CLIENT_CONNECT));
-#endif
- Status = CsrClientCallServer(&RosApiMessage,
- NULL,
- MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),
- sizeof(CSR_API_MESSAGE));
+
+ /* Copy the updated connection info data back into the user buffer */
+ RtlMoveMemory(ConnectionInfo,
+ ClientConnect->ConnectionInfo,
+ *ConnectionInfoSize);
+
+ /* Free the capture buffer */
+ CsrFreeCaptureBuffer(CaptureBuffer);
}
else
{
/* Let the caller know if this was server to server */
DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
+
return Status;
}
+#if 0
+//
+// Structures can be padded at the end, causing the size of the entire structure
+// minus the size of the last field, not to be equal to the offset of the last
+// field.
+//
+typedef struct _TEST_EMBEDDED
+{
+ ULONG One;
+ ULONG Two;
+ ULONG Three;
+} TEST_EMBEDDED;
+
+typedef struct _TEST
+{
+ PORT_MESSAGE h;
+ TEST_EMBEDDED Three;
+} TEST;
+
+C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
+C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
+C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
+
+C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
+C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
+#endif
+
+/*
+ * @implemented
+ */
+NTSTATUS
+NTAPI
+CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
+ IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
+ IN CSR_API_NUMBER ApiNumber,
+ IN ULONG DataLength)
+{
+ NTSTATUS Status;
+ ULONG PointerCount;
+ PULONG_PTR OffsetPointer;
+
+ /* Fill out the Port Message Header */
+ ApiMessage->Header.u2.ZeroInit = 0;
+ ApiMessage->Header.u1.s1.TotalLength = DataLength +
+ sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data); // FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
+ ApiMessage->Header.u1.s1.DataLength = DataLength +
+ FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header);// ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
+
+ /* Fill out the CSR Header */
+ ApiMessage->ApiNumber = ApiNumber;
+ ApiMessage->CsrCaptureData = NULL;
+
+ DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
+ ApiNumber,
+ ApiMessage->Header.u1.s1.DataLength,
+ ApiMessage->Header.u1.s1.TotalLength);
+
+ /* Check if we are already inside a CSR Server */
+ if (!InsideCsrProcess)
+ {
+ /* Check if we got a Capture Buffer */
+ if (CaptureBuffer)
+ {
+ /*
+ * We have to convert from our local (client) view
+ * to the remote (server) view.
+ */
+ ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+ ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
+
+ /* Lock the buffer. */
+ CaptureBuffer->BufferEnd = NULL;
+
+ /*
+ * Each client pointer inside the CSR message is converted into
+ * a server pointer, and each pointer to these message pointers
+ * is converted into an offset.
+ */
+ PointerCount = CaptureBuffer->PointerCount;
+ OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+ while (PointerCount--)
+ {
+ if (*OffsetPointer != 0)
+ {
+ *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
+ *OffsetPointer -= (ULONG_PTR)ApiMessage;
+ }
+ ++OffsetPointer;
+ }
+ }
+
+ /* Send the LPC Message */
+ Status = NtRequestWaitReplyPort(CsrApiPort,
+ &ApiMessage->Header,
+ &ApiMessage->Header);
+
+ /* Check if we got a Capture Buffer */
+ if (CaptureBuffer)
+ {
+ /*
+ * We have to convert back from the remote (server) view
+ * to our local (client) view.
+ */
+ ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
+ ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
+
+ /*
+ * Convert back the offsets into pointers to CSR message
+ * pointers, and convert back these message server pointers
+ * into client pointers.
+ */
+ PointerCount = CaptureBuffer->PointerCount;
+ OffsetPointer = CaptureBuffer->PointerOffsetsArray;
+ while (PointerCount--)
+ {
+ if (*OffsetPointer != 0)
+ {
+ *OffsetPointer += (ULONG_PTR)ApiMessage;
+ *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
+ }
+ ++OffsetPointer;
+ }
+ }
+
+ /* Check for success */
+ if (!NT_SUCCESS(Status))
+ {
+ /* We failed. Overwrite the return value with the failure. */
+ DPRINT1("LPC Failed: %lx\n", Status);
+ ApiMessage->Status = Status;
+ }
+ }
+ else
+ {
+ /* This is a server-to-server call. Save our CID and do a direct call. */
+ DPRINT("Next gen server-to-server call\n");
+
+ /* We check this equality inside CsrValidateMessageBuffer */
+ ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
+
+ Status = CsrServerApiRoutine(&ApiMessage->Header,
+ &ApiMessage->Header);
+
+ /* Check for success */
+ if (!NT_SUCCESS(Status))
+ {
+ /* We failed. Overwrite the return value with the failure. */
+ ApiMessage->Status = Status;
+ }
+ }
+
+ /* Return the CSR Result */
+ DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
+ return ApiMessage->Status;
+}
+
+/*
+ * @implemented
+ */
+HANDLE
+NTAPI
+CsrGetProcessId(VOID)
+{
+ return CsrProcessId;
+}
+
/* EOF */