8a12194ce0039bb3e9a613a12c4eef6ba5884dc9
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS ********************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22
23 typedef NTSTATUS
24 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
25 IN PPORT_MESSAGE Reply);
26
27 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
28
29 #define UNICODE_PATH_SEP L"\\"
30
31 /* FUNCTIONS ******************************************************************/
32
33 /*
34 * @implemented
35 */
36 HANDLE
37 NTAPI
38 CsrGetProcessId(VOID)
39 {
40 return CsrProcessId;
41 }
42
43 /*
44 * @implemented
45 */
46 NTSTATUS
47 NTAPI
48 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
49 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
50 IN CSR_API_NUMBER ApiNumber,
51 IN ULONG DataLength)
52 {
53 NTSTATUS Status;
54 ULONG i;
55
56 /* Fill out the Port Message Header. */
57 ApiMessage->Header.u2.ZeroInit = 0;
58 ApiMessage->Header.u1.s1.TotalLength =
59 FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
60 ApiMessage->Header.u1.s1.DataLength =
61 ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
62
63 /* Fill out the CSR Header. */
64 ApiMessage->ApiNumber = ApiNumber;
65 ApiMessage->CsrCaptureData = NULL;
66
67 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
68 ApiNumber,
69 ApiMessage->Header.u1.s1.DataLength,
70 ApiMessage->Header.u1.s1.TotalLength);
71
72 /* Check if we are already inside a CSR Server. */
73 if (!InsideCsrProcess)
74 {
75 /* Check if we got a Capture Buffer. */
76 if (CaptureBuffer)
77 {
78 /*
79 * We have to convert from our local (client) view
80 * to the remote (server) view.
81 */
82 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
83 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
84
85 /* Lock the buffer. */
86 CaptureBuffer->BufferEnd = NULL;
87
88 /*
89 * Each client pointer inside the CSR message is converted into
90 * a server pointer, and each pointer to these message pointers
91 * is converted into an offset.
92 */
93 for (i = 0 ; i < CaptureBuffer->PointerCount ; ++i)
94 {
95 if (CaptureBuffer->PointerOffsetsArray[i] != 0)
96 {
97 *(PULONG_PTR)CaptureBuffer->PointerOffsetsArray[i] += CsrPortMemoryDelta;
98 CaptureBuffer->PointerOffsetsArray[i] -= (ULONG_PTR)ApiMessage;
99 }
100 }
101 }
102
103 /* Send the LPC Message. */
104 Status = NtRequestWaitReplyPort(CsrApiPort,
105 &ApiMessage->Header,
106 &ApiMessage->Header);
107
108 /* Check if we got a Capture Buffer. */
109 if (CaptureBuffer)
110 {
111 /*
112 * We have to convert back from the remote (server) view
113 * to our local (client) view.
114 */
115 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
116 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
117
118 /*
119 * Convert back the offsets into pointers to CSR message
120 * pointers, and convert back these message server pointers
121 * into client pointers.
122 */
123 for (i = 0 ; i < CaptureBuffer->PointerCount ; ++i)
124 {
125 if (CaptureBuffer->PointerOffsetsArray[i] != 0)
126 {
127 CaptureBuffer->PointerOffsetsArray[i] += (ULONG_PTR)ApiMessage;
128 *(PULONG_PTR)CaptureBuffer->PointerOffsetsArray[i] -= CsrPortMemoryDelta;
129 }
130 }
131 }
132
133 /* Check for success. */
134 if (!NT_SUCCESS(Status))
135 {
136 /* We failed. Overwrite the return value with the failure. */
137 DPRINT1("LPC Failed: %lx\n", Status);
138 ApiMessage->Status = Status;
139 }
140 }
141 else
142 {
143 /* This is a server-to-server call. Save our CID and do a direct call. */
144 DPRINT1("Next gen server-to-server call\n");
145
146 /* We check this equality inside CsrValidateMessageBuffer. */
147 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
148
149 Status = CsrServerApiRoutine(&ApiMessage->Header,
150 &ApiMessage->Header);
151
152 /* Check for success. */
153 if (!NT_SUCCESS(Status))
154 {
155 /* We failed. Overwrite the return value with the failure. */
156 ApiMessage->Status = Status;
157 }
158 }
159
160 /* Return the CSR Result. */
161 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
162 return ApiMessage->Status;
163 }
164
165 NTSTATUS
166 NTAPI
167 CsrpConnectToServer(IN PWSTR ObjectDirectory)
168 {
169 ULONG PortNameLength;
170 UNICODE_STRING PortName;
171 LARGE_INTEGER CsrSectionViewSize;
172 NTSTATUS Status;
173 HANDLE CsrSectionHandle;
174 PORT_VIEW LpcWrite;
175 REMOTE_PORT_VIEW LpcRead;
176 SECURITY_QUALITY_OF_SERVICE SecurityQos;
177 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
178 PSID SystemSid = NULL;
179 CSR_CONNECTION_INFO ConnectionInfo;
180 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
181
182 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
183
184 /* Binary compatibility with MS KERNEL32 */
185 if (NULL == ObjectDirectory)
186 {
187 ObjectDirectory = L"\\Windows";
188 }
189
190 /* Calculate the total port name size */
191 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
192 sizeof(CSR_PORT_NAME);
193
194 /* Set the port name */
195 PortName.Length = 0;
196 PortName.MaximumLength = PortNameLength;
197
198 /* Allocate a buffer for it */
199 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
200 if (PortName.Buffer == NULL)
201 {
202 return STATUS_INSUFFICIENT_RESOURCES;
203 }
204
205 /* Create the name */
206 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
207 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
208 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
209
210 /* Create a section for the port memory */
211 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
212 Status = NtCreateSection(&CsrSectionHandle,
213 SECTION_ALL_ACCESS,
214 NULL,
215 &CsrSectionViewSize,
216 PAGE_READWRITE,
217 SEC_COMMIT,
218 NULL);
219 if (!NT_SUCCESS(Status))
220 {
221 DPRINT1("Failure allocating CSR Section\n");
222 return Status;
223 }
224
225 /* Set up the port view structures to match them with the section */
226 LpcWrite.Length = sizeof(PORT_VIEW);
227 LpcWrite.SectionHandle = CsrSectionHandle;
228 LpcWrite.SectionOffset = 0;
229 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
230 LpcWrite.ViewBase = 0;
231 LpcWrite.ViewRemoteBase = 0;
232 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
233 LpcRead.ViewSize = 0;
234 LpcRead.ViewBase = 0;
235
236 /* Setup the QoS */
237 SecurityQos.ImpersonationLevel = SecurityImpersonation;
238 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
239 SecurityQos.EffectiveOnly = TRUE;
240
241 /* Setup the connection info */
242 ConnectionInfo.Version = 0x10000;
243
244 /* Create a SID for us */
245 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
246 1,
247 SECURITY_LOCAL_SYSTEM_RID,
248 0,
249 0,
250 0,
251 0,
252 0,
253 0,
254 0,
255 &SystemSid);
256 if (!NT_SUCCESS(Status))
257 {
258 /* Failure */
259 DPRINT1("Couldn't allocate SID\n");
260 NtClose(CsrSectionHandle);
261 return Status;
262 }
263
264 /* Connect to the port */
265 Status = NtSecureConnectPort(&CsrApiPort,
266 &PortName,
267 &SecurityQos,
268 &LpcWrite,
269 SystemSid,
270 &LpcRead,
271 NULL,
272 &ConnectionInfo,
273 &ConnectionInfoLength);
274 NtClose(CsrSectionHandle);
275 if (!NT_SUCCESS(Status))
276 {
277 /* Failure */
278 DPRINT1("Couldn't connect to CSR port\n");
279 return Status;
280 }
281
282 /* Save the delta between the sections, for capture usage later */
283 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
284 (ULONG_PTR)LpcWrite.ViewBase;
285
286 /* Save the Process */
287 CsrProcessId = ConnectionInfo.ProcessId;
288
289 /* Save CSR Section data */
290 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
291 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
292 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
293
294 /* Create the port heap */
295 CsrPortHeap = RtlCreateHeap(0,
296 LpcWrite.ViewBase,
297 LpcWrite.ViewSize,
298 PAGE_SIZE,
299 0,
300 0);
301 if (CsrPortHeap == NULL)
302 {
303 /* Failure */
304 DPRINT1("Couldn't create heap for CSR port\n");
305 NtClose(CsrApiPort);
306 CsrApiPort = NULL;
307 return STATUS_INSUFFICIENT_RESOURCES;
308 }
309
310 /* Return success */
311 return STATUS_SUCCESS;
312 }
313
314 /*
315 * @implemented
316 */
317 NTSTATUS
318 NTAPI
319 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
320 IN ULONG ServerId,
321 IN PVOID ConnectionInfo,
322 IN OUT PULONG ConnectionInfoSize,
323 OUT PBOOLEAN ServerToServerCall)
324 {
325 NTSTATUS Status;
326 PIMAGE_NT_HEADERS NtHeader;
327 UNICODE_STRING CsrSrvName;
328 HANDLE hCsrSrv;
329 ANSI_STRING CsrServerRoutineName;
330 CSR_API_MESSAGE ApiMessage;
331 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
332 PCSR_CAPTURE_BUFFER CaptureBuffer;
333
334 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
335
336 /* Validate the Connection Info */
337 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
338 {
339 DPRINT1("Connection info given, but no length\n");
340 return STATUS_INVALID_PARAMETER;
341 }
342
343 /* Check if we're inside a CSR Process */
344 if (InsideCsrProcess)
345 {
346 /* Tell the client that we're already inside CSR */
347 if (ServerToServerCall) *ServerToServerCall = TRUE;
348 return STATUS_SUCCESS;
349 }
350
351 /*
352 * We might be in a CSR Process but not know it, if this is the first call.
353 * So let's find out.
354 */
355 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
356 {
357 /* The image isn't valid */
358 DPRINT1("Invalid image\n");
359 return STATUS_INVALID_IMAGE_FORMAT;
360 }
361 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
362
363 /* Now we can check if we are inside or not */
364 if (InsideCsrProcess)
365 {
366 /* We're inside, so let's find csrsrv */
367 DPRINT1("Next-GEN CSRSS support\n");
368 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
369 Status = LdrGetDllHandle(NULL,
370 NULL,
371 &CsrSrvName,
372 &hCsrSrv);
373
374 /* Now get the Server to Server routine */
375 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
376 Status = LdrGetProcedureAddress(hCsrSrv,
377 &CsrServerRoutineName,
378 0L,
379 (PVOID*)&CsrServerApiRoutine);
380
381 /* Use the local heap as port heap */
382 CsrPortHeap = RtlGetProcessHeap();
383
384 /* Tell the caller we're inside the server */
385 *ServerToServerCall = InsideCsrProcess;
386 return STATUS_SUCCESS;
387 }
388
389 /* Now check if connection info is given */
390 if (ConnectionInfo)
391 {
392 /* Well, we're defintely in a client now */
393 InsideCsrProcess = FALSE;
394
395 /* Do we have a connection to CSR yet? */
396 if (!CsrApiPort)
397 {
398 /* No, set it up now */
399 if (!NT_SUCCESS(Status = CsrpConnectToServer(ObjectDirectory)))
400 {
401 /* Failed */
402 DPRINT1("Failure to connect to CSR\n");
403 return Status;
404 }
405 }
406
407 /* Setup the connect message header */
408 ClientConnect->ServerId = ServerId;
409 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
410
411 /* Setup a buffer for the connection info */
412 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
413 if (CaptureBuffer == NULL)
414 {
415 return STATUS_INSUFFICIENT_RESOURCES;
416 }
417
418 /* Capture the connection info data */
419 CsrCaptureMessageBuffer(CaptureBuffer,
420 ConnectionInfo,
421 ClientConnect->ConnectionInfoSize,
422 &ClientConnect->ConnectionInfo);
423
424 /* Return the allocated length */
425 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
426
427 /* Call CSR */
428 Status = CsrClientCallServer(&ApiMessage,
429 CaptureBuffer,
430 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
431 sizeof(CSR_CLIENT_CONNECT));
432
433 /* Copy the updated connection info data back into the user buffer */
434 RtlMoveMemory(ConnectionInfo,
435 ClientConnect->ConnectionInfo,
436 *ConnectionInfoSize);
437
438 /* Free the capture buffer */
439 CsrFreeCaptureBuffer(CaptureBuffer);
440 }
441 else
442 {
443 /* No connection info, just return */
444 Status = STATUS_SUCCESS;
445 }
446
447 /* Let the caller know if this was server to server */
448 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
449 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
450
451 return Status;
452 }
453
454 /* EOF */