[NTDLL/CSRSRV]
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS ********************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22
23 typedef NTSTATUS
24 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
25 IN PPORT_MESSAGE Reply);
26
27 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
28
29 #define UNICODE_PATH_SEP L"\\"
30
31 /* FUNCTIONS ******************************************************************/
32
33 /*
34 * @implemented
35 */
36 HANDLE
37 NTAPI
38 CsrGetProcessId(VOID)
39 {
40 return CsrProcessId;
41 }
42
43 /*
44 * @implemented
45 */
46 NTSTATUS
47 NTAPI
48 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
49 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
50 IN CSR_API_NUMBER ApiNumber,
51 IN ULONG DataLength)
52 {
53 NTSTATUS Status;
54 ULONG PointerCount;
55 PULONG_PTR OffsetPointer;
56
57 /* Fill out the Port Message Header. */
58 ApiMessage->Header.u2.ZeroInit = 0;
59 ApiMessage->Header.u1.s1.TotalLength =
60 FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
61 ApiMessage->Header.u1.s1.DataLength =
62 ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
63
64 /* Fill out the CSR Header. */
65 ApiMessage->ApiNumber = ApiNumber;
66 ApiMessage->CsrCaptureData = NULL;
67
68 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
69 ApiNumber,
70 ApiMessage->Header.u1.s1.DataLength,
71 ApiMessage->Header.u1.s1.TotalLength);
72
73 /* Check if we are already inside a CSR Server. */
74 if (!InsideCsrProcess)
75 {
76 /* Check if we got a Capture Buffer. */
77 if (CaptureBuffer)
78 {
79 /*
80 * We have to convert from our local (client) view
81 * to the remote (server) view.
82 */
83 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
84 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
85
86 /* Lock the buffer. */
87 CaptureBuffer->BufferEnd = NULL;
88
89 /*
90 * Each client pointer inside the CSR message is converted into
91 * a server pointer, and each pointer to these message pointers
92 * is converted into an offset.
93 */
94 PointerCount = CaptureBuffer->PointerCount;
95 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
96 while (PointerCount--)
97 {
98 if (*OffsetPointer != 0)
99 {
100 *(PULONG_PTR)*OffsetPointer += CsrPortMemoryDelta;
101 *OffsetPointer -= (ULONG_PTR)ApiMessage;
102 }
103 ++OffsetPointer;
104 }
105 }
106
107 /* Send the LPC Message. */
108 Status = NtRequestWaitReplyPort(CsrApiPort,
109 &ApiMessage->Header,
110 &ApiMessage->Header);
111
112 /* Check if we got a Capture Buffer. */
113 if (CaptureBuffer)
114 {
115 /*
116 * We have to convert back from the remote (server) view
117 * to our local (client) view.
118 */
119 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
120 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
121
122 /*
123 * Convert back the offsets into pointers to CSR message
124 * pointers, and convert back these message server pointers
125 * into client pointers.
126 */
127 PointerCount = CaptureBuffer->PointerCount;
128 OffsetPointer = CaptureBuffer->PointerOffsetsArray;
129 while (PointerCount--)
130 {
131 if (*OffsetPointer != 0)
132 {
133 *OffsetPointer += (ULONG_PTR)ApiMessage;
134 *(PULONG_PTR)*OffsetPointer -= CsrPortMemoryDelta;
135 }
136 ++OffsetPointer;
137 }
138 }
139
140 /* Check for success. */
141 if (!NT_SUCCESS(Status))
142 {
143 /* We failed. Overwrite the return value with the failure. */
144 DPRINT1("LPC Failed: %lx\n", Status);
145 ApiMessage->Status = Status;
146 }
147 }
148 else
149 {
150 /* This is a server-to-server call. Save our CID and do a direct call. */
151 DPRINT1("Next gen server-to-server call\n");
152
153 /* We check this equality inside CsrValidateMessageBuffer. */
154 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
155
156 Status = CsrServerApiRoutine(&ApiMessage->Header,
157 &ApiMessage->Header);
158
159 /* Check for success. */
160 if (!NT_SUCCESS(Status))
161 {
162 /* We failed. Overwrite the return value with the failure. */
163 ApiMessage->Status = Status;
164 }
165 }
166
167 /* Return the CSR Result. */
168 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
169 return ApiMessage->Status;
170 }
171
172 NTSTATUS
173 NTAPI
174 CsrpConnectToServer(IN PWSTR ObjectDirectory)
175 {
176 ULONG PortNameLength;
177 UNICODE_STRING PortName;
178 LARGE_INTEGER CsrSectionViewSize;
179 NTSTATUS Status;
180 HANDLE CsrSectionHandle;
181 PORT_VIEW LpcWrite;
182 REMOTE_PORT_VIEW LpcRead;
183 SECURITY_QUALITY_OF_SERVICE SecurityQos;
184 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
185 PSID SystemSid = NULL;
186 CSR_CONNECTION_INFO ConnectionInfo;
187 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
188
189 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
190
191 /* Binary compatibility with MS KERNEL32 */
192 if (NULL == ObjectDirectory)
193 {
194 ObjectDirectory = L"\\Windows";
195 }
196
197 /* Calculate the total port name size */
198 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
199 sizeof(CSR_PORT_NAME);
200
201 /* Set the port name */
202 PortName.Length = 0;
203 PortName.MaximumLength = PortNameLength;
204
205 /* Allocate a buffer for it */
206 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
207 if (PortName.Buffer == NULL)
208 {
209 return STATUS_INSUFFICIENT_RESOURCES;
210 }
211
212 /* Create the name */
213 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
214 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
215 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
216
217 /* Create a section for the port memory */
218 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
219 Status = NtCreateSection(&CsrSectionHandle,
220 SECTION_ALL_ACCESS,
221 NULL,
222 &CsrSectionViewSize,
223 PAGE_READWRITE,
224 SEC_COMMIT,
225 NULL);
226 if (!NT_SUCCESS(Status))
227 {
228 DPRINT1("Failure allocating CSR Section\n");
229 return Status;
230 }
231
232 /* Set up the port view structures to match them with the section */
233 LpcWrite.Length = sizeof(PORT_VIEW);
234 LpcWrite.SectionHandle = CsrSectionHandle;
235 LpcWrite.SectionOffset = 0;
236 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
237 LpcWrite.ViewBase = 0;
238 LpcWrite.ViewRemoteBase = 0;
239 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
240 LpcRead.ViewSize = 0;
241 LpcRead.ViewBase = 0;
242
243 /* Setup the QoS */
244 SecurityQos.ImpersonationLevel = SecurityImpersonation;
245 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
246 SecurityQos.EffectiveOnly = TRUE;
247
248 /* Setup the connection info */
249 ConnectionInfo.Version = 0x10000;
250
251 /* Create a SID for us */
252 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
253 1,
254 SECURITY_LOCAL_SYSTEM_RID,
255 0,
256 0,
257 0,
258 0,
259 0,
260 0,
261 0,
262 &SystemSid);
263 if (!NT_SUCCESS(Status))
264 {
265 /* Failure */
266 DPRINT1("Couldn't allocate SID\n");
267 NtClose(CsrSectionHandle);
268 return Status;
269 }
270
271 /* Connect to the port */
272 Status = NtSecureConnectPort(&CsrApiPort,
273 &PortName,
274 &SecurityQos,
275 &LpcWrite,
276 SystemSid,
277 &LpcRead,
278 NULL,
279 &ConnectionInfo,
280 &ConnectionInfoLength);
281 NtClose(CsrSectionHandle);
282 if (!NT_SUCCESS(Status))
283 {
284 /* Failure */
285 DPRINT1("Couldn't connect to CSR port\n");
286 return Status;
287 }
288
289 /* Save the delta between the sections, for capture usage later */
290 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
291 (ULONG_PTR)LpcWrite.ViewBase;
292
293 /* Save the Process */
294 CsrProcessId = ConnectionInfo.ProcessId;
295
296 /* Save CSR Section data */
297 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
298 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
299 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
300
301 /* Create the port heap */
302 CsrPortHeap = RtlCreateHeap(0,
303 LpcWrite.ViewBase,
304 LpcWrite.ViewSize,
305 PAGE_SIZE,
306 0,
307 0);
308 if (CsrPortHeap == NULL)
309 {
310 /* Failure */
311 DPRINT1("Couldn't create heap for CSR port\n");
312 NtClose(CsrApiPort);
313 CsrApiPort = NULL;
314 return STATUS_INSUFFICIENT_RESOURCES;
315 }
316
317 /* Return success */
318 return STATUS_SUCCESS;
319 }
320
321 /*
322 * @implemented
323 */
324 NTSTATUS
325 NTAPI
326 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
327 IN ULONG ServerId,
328 IN PVOID ConnectionInfo,
329 IN OUT PULONG ConnectionInfoSize,
330 OUT PBOOLEAN ServerToServerCall)
331 {
332 NTSTATUS Status;
333 PIMAGE_NT_HEADERS NtHeader;
334 UNICODE_STRING CsrSrvName;
335 HANDLE hCsrSrv;
336 ANSI_STRING CsrServerRoutineName;
337 CSR_API_MESSAGE ApiMessage;
338 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
339 PCSR_CAPTURE_BUFFER CaptureBuffer;
340
341 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
342
343 /* Validate the Connection Info */
344 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
345 {
346 DPRINT1("Connection info given, but no length\n");
347 return STATUS_INVALID_PARAMETER;
348 }
349
350 /* Check if we're inside a CSR Process */
351 if (InsideCsrProcess)
352 {
353 /* Tell the client that we're already inside CSR */
354 if (ServerToServerCall) *ServerToServerCall = TRUE;
355 return STATUS_SUCCESS;
356 }
357
358 /*
359 * We might be in a CSR Process but not know it, if this is the first call.
360 * So let's find out.
361 */
362 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
363 {
364 /* The image isn't valid */
365 DPRINT1("Invalid image\n");
366 return STATUS_INVALID_IMAGE_FORMAT;
367 }
368 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
369
370 /* Now we can check if we are inside or not */
371 if (InsideCsrProcess)
372 {
373 /* We're inside, so let's find csrsrv */
374 DPRINT1("Next-GEN CSRSS support\n");
375 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
376 Status = LdrGetDllHandle(NULL,
377 NULL,
378 &CsrSrvName,
379 &hCsrSrv);
380
381 /* Now get the Server to Server routine */
382 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
383 Status = LdrGetProcedureAddress(hCsrSrv,
384 &CsrServerRoutineName,
385 0L,
386 (PVOID*)&CsrServerApiRoutine);
387
388 /* Use the local heap as port heap */
389 CsrPortHeap = RtlGetProcessHeap();
390
391 /* Tell the caller we're inside the server */
392 *ServerToServerCall = InsideCsrProcess;
393 return STATUS_SUCCESS;
394 }
395
396 /* Now check if connection info is given */
397 if (ConnectionInfo)
398 {
399 /* Well, we're defintely in a client now */
400 InsideCsrProcess = FALSE;
401
402 /* Do we have a connection to CSR yet? */
403 if (!CsrApiPort)
404 {
405 /* No, set it up now */
406 if (!NT_SUCCESS(Status = CsrpConnectToServer(ObjectDirectory)))
407 {
408 /* Failed */
409 DPRINT1("Failure to connect to CSR\n");
410 return Status;
411 }
412 }
413
414 /* Setup the connect message header */
415 ClientConnect->ServerId = ServerId;
416 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
417
418 /* Setup a buffer for the connection info */
419 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
420 if (CaptureBuffer == NULL)
421 {
422 return STATUS_INSUFFICIENT_RESOURCES;
423 }
424
425 /* Capture the connection info data */
426 CsrCaptureMessageBuffer(CaptureBuffer,
427 ConnectionInfo,
428 ClientConnect->ConnectionInfoSize,
429 &ClientConnect->ConnectionInfo);
430
431 /* Return the allocated length */
432 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
433
434 /* Call CSR */
435 Status = CsrClientCallServer(&ApiMessage,
436 CaptureBuffer,
437 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
438 sizeof(CSR_CLIENT_CONNECT));
439
440 /* Copy the updated connection info data back into the user buffer */
441 RtlMoveMemory(ConnectionInfo,
442 ClientConnect->ConnectionInfo,
443 *ConnectionInfoSize);
444
445 /* Free the capture buffer */
446 CsrFreeCaptureBuffer(CaptureBuffer);
447 }
448 else
449 {
450 /* No connection info, just return */
451 Status = STATUS_SUCCESS;
452 }
453
454 /* Let the caller know if this was server to server */
455 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
456 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
457
458 return Status;
459 }
460
461 /* EOF */