[BASESRV-CSRSRV]
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS ********************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22
23 typedef NTSTATUS
24 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
25 IN PPORT_MESSAGE Reply);
26
27 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
28
29 #define UNICODE_PATH_SEP L"\\"
30
31 /* FUNCTIONS ******************************************************************/
32
33 /*
34 * @implemented
35 */
36 HANDLE
37 NTAPI
38 CsrGetProcessId(VOID)
39 {
40 return CsrProcessId;
41 }
42
43 /*
44 * @implemented
45 */
46 NTSTATUS
47 NTAPI
48 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
49 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
50 IN CSR_API_NUMBER ApiNumber,
51 IN ULONG DataLength)
52 {
53 NTSTATUS Status;
54 ULONG PointerCount;
55 PULONG_PTR OffsetPointer;
56
57 /* Fill out the Port Message Header. */
58 ApiMessage->Header.u2.ZeroInit = 0;
59 ApiMessage->Header.u1.s1.TotalLength =
60 FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
61 ApiMessage->Header.u1.s1.DataLength =
62 ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
63
64 /* Fill out the CSR Header. */
65 ApiMessage->ApiNumber = ApiNumber;
66 ApiMessage->CsrCaptureData = NULL;
67
68 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
69 ApiNumber,
70 ApiMessage->Header.u1.s1.DataLength,
71 ApiMessage->Header.u1.s1.TotalLength);
72
73 /* Check if we are already inside a CSR Server. */
74 if (!InsideCsrProcess)
75 {
76 /* Check if we got a Capture Buffer. */
77 if (CaptureBuffer)
78 {
79 /*
80 * We have to convert from our local (client) view
81 * to the remote (server) view.
82 */
83 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
84 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
85
86 /* Lock the buffer. */
87 CaptureBuffer->BufferEnd = NULL;
88
89 /*
90 * Each client pointer inside the CSR message is converted into
91 * a server pointer, and each pointer to these message pointers
92 * is converted into an offset.
93 */
94 PointerCount = CaptureBuffer->PointerCount;
95 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
96 while (PointerCount--)
97 {
98 if (*OffsetPointer != 0)
99 {
100 *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
101 *OffsetPointer -= (ULONG_PTR)ApiMessage;
102 }
103 ++OffsetPointer;
104 }
105 }
106
107 /* Send the LPC Message. */
108 Status = NtRequestWaitReplyPort(CsrApiPort,
109 &ApiMessage->Header,
110 &ApiMessage->Header);
111
112 /* Check if we got a Capture Buffer. */
113 if (CaptureBuffer)
114 {
115 /*
116 * We have to convert back from the remote (server) view
117 * to our local (client) view.
118 */
119 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
120 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
121
122 /*
123 * Convert back the offsets into pointers to CSR message
124 * pointers, and convert back these message server pointers
125 * into client pointers.
126 */
127 PointerCount = CaptureBuffer->PointerCount;
128 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
129 while (PointerCount--)
130 {
131 if (*OffsetPointer != 0)
132 {
133 *OffsetPointer += (ULONG_PTR)ApiMessage;
134 *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
135 }
136 ++OffsetPointer;
137 }
138 }
139
140 /* Check for success. */
141 if (!NT_SUCCESS(Status))
142 {
143 /* We failed. Overwrite the return value with the failure. */
144 DPRINT1("LPC Failed: %lx\n", Status);
145 ApiMessage->Status = Status;
146 }
147 }
148 else
149 {
150 /* This is a server-to-server call. Save our CID and do a direct call. */
151 DPRINT1("Next gen server-to-server call\n");
152
153 /* We check this equality inside CsrValidateMessageBuffer. */
154 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
155
156 Status = CsrServerApiRoutine(&ApiMessage->Header,
157 &ApiMessage->Header);
158
159 /* Check for success. */
160 if (!NT_SUCCESS(Status))
161 {
162 /* We failed. Overwrite the return value with the failure. */
163 ApiMessage->Status = Status;
164 }
165 }
166
167 /* Return the CSR Result. */
168 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
169 return ApiMessage->Status;
170 }
171
172 NTSTATUS
173 NTAPI
174 CsrpConnectToServer(IN PWSTR ObjectDirectory)
175 {
176 NTSTATUS Status;
177 ULONG PortNameLength;
178 UNICODE_STRING PortName;
179 LARGE_INTEGER CsrSectionViewSize;
180 HANDLE CsrSectionHandle;
181 PORT_VIEW LpcWrite;
182 REMOTE_PORT_VIEW LpcRead;
183 SECURITY_QUALITY_OF_SERVICE SecurityQos;
184 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
185 PSID SystemSid = NULL;
186 CSR_CONNECTION_INFO ConnectionInfo;
187 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
188
189 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
190
191 /* Binary compatibility with MS KERNEL32 */
192 if (NULL == ObjectDirectory)
193 {
194 ObjectDirectory = L"\\Windows";
195 }
196
197 /* Calculate the total port name size */
198 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
199 sizeof(CSR_PORT_NAME);
200
201 /* Set the port name */
202 PortName.Length = 0;
203 PortName.MaximumLength = PortNameLength;
204
205 /* Allocate a buffer for it */
206 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
207 if (PortName.Buffer == NULL)
208 {
209 return STATUS_INSUFFICIENT_RESOURCES;
210 }
211
212 /* Create the name */
213 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
214 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
215 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
216
217 /* Create a section for the port memory */
218 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
219 Status = NtCreateSection(&CsrSectionHandle,
220 SECTION_ALL_ACCESS,
221 NULL,
222 &CsrSectionViewSize,
223 PAGE_READWRITE,
224 SEC_RESERVE,
225 NULL);
226 if (!NT_SUCCESS(Status))
227 {
228 DPRINT1("Failure allocating CSR Section\n");
229 return Status;
230 }
231
232 /* Set up the port view structures to match them with the section */
233 LpcWrite.Length = sizeof(PORT_VIEW);
234 LpcWrite.SectionHandle = CsrSectionHandle;
235 LpcWrite.SectionOffset = 0;
236 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
237 LpcWrite.ViewBase = 0;
238 LpcWrite.ViewRemoteBase = 0;
239 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
240 LpcRead.ViewSize = 0;
241 LpcRead.ViewBase = 0;
242
243 /* Setup the QoS */
244 SecurityQos.ImpersonationLevel = SecurityImpersonation;
245 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
246 SecurityQos.EffectiveOnly = TRUE;
247
248 /* Setup the connection info */
249 ConnectionInfo.Version = 0x10000;
250
251 /* Create a SID for us */
252 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
253 1,
254 SECURITY_LOCAL_SYSTEM_RID,
255 0,
256 0,
257 0,
258 0,
259 0,
260 0,
261 0,
262 &SystemSid);
263 if (!NT_SUCCESS(Status))
264 {
265 /* Failure */
266 DPRINT1("Couldn't allocate SID\n");
267 NtClose(CsrSectionHandle);
268 return Status;
269 }
270
271 /* Connect to the port */
272 Status = NtSecureConnectPort(&CsrApiPort,
273 &PortName,
274 &SecurityQos,
275 &LpcWrite,
276 SystemSid,
277 &LpcRead,
278 NULL,
279 &ConnectionInfo,
280 &ConnectionInfoLength);
281 RtlFreeSid(SystemSid);
282 NtClose(CsrSectionHandle);
283 if (!NT_SUCCESS(Status))
284 {
285 /* Failure */
286 DPRINT1("Couldn't connect to CSR port\n");
287 return Status;
288 }
289
290 /* Save the delta between the sections, for capture usage later */
291 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
292 (ULONG_PTR)LpcWrite.ViewBase;
293
294 /* Save the Process */
295 CsrProcessId = ConnectionInfo.ProcessId;
296
297 /* Save CSR Section data */
298 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
299 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
300 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
301
302 /* Create the port heap */
303 CsrPortHeap = RtlCreateHeap(0,
304 LpcWrite.ViewBase,
305 LpcWrite.ViewSize,
306 PAGE_SIZE,
307 0,
308 0);
309 if (CsrPortHeap == NULL)
310 {
311 /* Failure */
312 DPRINT1("Couldn't create heap for CSR port\n");
313 NtClose(CsrApiPort);
314 CsrApiPort = NULL;
315 return STATUS_INSUFFICIENT_RESOURCES;
316 }
317
318 /* Return success */
319 return STATUS_SUCCESS;
320 }
321
322 /*
323 * @implemented
324 */
325 NTSTATUS
326 NTAPI
327 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
328 IN ULONG ServerId,
329 IN PVOID ConnectionInfo,
330 IN OUT PULONG ConnectionInfoSize,
331 OUT PBOOLEAN ServerToServerCall)
332 {
333 NTSTATUS Status;
334 PIMAGE_NT_HEADERS NtHeader;
335 UNICODE_STRING CsrSrvName;
336 HANDLE hCsrSrv;
337 ANSI_STRING CsrServerRoutineName;
338 CSR_API_MESSAGE ApiMessage;
339 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
340 PCSR_CAPTURE_BUFFER CaptureBuffer;
341
342 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
343
344 /* Validate the Connection Info */
345 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
346 {
347 DPRINT1("Connection info given, but no length\n");
348 return STATUS_INVALID_PARAMETER;
349 }
350
351 /* Check if we're inside a CSR Process */
352 if (InsideCsrProcess)
353 {
354 /* Tell the client that we're already inside CSR */
355 if (ServerToServerCall) *ServerToServerCall = TRUE;
356 return STATUS_SUCCESS;
357 }
358
359 /*
360 * We might be in a CSR Process but not know it, if this is the first call.
361 * So let's find out.
362 */
363 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
364 {
365 /* The image isn't valid */
366 DPRINT1("Invalid image\n");
367 return STATUS_INVALID_IMAGE_FORMAT;
368 }
369 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
370
371 /* Now we can check if we are inside or not */
372 if (InsideCsrProcess)
373 {
374 /* We're inside, so let's find csrsrv */
375 DPRINT1("Next-GEN CSRSS support\n");
376 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
377 Status = LdrGetDllHandle(NULL,
378 NULL,
379 &CsrSrvName,
380 &hCsrSrv);
381
382 /* Now get the Server to Server routine */
383 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
384 Status = LdrGetProcedureAddress(hCsrSrv,
385 &CsrServerRoutineName,
386 0L,
387 (PVOID*)&CsrServerApiRoutine);
388
389 /* Use the local heap as port heap */
390 CsrPortHeap = RtlGetProcessHeap();
391
392 /* Tell the caller we're inside the server */
393 *ServerToServerCall = InsideCsrProcess;
394 return STATUS_SUCCESS;
395 }
396
397 /* Now check if connection info is given */
398 if (ConnectionInfo)
399 {
400 /* Well, we're defintely in a client now */
401 InsideCsrProcess = FALSE;
402
403 /* Do we have a connection to CSR yet? */
404 if (!CsrApiPort)
405 {
406 /* No, set it up now */
407 if (!NT_SUCCESS(Status = CsrpConnectToServer(ObjectDirectory)))
408 {
409 /* Failed */
410 DPRINT1("Failure to connect to CSR\n");
411 return Status;
412 }
413 }
414
415 /* Setup the connect message header */
416 ClientConnect->ServerId = ServerId;
417 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
418
419 /* Setup a buffer for the connection info */
420 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
421 if (CaptureBuffer == NULL)
422 {
423 return STATUS_INSUFFICIENT_RESOURCES;
424 }
425
426 /* Capture the connection info data */
427 CsrCaptureMessageBuffer(CaptureBuffer,
428 ConnectionInfo,
429 ClientConnect->ConnectionInfoSize,
430 &ClientConnect->ConnectionInfo);
431
432 /* Return the allocated length */
433 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
434
435 /* Call CSR */
436 Status = CsrClientCallServer(&ApiMessage,
437 CaptureBuffer,
438 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
439 sizeof(CSR_CLIENT_CONNECT));
440
441 /* Copy the updated connection info data back into the user buffer */
442 RtlMoveMemory(ConnectionInfo,
443 ClientConnect->ConnectionInfo,
444 *ConnectionInfoSize);
445
446 /* Free the capture buffer */
447 CsrFreeCaptureBuffer(CaptureBuffer);
448 }
449 else
450 {
451 /* No connection info, just return */
452 Status = STATUS_SUCCESS;
453 }
454
455 /* Let the caller know if this was server to server */
456 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
457 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
458
459 return Status;
460 }
461
462 /* EOF */