Create a branch for console restructuration work.
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: dll/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12
13 #include <ndk/lpcfuncs.h>
14 #include <csr/csrsrv.h>
15
16 #define NDEBUG
17 #include <debug.h>
18
19 /* GLOBALS ********************************************************************/
20
21 HANDLE CsrApiPort;
22 HANDLE CsrProcessId;
23 HANDLE CsrPortHeap;
24 ULONG_PTR CsrPortMemoryDelta;
25 BOOLEAN InsideCsrProcess = FALSE;
26
27 typedef NTSTATUS
28 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
29 IN PPORT_MESSAGE Reply);
30
31 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
32
33 #define UNICODE_PATH_SEP L"\\"
34
35 /* FUNCTIONS ******************************************************************/
36
37 NTSTATUS
38 NTAPI
39 CsrpConnectToServer(IN PWSTR ObjectDirectory)
40 {
41 NTSTATUS Status;
42 ULONG PortNameLength;
43 UNICODE_STRING PortName;
44 LARGE_INTEGER CsrSectionViewSize;
45 HANDLE CsrSectionHandle;
46 PORT_VIEW LpcWrite;
47 REMOTE_PORT_VIEW LpcRead;
48 SECURITY_QUALITY_OF_SERVICE SecurityQos;
49 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
50 PSID SystemSid = NULL;
51 CSR_API_CONNECTINFO ConnectionInfo;
52 ULONG ConnectionInfoLength = sizeof(CSR_API_CONNECTINFO);
53
54 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
55
56 /* Binary compatibility with MS KERNEL32 */
57 if (NULL == ObjectDirectory)
58 {
59 ObjectDirectory = L"\\Windows";
60 }
61
62 /* Calculate the total port name size */
63 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
64 sizeof(CSR_PORT_NAME);
65
66 /* Set the port name */
67 PortName.Length = 0;
68 PortName.MaximumLength = PortNameLength;
69
70 /* Allocate a buffer for it */
71 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
72 if (PortName.Buffer == NULL)
73 {
74 return STATUS_INSUFFICIENT_RESOURCES;
75 }
76
77 /* Create the name */
78 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
79 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
80 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
81
82 /* Create a section for the port memory */
83 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
84 Status = NtCreateSection(&CsrSectionHandle,
85 SECTION_ALL_ACCESS,
86 NULL,
87 &CsrSectionViewSize,
88 PAGE_READWRITE,
89 SEC_RESERVE,
90 NULL);
91 if (!NT_SUCCESS(Status))
92 {
93 DPRINT1("Failure allocating CSR Section\n");
94 return Status;
95 }
96
97 /* Set up the port view structures to match them with the section */
98 LpcWrite.Length = sizeof(PORT_VIEW);
99 LpcWrite.SectionHandle = CsrSectionHandle;
100 LpcWrite.SectionOffset = 0;
101 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
102 LpcWrite.ViewBase = 0;
103 LpcWrite.ViewRemoteBase = 0;
104 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
105 LpcRead.ViewSize = 0;
106 LpcRead.ViewBase = 0;
107
108 /* Setup the QoS */
109 SecurityQos.ImpersonationLevel = SecurityImpersonation;
110 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
111 SecurityQos.EffectiveOnly = TRUE;
112
113 /* Setup the connection info */
114 ConnectionInfo.DebugFlags = 0;
115
116 /* Create a SID for us */
117 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
118 1,
119 SECURITY_LOCAL_SYSTEM_RID,
120 0,
121 0,
122 0,
123 0,
124 0,
125 0,
126 0,
127 &SystemSid);
128 if (!NT_SUCCESS(Status))
129 {
130 /* Failure */
131 DPRINT1("Couldn't allocate SID\n");
132 NtClose(CsrSectionHandle);
133 return Status;
134 }
135
136 /* Connect to the port */
137 Status = NtSecureConnectPort(&CsrApiPort,
138 &PortName,
139 &SecurityQos,
140 &LpcWrite,
141 SystemSid,
142 &LpcRead,
143 NULL,
144 &ConnectionInfo,
145 &ConnectionInfoLength);
146 RtlFreeSid(SystemSid);
147 NtClose(CsrSectionHandle);
148 if (!NT_SUCCESS(Status))
149 {
150 /* Failure */
151 DPRINT1("Couldn't connect to CSR port\n");
152 return Status;
153 }
154
155 /* Save the delta between the sections, for capture usage later */
156 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
157 (ULONG_PTR)LpcWrite.ViewBase;
158
159 /* Save the Process */
160 CsrProcessId = ConnectionInfo.ServerProcessId;
161
162 /* Save CSR Section data */
163 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
164 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
165 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedStaticServerData;
166
167 /* Create the port heap */
168 CsrPortHeap = RtlCreateHeap(0,
169 LpcWrite.ViewBase,
170 LpcWrite.ViewSize,
171 PAGE_SIZE,
172 0,
173 0);
174 if (CsrPortHeap == NULL)
175 {
176 /* Failure */
177 DPRINT1("Couldn't create heap for CSR port\n");
178 NtClose(CsrApiPort);
179 CsrApiPort = NULL;
180 return STATUS_INSUFFICIENT_RESOURCES;
181 }
182
183 /* Return success */
184 return STATUS_SUCCESS;
185 }
186
187 /*
188 * @implemented
189 */
190 NTSTATUS
191 NTAPI
192 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
193 IN ULONG ServerId,
194 IN PVOID ConnectionInfo,
195 IN OUT PULONG ConnectionInfoSize,
196 OUT PBOOLEAN ServerToServerCall)
197 {
198 NTSTATUS Status;
199 PIMAGE_NT_HEADERS NtHeader;
200 UNICODE_STRING CsrSrvName;
201 HANDLE hCsrSrv;
202 ANSI_STRING CsrServerRoutineName;
203 CSR_API_MESSAGE ApiMessage;
204 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
205 PCSR_CAPTURE_BUFFER CaptureBuffer;
206
207 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
208
209 /* Validate the Connection Info */
210 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
211 {
212 DPRINT1("Connection info given, but no length\n");
213 return STATUS_INVALID_PARAMETER;
214 }
215
216 /* Check if we're inside a CSR Process */
217 if (InsideCsrProcess)
218 {
219 /* Tell the client that we're already inside CSR */
220 if (ServerToServerCall) *ServerToServerCall = TRUE;
221 return STATUS_SUCCESS;
222 }
223
224 /*
225 * We might be in a CSR Process but not know it, if this is the first call.
226 * So let's find out.
227 */
228 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
229 {
230 /* The image isn't valid */
231 DPRINT1("Invalid image\n");
232 return STATUS_INVALID_IMAGE_FORMAT;
233 }
234 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
235
236 /* Now we can check if we are inside or not */
237 if (InsideCsrProcess)
238 {
239 /* We're inside, so let's find csrsrv */
240 DPRINT("Next-GEN CSRSS support\n");
241 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
242 Status = LdrGetDllHandle(NULL,
243 NULL,
244 &CsrSrvName,
245 &hCsrSrv);
246
247 /* Now get the Server to Server routine */
248 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
249 Status = LdrGetProcedureAddress(hCsrSrv,
250 &CsrServerRoutineName,
251 0L,
252 (PVOID*)&CsrServerApiRoutine);
253
254 /* Use the local heap as port heap */
255 CsrPortHeap = RtlGetProcessHeap();
256
257 /* Tell the caller we're inside the server */
258 *ServerToServerCall = InsideCsrProcess;
259 return STATUS_SUCCESS;
260 }
261
262 /* Now check if connection info is given */
263 if (ConnectionInfo)
264 {
265 /* Well, we're defintely in a client now */
266 InsideCsrProcess = FALSE;
267
268 /* Do we have a connection to CSR yet? */
269 if (!CsrApiPort)
270 {
271 /* No, set it up now */
272 Status = CsrpConnectToServer(ObjectDirectory);
273 if (!NT_SUCCESS(Status))
274 {
275 /* Failed */
276 DPRINT1("Failure to connect to CSR\n");
277 return Status;
278 }
279 }
280
281 /* Setup the connect message header */
282 ClientConnect->ServerId = ServerId;
283 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
284
285 /* Setup a buffer for the connection info */
286 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
287 if (CaptureBuffer == NULL)
288 {
289 return STATUS_INSUFFICIENT_RESOURCES;
290 }
291
292 /* Capture the connection info data */
293 CsrCaptureMessageBuffer(CaptureBuffer,
294 ConnectionInfo,
295 ClientConnect->ConnectionInfoSize,
296 &ClientConnect->ConnectionInfo);
297
298 /* Return the allocated length */
299 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
300
301 /* Call CSR */
302 Status = CsrClientCallServer(&ApiMessage,
303 CaptureBuffer,
304 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
305 sizeof(CSR_CLIENT_CONNECT));
306
307 /* Copy the updated connection info data back into the user buffer */
308 RtlMoveMemory(ConnectionInfo,
309 ClientConnect->ConnectionInfo,
310 *ConnectionInfoSize);
311
312 /* Free the capture buffer */
313 CsrFreeCaptureBuffer(CaptureBuffer);
314 }
315 else
316 {
317 /* No connection info, just return */
318 Status = STATUS_SUCCESS;
319 }
320
321 /* Let the caller know if this was server to server */
322 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
323 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
324
325 return Status;
326 }
327
328 #if 0
329 //
330 // Structures can be padded at the end, causing the size of the entire structure
331 // minus the size of the last field, not to be equal to the offset of the last
332 // field.
333 //
334 typedef struct _TEST_EMBEDDED
335 {
336 ULONG One;
337 ULONG Two;
338 ULONG Three;
339 } TEST_EMBEDDED;
340
341 typedef struct _TEST
342 {
343 PORT_MESSAGE h;
344 TEST_EMBEDDED Three;
345 } TEST;
346
347 C_ASSERT(sizeof(PORT_MESSAGE) == 0x18);
348 C_ASSERT(FIELD_OFFSET(TEST, Three) == 0x18);
349 C_ASSERT(sizeof(TEST_EMBEDDED) == 0xC);
350
351 C_ASSERT(sizeof(TEST) != (sizeof(TEST_EMBEDDED) + sizeof(PORT_MESSAGE)));
352 C_ASSERT((sizeof(TEST) - sizeof(TEST_EMBEDDED)) != FIELD_OFFSET(TEST, Three));
353 #endif
354
355 /*
356 * @implemented
357 */
358 NTSTATUS
359 NTAPI
360 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
361 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
362 IN CSR_API_NUMBER ApiNumber,
363 IN ULONG DataLength)
364 {
365 NTSTATUS Status;
366 ULONG PointerCount;
367 PULONG_PTR OffsetPointer;
368
369 /* Fill out the Port Message Header */
370 ApiMessage->Header.u2.ZeroInit = 0;
371 ApiMessage->Header.u1.s1.TotalLength = DataLength +
372 sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data); // FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
373 ApiMessage->Header.u1.s1.DataLength = DataLength +
374 FIELD_OFFSET(CSR_API_MESSAGE, Data) - sizeof(ApiMessage->Header);// ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
375
376 /* Fill out the CSR Header */
377 ApiMessage->ApiNumber = ApiNumber;
378 ApiMessage->CsrCaptureData = NULL;
379
380 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
381 ApiNumber,
382 ApiMessage->Header.u1.s1.DataLength,
383 ApiMessage->Header.u1.s1.TotalLength);
384
385 /* Check if we are already inside a CSR Server */
386 if (!InsideCsrProcess)
387 {
388 /* Check if we got a Capture Buffer */
389 if (CaptureBuffer)
390 {
391 /*
392 * We have to convert from our local (client) view
393 * to the remote (server) view.
394 */
395 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
396 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
397
398 /* Lock the buffer. */
399 CaptureBuffer->BufferEnd = NULL;
400
401 /*
402 * Each client pointer inside the CSR message is converted into
403 * a server pointer, and each pointer to these message pointers
404 * is converted into an offset.
405 */
406 PointerCount = CaptureBuffer->PointerCount;
407 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
408 while (PointerCount--)
409 {
410 if (*OffsetPointer != 0)
411 {
412 *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
413 *OffsetPointer -= (ULONG_PTR)ApiMessage;
414 }
415 ++OffsetPointer;
416 }
417 }
418
419 /* Send the LPC Message */
420 Status = NtRequestWaitReplyPort(CsrApiPort,
421 &ApiMessage->Header,
422 &ApiMessage->Header);
423
424 /* Check if we got a Capture Buffer */
425 if (CaptureBuffer)
426 {
427 /*
428 * We have to convert back from the remote (server) view
429 * to our local (client) view.
430 */
431 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
432 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
433
434 /*
435 * Convert back the offsets into pointers to CSR message
436 * pointers, and convert back these message server pointers
437 * into client pointers.
438 */
439 PointerCount = CaptureBuffer->PointerCount;
440 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
441 while (PointerCount--)
442 {
443 if (*OffsetPointer != 0)
444 {
445 *OffsetPointer += (ULONG_PTR)ApiMessage;
446 *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
447 }
448 ++OffsetPointer;
449 }
450 }
451
452 /* Check for success */
453 if (!NT_SUCCESS(Status))
454 {
455 /* We failed. Overwrite the return value with the failure. */
456 DPRINT1("LPC Failed: %lx\n", Status);
457 ApiMessage->Status = Status;
458 }
459 }
460 else
461 {
462 /* This is a server-to-server call. Save our CID and do a direct call. */
463 DPRINT("Next gen server-to-server call\n");
464
465 /* We check this equality inside CsrValidateMessageBuffer */
466 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
467
468 Status = CsrServerApiRoutine(&ApiMessage->Header,
469 &ApiMessage->Header);
470
471 /* Check for success */
472 if (!NT_SUCCESS(Status))
473 {
474 /* We failed. Overwrite the return value with the failure. */
475 ApiMessage->Status = Status;
476 }
477 }
478
479 /* Return the CSR Result */
480 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
481 return ApiMessage->Status;
482 }
483
484 /*
485 * @implemented
486 */
487 HANDLE
488 NTAPI
489 CsrGetProcessId(VOID)
490 {
491 return CsrProcessId;
492 }
493
494 /* EOF */