[AVIFIL32]
[reactos.git] / reactos / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: dll/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS ********************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22
23 typedef NTSTATUS
24 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
25 IN PPORT_MESSAGE Reply);
26
27 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
28
29 #define UNICODE_PATH_SEP L"\\"
30
31 /* FUNCTIONS ******************************************************************/
32
33 NTSTATUS
34 NTAPI
35 CsrpConnectToServer(IN PWSTR ObjectDirectory)
36 {
37 NTSTATUS Status;
38 ULONG PortNameLength;
39 UNICODE_STRING PortName;
40 LARGE_INTEGER CsrSectionViewSize;
41 HANDLE CsrSectionHandle;
42 PORT_VIEW LpcWrite;
43 REMOTE_PORT_VIEW LpcRead;
44 SECURITY_QUALITY_OF_SERVICE SecurityQos;
45 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
46 PSID SystemSid = NULL;
47 CSR_API_CONNECTINFO ConnectionInfo;
48 ULONG ConnectionInfoLength = sizeof(CSR_API_CONNECTINFO);
49
50 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
51
52 /* Binary compatibility with MS KERNEL32 */
53 if (NULL == ObjectDirectory)
54 {
55 ObjectDirectory = L"\\Windows";
56 }
57
58 /* Calculate the total port name size */
59 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
60 sizeof(CSR_PORT_NAME);
61
62 /* Set the port name */
63 PortName.Length = 0;
64 PortName.MaximumLength = PortNameLength;
65
66 /* Allocate a buffer for it */
67 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
68 if (PortName.Buffer == NULL)
69 {
70 return STATUS_INSUFFICIENT_RESOURCES;
71 }
72
73 /* Create the name */
74 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
75 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
76 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
77
78 /* Create a section for the port memory */
79 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
80 Status = NtCreateSection(&CsrSectionHandle,
81 SECTION_ALL_ACCESS,
82 NULL,
83 &CsrSectionViewSize,
84 PAGE_READWRITE,
85 SEC_RESERVE,
86 NULL);
87 if (!NT_SUCCESS(Status))
88 {
89 DPRINT1("Failure allocating CSR Section\n");
90 return Status;
91 }
92
93 /* Set up the port view structures to match them with the section */
94 LpcWrite.Length = sizeof(PORT_VIEW);
95 LpcWrite.SectionHandle = CsrSectionHandle;
96 LpcWrite.SectionOffset = 0;
97 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
98 LpcWrite.ViewBase = 0;
99 LpcWrite.ViewRemoteBase = 0;
100 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
101 LpcRead.ViewSize = 0;
102 LpcRead.ViewBase = 0;
103
104 /* Setup the QoS */
105 SecurityQos.ImpersonationLevel = SecurityImpersonation;
106 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
107 SecurityQos.EffectiveOnly = TRUE;
108
109 /* Setup the connection info */
110 ConnectionInfo.DebugFlags = 0;
111
112 /* Create a SID for us */
113 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
114 1,
115 SECURITY_LOCAL_SYSTEM_RID,
116 0,
117 0,
118 0,
119 0,
120 0,
121 0,
122 0,
123 &SystemSid);
124 if (!NT_SUCCESS(Status))
125 {
126 /* Failure */
127 DPRINT1("Couldn't allocate SID\n");
128 NtClose(CsrSectionHandle);
129 return Status;
130 }
131
132 /* Connect to the port */
133 Status = NtSecureConnectPort(&CsrApiPort,
134 &PortName,
135 &SecurityQos,
136 &LpcWrite,
137 SystemSid,
138 &LpcRead,
139 NULL,
140 &ConnectionInfo,
141 &ConnectionInfoLength);
142 RtlFreeSid(SystemSid);
143 NtClose(CsrSectionHandle);
144 if (!NT_SUCCESS(Status))
145 {
146 /* Failure */
147 DPRINT1("Couldn't connect to CSR port\n");
148 return Status;
149 }
150
151 /* Save the delta between the sections, for capture usage later */
152 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
153 (ULONG_PTR)LpcWrite.ViewBase;
154
155 /* Save the Process */
156 CsrProcessId = ConnectionInfo.ServerProcessId;
157
158 /* Save CSR Section data */
159 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
160 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
161 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
162
163 /* Create the port heap */
164 CsrPortHeap = RtlCreateHeap(0,
165 LpcWrite.ViewBase,
166 LpcWrite.ViewSize,
167 PAGE_SIZE,
168 0,
169 0);
170 if (CsrPortHeap == NULL)
171 {
172 /* Failure */
173 DPRINT1("Couldn't create heap for CSR port\n");
174 NtClose(CsrApiPort);
175 CsrApiPort = NULL;
176 return STATUS_INSUFFICIENT_RESOURCES;
177 }
178
179 /* Return success */
180 return STATUS_SUCCESS;
181 }
182
183 /*
184 * @implemented
185 */
186 NTSTATUS
187 NTAPI
188 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
189 IN ULONG ServerId,
190 IN PVOID ConnectionInfo,
191 IN OUT PULONG ConnectionInfoSize,
192 OUT PBOOLEAN ServerToServerCall)
193 {
194 NTSTATUS Status;
195 PIMAGE_NT_HEADERS NtHeader;
196 UNICODE_STRING CsrSrvName;
197 HANDLE hCsrSrv;
198 ANSI_STRING CsrServerRoutineName;
199 CSR_API_MESSAGE ApiMessage;
200 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
201 PCSR_CAPTURE_BUFFER CaptureBuffer;
202
203 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
204
205 /* Validate the Connection Info */
206 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
207 {
208 DPRINT1("Connection info given, but no length\n");
209 return STATUS_INVALID_PARAMETER;
210 }
211
212 /* Check if we're inside a CSR Process */
213 if (InsideCsrProcess)
214 {
215 /* Tell the client that we're already inside CSR */
216 if (ServerToServerCall) *ServerToServerCall = TRUE;
217 return STATUS_SUCCESS;
218 }
219
220 /*
221 * We might be in a CSR Process but not know it, if this is the first call.
222 * So let's find out.
223 */
224 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
225 {
226 /* The image isn't valid */
227 DPRINT1("Invalid image\n");
228 return STATUS_INVALID_IMAGE_FORMAT;
229 }
230 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
231
232 /* Now we can check if we are inside or not */
233 if (InsideCsrProcess)
234 {
235 /* We're inside, so let's find csrsrv */
236 DPRINT("Next-GEN CSRSS support\n");
237 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
238 Status = LdrGetDllHandle(NULL,
239 NULL,
240 &CsrSrvName,
241 &hCsrSrv);
242
243 /* Now get the Server to Server routine */
244 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
245 Status = LdrGetProcedureAddress(hCsrSrv,
246 &CsrServerRoutineName,
247 0L,
248 (PVOID*)&CsrServerApiRoutine);
249
250 /* Use the local heap as port heap */
251 CsrPortHeap = RtlGetProcessHeap();
252
253 /* Tell the caller we're inside the server */
254 *ServerToServerCall = InsideCsrProcess;
255 return STATUS_SUCCESS;
256 }
257
258 /* Now check if connection info is given */
259 if (ConnectionInfo)
260 {
261 /* Well, we're defintely in a client now */
262 InsideCsrProcess = FALSE;
263
264 /* Do we have a connection to CSR yet? */
265 if (!CsrApiPort)
266 {
267 /* No, set it up now */
268 Status = CsrpConnectToServer(ObjectDirectory);
269 if (!NT_SUCCESS(Status))
270 {
271 /* Failed */
272 DPRINT1("Failure to connect to CSR\n");
273 return Status;
274 }
275 }
276
277 /* Setup the connect message header */
278 ClientConnect->ServerId = ServerId;
279 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
280
281 /* Setup a buffer for the connection info */
282 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
283 if (CaptureBuffer == NULL)
284 {
285 return STATUS_INSUFFICIENT_RESOURCES;
286 }
287
288 /* Capture the connection info data */
289 CsrCaptureMessageBuffer(CaptureBuffer,
290 ConnectionInfo,
291 ClientConnect->ConnectionInfoSize,
292 &ClientConnect->ConnectionInfo);
293
294 /* Return the allocated length */
295 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
296
297 /* Call CSR */
298 Status = CsrClientCallServer(&ApiMessage,
299 CaptureBuffer,
300 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
301 sizeof(CSR_CLIENT_CONNECT));
302
303 /* Copy the updated connection info data back into the user buffer */
304 RtlMoveMemory(ConnectionInfo,
305 ClientConnect->ConnectionInfo,
306 *ConnectionInfoSize);
307
308 /* Free the capture buffer */
309 CsrFreeCaptureBuffer(CaptureBuffer);
310 }
311 else
312 {
313 /* No connection info, just return */
314 Status = STATUS_SUCCESS;
315 }
316
317 /* Let the caller know if this was server to server */
318 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
319 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
320
321 return Status;
322 }
323
324 #if 0
325 //
326 // Structures can be padded at the end, causing the size of the entire structure
327 // minus the size of the last field, not to be equal to the offset of the last
328 // field.
329 //
330 typedef struct _TEST_EMBEDDED
331 {
332 ULONG One;
333 ULONG Two;
334 ULONG Three;
335 } TEST_EMBEDDED;
336
337 typedef struct _TEST
338 {
339 PORT_MESSAGE h;
340 TEST_EMBEDDED Three;
341 } TEST;
342
343 C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
344 C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
345 C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
346
347 C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
348 C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
349 #endif
350
351 /*
352 * @implemented
353 */
354 NTSTATUS
355 NTAPI
356 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
357 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
358 IN CSR_API_NUMBER ApiNumber,
359 IN ULONG DataLength)
360 {
361 NTSTATUS Status;
362 ULONG PointerCount;
363 PULONG_PTR OffsetPointer;
364
365 /* Fill out the Port Message Header */
366 ApiMessage->Header.u2.ZeroInit = 0;
367 ApiMessage->Header.u1.s1.TotalLength = DataLength +
368 sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data); // FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
369 ApiMessage->Header.u1.s1.DataLength = DataLength +
370 FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header);// ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
371
372 /* Fill out the CSR Header */
373 ApiMessage->ApiNumber = ApiNumber;
374 ApiMessage->CsrCaptureData = NULL;
375
376 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
377 ApiNumber,
378 ApiMessage->Header.u1.s1.DataLength,
379 ApiMessage->Header.u1.s1.TotalLength);
380
381 /* Check if we are already inside a CSR Server */
382 if (!InsideCsrProcess)
383 {
384 /* Check if we got a Capture Buffer */
385 if (CaptureBuffer)
386 {
387 /*
388 * We have to convert from our local (client) view
389 * to the remote (server) view.
390 */
391 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
392 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
393
394 /* Lock the buffer. */
395 CaptureBuffer->BufferEnd = NULL;
396
397 /*
398 * Each client pointer inside the CSR message is converted into
399 * a server pointer, and each pointer to these message pointers
400 * is converted into an offset.
401 */
402 PointerCount = CaptureBuffer->PointerCount;
403 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
404 while (PointerCount--)
405 {
406 if (*OffsetPointer != 0)
407 {
408 *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
409 *OffsetPointer -= (ULONG_PTR)ApiMessage;
410 }
411 ++OffsetPointer;
412 }
413 }
414
415 /* Send the LPC Message */
416 Status = NtRequestWaitReplyPort(CsrApiPort,
417 &ApiMessage->Header,
418 &ApiMessage->Header);
419
420 /* Check if we got a Capture Buffer */
421 if (CaptureBuffer)
422 {
423 /*
424 * We have to convert back from the remote (server) view
425 * to our local (client) view.
426 */
427 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
428 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
429
430 /*
431 * Convert back the offsets into pointers to CSR message
432 * pointers, and convert back these message server pointers
433 * into client pointers.
434 */
435 PointerCount = CaptureBuffer->PointerCount;
436 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
437 while (PointerCount--)
438 {
439 if (*OffsetPointer != 0)
440 {
441 *OffsetPointer += (ULONG_PTR)ApiMessage;
442 *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
443 }
444 ++OffsetPointer;
445 }
446 }
447
448 /* Check for success */
449 if (!NT_SUCCESS(Status))
450 {
451 /* We failed. Overwrite the return value with the failure. */
452 DPRINT1("LPC Failed: %lx\n", Status);
453 ApiMessage->Status = Status;
454 }
455 }
456 else
457 {
458 /* This is a server-to-server call. Save our CID and do a direct call. */
459 DPRINT("Next gen server-to-server call\n");
460
461 /* We check this equality inside CsrValidateMessageBuffer */
462 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
463
464 Status = CsrServerApiRoutine(&ApiMessage->Header,
465 &ApiMessage->Header);
466
467 /* Check for success */
468 if (!NT_SUCCESS(Status))
469 {
470 /* We failed. Overwrite the return value with the failure. */
471 ApiMessage->Status = Status;
472 }
473 }
474
475 /* Return the CSR Result */
476 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
477 return ApiMessage->Status;
478 }
479
480 /*
481 * @implemented
482 */
483 HANDLE
484 NTAPI
485 CsrGetProcessId(VOID)
486 {
487 return CsrProcessId;
488 }
489
490 /* EOF */