[MMIXER] Fix additional data size initialization for different audio formats (#6753)
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: dll/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12
13 #include <ndk/lpcfuncs.h>
14 #include <csr/csrsrv.h>
15
16 #define NDEBUG
17 #include <debug.h>
18
19 /* GLOBALS ********************************************************************/
20
21 HANDLE CsrApiPort;
22 HANDLE CsrProcessId;
23 HANDLE CsrPortHeap;
24 ULONG_PTR CsrPortMemoryDelta;
25 BOOLEAN InsideCsrProcess = FALSE;
26
27 typedef NTSTATUS
28 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
29 IN PPORT_MESSAGE Reply);
30
31 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
32
33 #define UNICODE_PATH_SEP L"\\"
34
35 /* FUNCTIONS ******************************************************************/
36
37 NTSTATUS
38 NTAPI
39 CsrpConnectToServer(IN PWSTR ObjectDirectory)
40 {
41 NTSTATUS Status;
42 SIZE_T PortNameLength;
43 UNICODE_STRING PortName;
44 LARGE_INTEGER CsrSectionViewSize;
45 HANDLE CsrSectionHandle;
46 PORT_VIEW LpcWrite;
47 REMOTE_PORT_VIEW LpcRead;
48 SECURITY_QUALITY_OF_SERVICE SecurityQos;
49 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
50 PSID SystemSid = NULL;
51 CSR_API_CONNECTINFO ConnectionInfo;
52 ULONG ConnectionInfoLength = sizeof(CSR_API_CONNECTINFO);
53
54 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
55
56 /* Binary compatibility with MS KERNEL32 */
57 if (NULL == ObjectDirectory)
58 {
59 ObjectDirectory = L"\\Windows";
60 }
61
62 /* Calculate the total port name size */
63 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
64 sizeof(CSR_PORT_NAME);
65 if (PortNameLength > UNICODE_STRING_MAX_BYTES)
66 {
67 DPRINT1("PortNameLength too big: %Iu", PortNameLength);
68 return STATUS_NAME_TOO_LONG;
69 }
70
71 /* Set the port name */
72 PortName.Length = 0;
73 PortName.MaximumLength = (USHORT)PortNameLength;
74
75 /* Allocate a buffer for it */
76 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
77 if (PortName.Buffer == NULL)
78 {
79 return STATUS_INSUFFICIENT_RESOURCES;
80 }
81
82 /* Create the name */
83 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
84 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
85 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
86
87 /* Create a section for the port memory */
88 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
89 Status = NtCreateSection(&CsrSectionHandle,
90 SECTION_ALL_ACCESS,
91 NULL,
92 &CsrSectionViewSize,
93 PAGE_READWRITE,
94 SEC_RESERVE,
95 NULL);
96 if (!NT_SUCCESS(Status))
97 {
98 DPRINT1("Failure allocating CSR Section\n");
99 return Status;
100 }
101
102 /* Set up the port view structures to match them with the section */
103 LpcWrite.Length = sizeof(PORT_VIEW);
104 LpcWrite.SectionHandle = CsrSectionHandle;
105 LpcWrite.SectionOffset = 0;
106 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
107 LpcWrite.ViewBase = 0;
108 LpcWrite.ViewRemoteBase = 0;
109 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
110 LpcRead.ViewSize = 0;
111 LpcRead.ViewBase = 0;
112
113 /* Setup the QoS */
114 SecurityQos.ImpersonationLevel = SecurityImpersonation;
115 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
116 SecurityQos.EffectiveOnly = TRUE;
117
118 /* Setup the connection info */
119 ConnectionInfo.DebugFlags = 0;
120
121 /* Create a SID for us */
122 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
123 1,
124 SECURITY_LOCAL_SYSTEM_RID,
125 0,
126 0,
127 0,
128 0,
129 0,
130 0,
131 0,
132 &SystemSid);
133 if (!NT_SUCCESS(Status))
134 {
135 /* Failure */
136 DPRINT1("Couldn't allocate SID\n");
137 NtClose(CsrSectionHandle);
138 return Status;
139 }
140
141 /* Connect to the port */
142 Status = NtSecureConnectPort(&CsrApiPort,
143 &PortName,
144 &SecurityQos,
145 &LpcWrite,
146 SystemSid,
147 &LpcRead,
148 NULL,
149 &ConnectionInfo,
150 &ConnectionInfoLength);
151 RtlFreeSid(SystemSid);
152 NtClose(CsrSectionHandle);
153 if (!NT_SUCCESS(Status))
154 {
155 /* Failure */
156 DPRINT1("Couldn't connect to CSR port\n");
157 return Status;
158 }
159
160 /* Save the delta between the sections, for capture usage later */
161 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
162 (ULONG_PTR)LpcWrite.ViewBase;
163
164 /* Save the Process */
165 CsrProcessId = ConnectionInfo.ServerProcessId;
166
167 /* Save CSR Section data */
168 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
169 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
170 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
171
172 /* Create the port heap */
173 CsrPortHeap = RtlCreateHeap(0,
174 LpcWrite.ViewBase,
175 LpcWrite.ViewSize,
176 PAGE_SIZE,
177 0,
178 0);
179 if (CsrPortHeap == NULL)
180 {
181 /* Failure */
182 DPRINT1("Couldn't create heap for CSR port\n");
183 NtClose(CsrApiPort);
184 CsrApiPort = NULL;
185 return STATUS_INSUFFICIENT_RESOURCES;
186 }
187
188 /* Return success */
189 return STATUS_SUCCESS;
190 }
191
192 /*
193 * @implemented
194 */
195 NTSTATUS
196 NTAPI
197 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
198 IN ULONG ServerId,
199 IN PVOID ConnectionInfo,
200 IN OUT PULONG ConnectionInfoSize,
201 OUT PBOOLEAN ServerToServerCall)
202 {
203 NTSTATUS Status;
204 PIMAGE_NT_HEADERS NtHeader;
205 UNICODE_STRING CsrSrvName;
206 HANDLE hCsrSrv;
207 ANSI_STRING CsrServerRoutineName;
208 CSR_API_MESSAGE ApiMessage;
209 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
210 PCSR_CAPTURE_BUFFER CaptureBuffer;
211
212 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
213
214 /* Validate the Connection Info */
215 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
216 {
217 DPRINT1("Connection info given, but no length\n");
218 return STATUS_INVALID_PARAMETER;
219 }
220
221 /* Check if we're inside a CSR Process */
222 if (InsideCsrProcess)
223 {
224 /* Tell the client that we're already inside CSR */
225 if (ServerToServerCall) *ServerToServerCall = TRUE;
226 return STATUS_SUCCESS;
227 }
228
229 /*
230 * We might be in a CSR Process but not know it, if this is the first call.
231 * So let's find out.
232 */
233 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
234 {
235 /* The image isn't valid */
236 DPRINT1("Invalid image\n");
237 return STATUS_INVALID_IMAGE_FORMAT;
238 }
239 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
240
241 /* Now we can check if we are inside or not */
242 if (InsideCsrProcess)
243 {
244 /* We're inside, so let's find csrsrv */
245 DPRINT("Next-GEN CSRSS support\n");
246 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
247 Status = LdrGetDllHandle(NULL,
248 NULL,
249 &CsrSrvName,
250 &hCsrSrv);
251
252 /* Now get the Server to Server routine */
253 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
254 Status = LdrGetProcedureAddress(hCsrSrv,
255 &CsrServerRoutineName,
256 0L,
257 (PVOID*)&CsrServerApiRoutine);
258
259 /* Use the local heap as port heap */
260 CsrPortHeap = RtlGetProcessHeap();
261
262 /* Tell the caller we're inside the server */
263 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
264 return STATUS_SUCCESS;
265 }
266
267 /* Now check if connection info is given */
268 if (ConnectionInfo)
269 {
270 /* Well, we're definitely in a client now */
271 InsideCsrProcess = FALSE;
272
273 /* Do we have a connection to CSR yet? */
274 if (!CsrApiPort)
275 {
276 /* No, set it up now */
277 Status = CsrpConnectToServer(ObjectDirectory);
278 if (!NT_SUCCESS(Status))
279 {
280 /* Failed */
281 DPRINT1("Failure to connect to CSR\n");
282 return Status;
283 }
284 }
285
286 /* Setup the connect message header */
287 ClientConnect->ServerId = ServerId;
288 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
289
290 /* Setup a buffer for the connection info */
291 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
292 if (CaptureBuffer == NULL)
293 {
294 return STATUS_INSUFFICIENT_RESOURCES;
295 }
296
297 /* Capture the connection info data */
298 CsrCaptureMessageBuffer(CaptureBuffer,
299 ConnectionInfo,
300 ClientConnect->ConnectionInfoSize,
301 &ClientConnect->ConnectionInfo);
302
303 /* Return the allocated length */
304 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
305
306 /* Call CSR */
307 Status = CsrClientCallServer(&ApiMessage,
308 CaptureBuffer,
309 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
310 sizeof(CSR_CLIENT_CONNECT));
311
312 /* Copy the updated connection info data back into the user buffer */
313 RtlMoveMemory(ConnectionInfo,
314 ClientConnect->ConnectionInfo,
315 *ConnectionInfoSize);
316
317 /* Free the capture buffer */
318 CsrFreeCaptureBuffer(CaptureBuffer);
319 }
320 else
321 {
322 /* No connection info, just return */
323 Status = STATUS_SUCCESS;
324 }
325
326 /* Let the caller know if this was server to server */
327 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
328 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
329
330 return Status;
331 }
332
333 #if 0
334 //
335 // Structures can be padded at the end, causing the size of the entire structure
336 // minus the size of the last field, not to be equal to the offset of the last
337 // field.
338 //
339 typedef struct _TEST_EMBEDDED
340 {
341 ULONG One;
342 ULONG Two;
343 ULONG Three;
344 } TEST_EMBEDDED;
345
346 typedef struct _TEST
347 {
348 PORT_MESSAGE h;
349 TEST_EMBEDDED Three;
350 } TEST;
351
352 C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
353 C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
354 C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
355
356 C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
357 C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
358 #endif
359
360 /*
361 * @implemented
362 */
363 NTSTATUS
364 NTAPI
365 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
366 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
367 IN CSR_API_NUMBER ApiNumber,
368 IN ULONG DataLength)
369 {
370 NTSTATUS Status;
371 ULONG PointerCount;
372 PULONG_PTR OffsetPointer;
373
374 /* Make sure the length is valid */
375 if (DataLength > (MAXSHORT - sizeof(CSR_API_MESSAGE)))
376 {
377 DPRINT1("DataLength too big: %lu", DataLength);
378 return STATUS_INVALID_PARAMETER;
379 }
380
381 /* Fill out the Port Message Header */
382 ApiMessage->Header.u2.ZeroInit = 0;
383 ApiMessage->Header.u1.s1.TotalLength = (CSHORT)DataLength +
384 sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data); // FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
385 ApiMessage->Header.u1.s1.DataLength = (CSHORT)DataLength +
386 FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header); // ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
387
388 /* Fill out the CSR Header */
389 ApiMessage->ApiNumber = ApiNumber;
390 ApiMessage->CsrCaptureData = NULL;
391
392 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
393 ApiNumber,
394 ApiMessage->Header.u1.s1.DataLength,
395 ApiMessage->Header.u1.s1.TotalLength);
396
397 /* Check if we are already inside a CSR Server */
398 if (!InsideCsrProcess)
399 {
400 /* Check if we got a Capture Buffer */
401 if (CaptureBuffer)
402 {
403 /*
404 * We have to convert from our local (client) view
405 * to the remote (server) view.
406 */
407 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
408 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
409
410 /* Lock the buffer. */
411 CaptureBuffer->BufferEnd = NULL;
412
413 /*
414 * Each client pointer inside the CSR message is converted into
415 * a server pointer, and each pointer to these message pointers
416 * is converted into an offset.
417 */
418 PointerCount = CaptureBuffer->PointerCount;
419 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
420 while (PointerCount--)
421 {
422 if (*OffsetPointer != 0)
423 {
424 *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
425 *OffsetPointer -= (ULONG_PTR)ApiMessage;
426 }
427 ++OffsetPointer;
428 }
429 }
430
431 /* Send the LPC Message */
432 Status = NtRequestWaitReplyPort(CsrApiPort,
433 &ApiMessage->Header,
434 &ApiMessage->Header);
435
436 /* Check if we got a Capture Buffer */
437 if (CaptureBuffer)
438 {
439 /*
440 * We have to convert back from the remote (server) view
441 * to our local (client) view.
442 */
443 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
444 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
445
446 /*
447 * Convert back the offsets into pointers to CSR message
448 * pointers, and convert back these message server pointers
449 * into client pointers.
450 */
451 PointerCount = CaptureBuffer->PointerCount;
452 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
453 while (PointerCount--)
454 {
455 if (*OffsetPointer != 0)
456 {
457 *OffsetPointer += (ULONG_PTR)ApiMessage;
458 *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
459 }
460 ++OffsetPointer;
461 }
462 }
463
464 /* Check for success */
465 if (!NT_SUCCESS(Status))
466 {
467 /* We failed. Overwrite the return value with the failure. */
468 DPRINT1("LPC Failed: %lx\n", Status);
469 ApiMessage->Status = Status;
470 }
471 }
472 else
473 {
474 /* This is a server-to-server call. Save our CID and do a direct call. */
475 DPRINT("Next gen server-to-server call\n");
476
477 /* We check this equality inside CsrValidateMessageBuffer */
478 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
479
480 Status = CsrServerApiRoutine(&ApiMessage->Header,
481 &ApiMessage->Header);
482
483 /* Check for success */
484 if (!NT_SUCCESS(Status))
485 {
486 /* We failed. Overwrite the return value with the failure. */
487 ApiMessage->Status = Status;
488 }
489 }
490
491 /* Return the CSR Result */
492 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
493 return ApiMessage->Status;
494 }
495
496 /*
497 * @implemented
498 */
499 HANDLE
500 NTAPI
501 CsrGetProcessId(VOID)
502 {
503 return CsrProcessId;
504 }
505
506 /* EOF */