[NTDLL]
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *******************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS ********************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22
23 typedef NTSTATUS
24 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
25 IN PPORT_MESSAGE Reply);
26
27 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
28
29 #define UNICODE_PATH_SEP L"\\"
30
31 /* FUNCTIONS ******************************************************************/
32
33 /*
34 * @implemented
35 */
36 HANDLE
37 NTAPI
38 CsrGetProcessId(VOID)
39 {
40 return CsrProcessId;
41 }
42
43 /*
44 * @implemented
45 */
46 NTSTATUS
47 NTAPI
48 CsrClientCallServer(IN OUT PCSR_API_MESSAGE ApiMessage,
49 IN OUT PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
50 IN CSR_API_NUMBER ApiNumber,
51 IN ULONG DataLength)
52 {
53 NTSTATUS Status;
54 ULONG i;
55
56 /* Fill out the Port Message Header. */
57 ApiMessage->Header.u2.ZeroInit = 0;
58 ApiMessage->Header.u1.s1.TotalLength =
59 FIELD_OFFSET(CSR_API_MESSAGE, Data) + DataLength;
60 ApiMessage->Header.u1.s1.DataLength =
61 ApiMessage->Header.u1.s1.TotalLength - sizeof(PORT_MESSAGE);
62
63 /* Fill out the CSR Header. */
64 ApiMessage->ApiNumber = ApiNumber;
65 ApiMessage->CsrCaptureData = NULL;
66
67 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
68 ApiNumber,
69 ApiMessage->Header.u1.s1.DataLength,
70 ApiMessage->Header.u1.s1.TotalLength);
71
72 /* Check if we are already inside a CSR Server. */
73 if (!InsideCsrProcess)
74 {
75 /* Check if we got a Capture Buffer. */
76 if (CaptureBuffer)
77 {
78 /*
79 * We have to convert from our local (client) view
80 * to the remote (server) view.
81 */
82 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
83 ((ULONG_PTR)CaptureBuffer + CsrPortMemoryDelta);
84
85 /* Lock the buffer. */
86 CaptureBuffer->BufferEnd = NULL;
87
88 /*
89 * Each client pointer inside the CSR message is converted into
90 * a server pointer, and each pointer to these message pointers
91 * is converted into an offset.
92 */
93 for (i = 0 ; i < CaptureBuffer->PointerCount ; ++i)
94 {
95 if (CaptureBuffer->PointerOffsetsArray[i] != 0)
96 {
97 *(PULONG_PTR)CaptureBuffer->PointerOffsetsArray[i] += CsrPortMemoryDelta;
98 CaptureBuffer->PointerOffsetsArray[i] -= (ULONG_PTR)ApiMessage;
99 }
100 }
101 }
102
103 /* Send the LPC Message. */
104 Status = NtRequestWaitReplyPort(CsrApiPort,
105 &ApiMessage->Header,
106 &ApiMessage->Header);
107
108 /* Check if we got a Capture Buffer. */
109 if (CaptureBuffer)
110 {
111 /*
112 * We have to convert back from the remote (server) view
113 * to our local (client) view.
114 */
115 ApiMessage->CsrCaptureData = (PCSR_CAPTURE_BUFFER)
116 ((ULONG_PTR)ApiMessage->CsrCaptureData - CsrPortMemoryDelta);
117
118 /*
119 * Convert back the offsets into pointers to CSR message
120 * pointers, and convert back these message server pointers
121 * into client pointers.
122 */
123 for (i = 0 ; i < CaptureBuffer->PointerCount ; ++i)
124 {
125 if (CaptureBuffer->PointerOffsetsArray[i] != 0)
126 {
127 CaptureBuffer->PointerOffsetsArray[i] += (ULONG_PTR)ApiMessage;
128 *(PULONG_PTR)CaptureBuffer->PointerOffsetsArray[i] -= CsrPortMemoryDelta;
129 }
130 }
131 }
132
133 /* Check for success. */
134 if (!NT_SUCCESS(Status))
135 {
136 /* We failed. Overwrite the return value with the failure. */
137 DPRINT1("LPC Failed: %lx\n", Status);
138 ApiMessage->Status = Status;
139 }
140 }
141 else
142 {
143 /* This is a server-to-server call. Save our CID and do a direct call. */
144 DPRINT1("Next gen server-to-server call\n");
145
146 /* We check this equality inside CsrValidateMessageBuffer. */
147 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
148
149 Status = CsrServerApiRoutine(&ApiMessage->Header,
150 &ApiMessage->Header);
151
152 /* Check for success. */
153 if (!NT_SUCCESS(Status))
154 {
155 /* We failed. Overwrite the return value with the failure. */
156 ApiMessage->Status = Status;
157 }
158 }
159
160 /* Return the CSR Result. */
161 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
162 return ApiMessage->Status;
163 }
164
165 NTSTATUS
166 NTAPI
167 CsrpConnectToServer(IN PWSTR ObjectDirectory)
168 {
169 ULONG PortNameLength;
170 UNICODE_STRING PortName;
171 LARGE_INTEGER CsrSectionViewSize;
172 NTSTATUS Status;
173 HANDLE CsrSectionHandle;
174 PORT_VIEW LpcWrite;
175 REMOTE_PORT_VIEW LpcRead;
176 SECURITY_QUALITY_OF_SERVICE SecurityQos;
177 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
178 PSID SystemSid = NULL;
179 CSR_CONNECTION_INFO ConnectionInfo;
180 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
181
182 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
183
184 /* Binary compatibility with MS KERNEL32 */
185 if (NULL == ObjectDirectory)
186 {
187 ObjectDirectory = L"\\Windows";
188 }
189
190 /* Calculate the total port name size */
191 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
192 sizeof(CSR_PORT_NAME);
193
194 /* Set the port name */
195 PortName.Length = 0;
196 PortName.MaximumLength = PortNameLength;
197
198 /* Allocate a buffer for it */
199 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, PortNameLength);
200 if (PortName.Buffer == NULL)
201 {
202 return STATUS_INSUFFICIENT_RESOURCES;
203 }
204
205 /* Create the name */
206 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
207 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
208 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
209
210 /* Create a section for the port memory */
211 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
212 Status = NtCreateSection(&CsrSectionHandle,
213 SECTION_ALL_ACCESS,
214 NULL,
215 &CsrSectionViewSize,
216 PAGE_READWRITE,
217 SEC_COMMIT,
218 NULL);
219 if (!NT_SUCCESS(Status))
220 {
221 DPRINT1("Failure allocating CSR Section\n");
222 return Status;
223 }
224
225 /* Set up the port view structures to match them with the section */
226 LpcWrite.Length = sizeof(PORT_VIEW);
227 LpcWrite.SectionHandle = CsrSectionHandle;
228 LpcWrite.SectionOffset = 0;
229 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
230 LpcWrite.ViewBase = 0;
231 LpcWrite.ViewRemoteBase = 0;
232 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
233 LpcRead.ViewSize = 0;
234 LpcRead.ViewBase = 0;
235
236 /* Setup the QoS */
237 SecurityQos.ImpersonationLevel = SecurityImpersonation;
238 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
239 SecurityQos.EffectiveOnly = TRUE;
240
241 /* Setup the connection info */
242 ConnectionInfo.Version = 0x10000;
243
244 /* Create a SID for us */
245 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
246 1,
247 SECURITY_LOCAL_SYSTEM_RID,
248 0,
249 0,
250 0,
251 0,
252 0,
253 0,
254 0,
255 &SystemSid);
256 if (!NT_SUCCESS(Status))
257 {
258 /* Failure */
259 DPRINT1("Couldn't allocate SID\n");
260 NtClose(CsrSectionHandle);
261 return Status;
262 }
263
264 /* Connect to the port */
265 Status = NtSecureConnectPort(&CsrApiPort,
266 &PortName,
267 &SecurityQos,
268 &LpcWrite,
269 SystemSid,
270 &LpcRead,
271 NULL,
272 &ConnectionInfo,
273 &ConnectionInfoLength);
274 NtClose(CsrSectionHandle);
275 if (!NT_SUCCESS(Status))
276 {
277 /* Failure */
278 DPRINT1("Couldn't connect to CSR port\n");
279 return Status;
280 }
281
282 /* Save the delta between the sections, for capture usage later */
283 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
284 (ULONG_PTR)LpcWrite.ViewBase;
285
286 /* Save the Process */
287 CsrProcessId = ConnectionInfo.ProcessId;
288
289 /* Save CSR Section data */
290 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
291 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
292 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
293
294 /* Create the port heap */
295 CsrPortHeap = RtlCreateHeap(0,
296 LpcWrite.ViewBase,
297 LpcWrite.ViewSize,
298 PAGE_SIZE,
299 0,
300 0);
301 if (CsrPortHeap == NULL)
302 {
303 /* Failure */
304 DPRINT1("Couldn't create heap for CSR port\n");
305 NtClose(CsrApiPort);
306 CsrApiPort = NULL;
307 return STATUS_INSUFFICIENT_RESOURCES;
308 }
309
310 /* Return success */
311 return STATUS_SUCCESS;
312 }
313
314 /*
315 * @implemented
316 */
317 NTSTATUS
318 NTAPI
319 CsrClientConnectToServer(IN PWSTR ObjectDirectory,
320 IN ULONG ServerId,
321 IN PVOID ConnectionInfo,
322 IN OUT PULONG ConnectionInfoSize,
323 OUT PBOOLEAN ServerToServerCall)
324 {
325 NTSTATUS Status;
326 PIMAGE_NT_HEADERS NtHeader;
327 UNICODE_STRING CsrSrvName;
328 HANDLE hCsrSrv;
329 ANSI_STRING CsrServerRoutineName;
330 CSR_API_MESSAGE ApiMessage;
331 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
332 PCSR_CAPTURE_BUFFER CaptureBuffer;
333
334 /* Validate the Connection Info */
335 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
336 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
337 {
338 DPRINT1("Connection info given, but no length\n");
339 return STATUS_INVALID_PARAMETER;
340 }
341
342 /* Check if we're inside a CSR Process */
343 if (InsideCsrProcess)
344 {
345 /* Tell the client that we're already inside CSR */
346 if (ServerToServerCall) *ServerToServerCall = TRUE;
347 return STATUS_SUCCESS;
348 }
349
350 /*
351 * We might be in a CSR Process but not know it, if this is the first call.
352 * So let's find out.
353 */
354 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
355 {
356 /* The image isn't valid */
357 DPRINT1("Invalid image\n");
358 return STATUS_INVALID_IMAGE_FORMAT;
359 }
360 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
361
362 /* Now we can check if we are inside or not */
363 if (InsideCsrProcess)
364 {
365 /* We're inside, so let's find csrsrv */
366 DPRINT1("Next-GEN CSRSS support\n");
367 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
368 Status = LdrGetDllHandle(NULL,
369 NULL,
370 &CsrSrvName,
371 &hCsrSrv);
372
373 /* Now get the Server to Server routine */
374 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
375 Status = LdrGetProcedureAddress(hCsrSrv,
376 &CsrServerRoutineName,
377 0L,
378 (PVOID*)&CsrServerApiRoutine);
379
380 /* Use the local heap as port heap */
381 CsrPortHeap = RtlGetProcessHeap();
382
383 /* Tell the caller we're inside the server */
384 *ServerToServerCall = InsideCsrProcess;
385 return STATUS_SUCCESS;
386 }
387
388 /* Now check if connection info is given */
389 if (ConnectionInfo)
390 {
391 /* Well, we're defintely in a client now */
392 InsideCsrProcess = FALSE;
393
394 /* Do we have a connection to CSR yet? */
395 if (!CsrApiPort)
396 {
397 /* No, set it up now */
398 if (!NT_SUCCESS(Status = CsrpConnectToServer(ObjectDirectory)))
399 {
400 /* Failed */
401 DPRINT1("Failure to connect to CSR\n");
402 return Status;
403 }
404 }
405
406 /* Setup the connect message header */
407 ClientConnect->ServerId = ServerId;
408 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
409
410 /* Setup a buffer for the connection info */
411 CaptureBuffer = CsrAllocateCaptureBuffer(1, ClientConnect->ConnectionInfoSize);
412 if (CaptureBuffer == NULL)
413 {
414 return STATUS_INSUFFICIENT_RESOURCES;
415 }
416
417 /* Capture the connection info data */
418 CsrCaptureMessageBuffer(CaptureBuffer,
419 ConnectionInfo,
420 ClientConnect->ConnectionInfoSize,
421 &ClientConnect->ConnectionInfo);
422
423 /* Return the allocated length */
424 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
425
426 /* Call CSR */
427 Status = CsrClientCallServer(&ApiMessage,
428 CaptureBuffer,
429 CSR_CREATE_API_NUMBER(CSRSRV_SERVERDLL_INDEX, CsrpClientConnect),
430 sizeof(CSR_CLIENT_CONNECT));
431 }
432 else
433 {
434 /* No connection info, just return */
435 Status = STATUS_SUCCESS;
436 }
437
438 /* Let the caller know if this was server to server */
439 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
440 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
441 return Status;
442 }
443
444 /* EOF */