27fbb743b1742e434a7dd9fdbb4fbc4968971d46
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS ********************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22
23 typedef NTSTATUS
24 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
25 IN PPORT_MESSAGE Reply);
26
27 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
28
29 #define UNICODE_PATH_SEP L"\\"
30
31 /* FUNCTIONS ******************************************************************/
32
33 /*
34 * @implemented
35 */
36 HANDLE
37 NTAPI
38 CsrGetProcessId(VOID)
39 {
40 return CsrProcessId;
41 }
42
43 /*
44 * @implemented
45 */
46 NTSTATUS
47 NTAPI
48 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
49 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
50 IN CSR_API_NUMBER ApiNumber,
51 IN ULONG DataLength)
52 {
53 NTSTATUS Status;
54 ULONG PointerCount;
55 PULONG_PTR Pointers;
56 ULONG_PTR CurrentPointer;
57 DPRINT("CsrClientCallServer\n");
58
59 /* Fill out the Port Message Header */
60 ApiMessage->Header.u2.ZeroInit = 0;
61 ApiMessage->Header.u1.s1.TotalLength =
62 FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
63 /* FIELD_OFFSET(CSR_API_MESSAGE, Data) <= sizeof(CSR_API_MESSAGE) - sizeof(ApiMessage->Data) */
64 ApiMessage->Header.u1.s1.DataLength =
65 ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
66
67 /* Fill out the CSR Header */
68 ApiMessage->ApiNumber = ApiNumber;
69 ApiMessage->CsrCaptureData = NULL;
70
71 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
72 ApiNumber,
73 ApiMessage->Header.u1.s1.DataLength,
74 ApiMessage->Header.u1.s1.TotalLength);
75
76 /* Check if we are already inside a CSR Server */
77 if (!InsideCsrProcess)
78 {
79 /* Check if we got a a Capture Buffer */
80 if (CaptureBuffer)
81 {
82 /* We have to convert from our local view to the remote view */
83 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
84 CsrPortMemoryDelta);
85
86 /* Lock the buffer */
87 CaptureBuffer->BufferEnd = 0;
88
89 /* Get the pointer information */
90 PointerCount = CaptureBuffer->PointerCount;
91 Pointers = CaptureBuffer->PointerArray;
92
93 /* Loop through every pointer and convert it */
94 DPRINT("PointerCount: %lx\n", PointerCount);
95 while (PointerCount--)
96 {
97 /* Get this pointer and check if it's valid */
98 DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n",
99 &Pointers, Pointers, *Pointers);
100 if ((CurrentPointer = *Pointers++))
101 {
102 /* Update it */
103 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
104 *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
105 Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
106 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
107 }
108 }
109 }
110
111 /* Send the LPC Message */
112 Status = NtRequestWaitReplyPort(CsrApiPort,
113 &ApiMessage->Header,
114 &ApiMessage->Header);
115
116 /* Check if we got a a Capture Buffer */
117 if (CaptureBuffer)
118 {
119 /* We have to convert back from the remote view to our local view */
120 DPRINT("Reconverting CaptureBuffer\n");
121 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
122 ApiMessage->CsrCaptureData -
123 CsrPortMemoryDelta);
124
125 /* Get the pointer information */
126 PointerCount = CaptureBuffer->PointerCount;
127 Pointers = CaptureBuffer->PointerArray;
128
129 /* Loop through every pointer and convert it */
130 while (PointerCount--)
131 {
132 /* Get this pointer and check if it's valid */
133 if ((CurrentPointer = *Pointers++))
134 {
135 /* Update it */
136 CurrentPointer += (ULONG_PTR)ApiMessage;
137 Pointers[-1] = CurrentPointer;
138 *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
139 }
140 }
141 }
142
143 /* Check for success */
144 if (!NT_SUCCESS(Status))
145 {
146 /* We failed. Overwrite the return value with the failure */
147 DPRINT1("LPC Failed: %lx\n", Status);
148 ApiMessage->Status = Status;
149 }
150 }
151 else
152 {
153 /* This is a server-to-server call. Save our CID and do a direct call */
154 DPRINT1("Next gen server-to-server call\n");
155 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
156 Status = CsrServerApiRoutine(&ApiMessage->Header,
157 &ApiMessage->Header);
158
159 /* Check for success */
160 if (!NT_SUCCESS(Status))
161 {
162 /* We failed. Overwrite the return value with the failure */
163 ApiMessage->Status = Status;
164 }
165 }
166
167 /* Return the CSR Result */
168 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
169 return ApiMessage->Status;
170 }
171
172 NTSTATUS
173 NTAPI
174 CsrpConnectToServer(IN PWSTR ObjectDirectory)
175 {
176 ULONG PortNameLength;
177 UNICODE_STRING PortName;
178 LARGE_INTEGER CsrSectionViewSize;
179 NTSTATUS Status;
180 HANDLE CsrSectionHandle;
181 PORT_VIEW LpcWrite;
182 REMOTE_PORT_VIEW LpcRead;
183 SECURITY_QUALITY_OF_SERVICE SecurityQos;
184 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
185 PSID SystemSid = NULL;
186 CSR_CONNECTION_INFO ConnectionInfo;
187 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
188
189 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
190
191 /* Binary compatibility with MS KERNEL32 */
192 if (NULL == ObjectDirectory)
193 {
194 ObjectDirectory = L"\\Windows";
195 }
196
197 /* Calculate the total port name size */
198 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
199 sizeof(CSR_PORT_NAME);
200
201 /* Set the port name */
202 PortName.Length = 0;
203 PortName.MaximumLength = PortNameLength;
204
205 /* Allocate a buffer for it */
206 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
207 if (PortName.Buffer == NULL)
208 {
209 return STATUS_INSUFFICIENT_RESOURCES;
210 }
211
212 /* Create the name */
213 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
214 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
215 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
216
217 /* Create a section for the port memory */
218 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
219 Status = NtCreateSection(&CsrSectionHandle,
220 SECTION_ALL_ACCESS,
221 NULL,
222 &CsrSectionViewSize,
223 PAGE_READWRITE,
224 SEC_COMMIT,
225 NULL);
226 if (!NT_SUCCESS(Status))
227 {
228 DPRINT1("Failure allocating CSR Section\n");
229 return Status;
230 }
231
232 /* Set up the port view structures to match them with the section */
233 LpcWrite.Length = sizeof(PORT_VIEW);
234 LpcWrite.SectionHandle = CsrSectionHandle;
235 LpcWrite.SectionOffset = 0;
236 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
237 LpcWrite.ViewBase = 0;
238 LpcWrite.ViewRemoteBase = 0;
239 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
240 LpcRead.ViewSize = 0;
241 LpcRead.ViewBase = 0;
242
243 /* Setup the QoS */
244 SecurityQos.ImpersonationLevel = SecurityImpersonation;
245 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
246 SecurityQos.EffectiveOnly = TRUE;
247
248 /* Setup the connection info */
249 ConnectionInfo.Version = 0x10000;
250
251 /* Create a SID for us */
252 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
253 1,
254 SECURITY_LOCAL_SYSTEM_RID,
255 0,
256 0,
257 0,
258 0,
259 0,
260 0,
261 0,
262 &SystemSid);
263 if (!NT_SUCCESS(Status))
264 {
265 /* Failure */
266 DPRINT1("Couldn't allocate SID\n");
267 NtClose(CsrSectionHandle);
268 return Status;
269 }
270
271 /* Connect to the port */
272 Status = NtSecureConnectPort(&CsrApiPort,
273 &PortName,
274 &SecurityQos,
275 &LpcWrite,
276 SystemSid,
277 &LpcRead,
278 NULL,
279 &ConnectionInfo,
280 &ConnectionInfoLength);
281 NtClose(CsrSectionHandle);
282 if (!NT_SUCCESS(Status))
283 {
284 /* Failure */
285 DPRINT1("Couldn't connect to CSR port\n");
286 return Status;
287 }
288
289 /* Save the delta between the sections, for capture usage later */
290 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
291 (ULONG_PTR)LpcWrite.ViewBase;
292
293 /* Save the Process */
294 CsrProcessId = ConnectionInfo.ProcessId;
295
296 /* Save CSR Section data */
297 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
298 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
299 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
300
301 /* Create the port heap */
302 CsrPortHeap = RtlCreateHeap(0,
303 LpcWrite.ViewBase,
304 LpcWrite.ViewSize,
305 PAGE_SIZE,
306 0,
307 0);
308 if (CsrPortHeap == NULL)
309 {
310 NtClose(CsrApiPort);
311 CsrApiPort = NULL;
312 return STATUS_INSUFFICIENT_RESOURCES;
313 }
314
315 /* Return success */
316 return STATUS_SUCCESS;
317 }
318
319 /*
320 * @implemented
321 */
322 NTSTATUS
323 NTAPI
324 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
325 IN ULONG ServerId,
326 IN PVOID ConnectionInfo,
327 IN OUT PULONG ConnectionInfoSize,
328 OUT PBOOLEAN ServerToServerCall)
329 {
330 NTSTATUS Status;
331 PIMAGE_NT_HEADERS NtHeader;
332 UNICODE_STRING CsrSrvName;
333 HANDLE hCsrSrv;
334 ANSI_STRING CsrServerRoutineName;
335 PCSR_CAPTURE_BUFFER CaptureBuffer;
336 CSR_API_MESSAGE ApiMessage;
337 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
338
339 /* Validate the Connection Info */
340 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
341 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
342 {
343 DPRINT1("Connection info given, but no length\n");
344 return STATUS_INVALID_PARAMETER;
345 }
346
347 /* Check if we're inside a CSR Process */
348 if (InsideCsrProcess)
349 {
350 /* Tell the client that we're already inside CSR */
351 if (ServerToServerCall) *ServerToServerCall = TRUE;
352 return STATUS_SUCCESS;
353 }
354
355 /*
356 * We might be in a CSR Process but not know it, if this is the first call.
357 * So let's find out.
358 */
359 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
360 {
361 /* The image isn't valid */
362 DPRINT1("Invalid image\n");
363 return STATUS_INVALID_IMAGE_FORMAT;
364 }
365 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
366
367 /* Now we can check if we are inside or not */
368 if (InsideCsrProcess)
369 {
370 /* We're inside, so let's find csrsrv */
371 DPRINT1("Next-GEN CSRSS support\n");
372 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
373 Status = LdrGetDllHandle(NULL,
374 NULL,
375 &CsrSrvName,
376 &hCsrSrv);
377
378 /* Now get the Server to Server routine */
379 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
380 Status = LdrGetProcedureAddress(hCsrSrv,
381 &CsrServerRoutineName,
382 0L,
383 (PVOID*)&CsrServerApiRoutine);
384
385 /* Use the local heap as port heap */
386 CsrPortHeap = RtlGetProcessHeap();
387
388 /* Tell the caller we're inside the server */
389 *ServerToServerCall = InsideCsrProcess;
390 return STATUS_SUCCESS;
391 }
392
393 /* Now check if connection info is given */
394 if (ConnectionInfo)
395 {
396 /* Well, we're defintely in a client now */
397 InsideCsrProcess = FALSE;
398
399 /* Do we have a connection to CSR yet? */
400 if (!CsrApiPort)
401 {
402 /* No, set it up now */
403 if (!NT_SUCCESS(Status = CsrpConnectToServer(ObjectDirectory)))
404 {
405 /* Failed */
406 DPRINT1("Failure to connect to CSR\n");
407 return Status;
408 }
409 }
410
411 /* Setup the connect message header */
412 ClientConnect->ServerId = ServerId;
413 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
414
415 /* Setup a buffer for the connection info */
416 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
417 if (CaptureBuffer == NULL)
418 {
419 return STATUS_INSUFFICIENT_RESOURCES;
420 }
421
422 /* Allocate a pointer for the connection info*/
423 CsrAllocateMessagePointer(CaptureBuffer,
424 ClientConnect->ConnectionInfoSize,
425 &ClientConnect->ConnectionInfo);
426
427 /* Copy the data into the buffer */
428 RtlMoveMemory(ClientConnect->ConnectionInfo,
429 ConnectionInfo,
430 ClientConnect->ConnectionInfoSize);
431
432 /* Return the allocated length */
433 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
434
435 /* Call CSR */
436 Status = CsrClientCallServer(&ApiMessage,
437 CaptureBuffer,
438 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
439 sizeof(CSR_CLIENT_CONNECT));
440 }
441 else
442 {
443 /* No connection info, just return */
444 Status = STATUS_SUCCESS;
445 }
446
447 /* Let the caller know if this was server to server */
448 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
449 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
450 return Status;
451 }
452
453 /* EOF */