[NTDLL]
[reactos.git] / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *****************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS *******************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22
23 typedef NTSTATUS
24 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
25 IN PPORT_MESSAGE Reply);
26
27 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
28
29 #define UNICODE_PATH_SEP L"\\"
30
31 /* FUNCTIONS *****************************************************************/
32
33 /*
34 * @implemented
35 */
36 HANDLE
37 NTAPI
38 CsrGetProcessId(VOID)
39 {
40 return CsrProcessId;
41 }
42
43 /*
44 * @implemented
45 */
46 NTSTATUS
47 NTAPI
48 CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,
49 PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
50 CSR_API_NUMBER ApiNumber,
51 ULONG RequestLength)
52 {
53 NTSTATUS Status;
54 ULONG PointerCount;
55 PULONG_PTR Pointers;
56 ULONG_PTR CurrentPointer;
57 DPRINT("CsrClientCallServer\n");
58
59 /* Fill out the Port Message Header */
60 ApiMessage->Header.u2.ZeroInit = 0;
61 ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);
62 ApiMessage->Header.u1.s1.TotalLength = RequestLength;
63
64 /* Fill out the CSR Header */
65 ApiMessage->ApiNumber = ApiNumber;
66 ApiMessage->CsrCaptureData = NULL;
67
68 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
69 ApiNumber,
70 ApiMessage->Header.u1.s1.DataLength,
71 ApiMessage->Header.u1.s1.TotalLength);
72
73 /* Check if we are already inside a CSR Server */
74 if (!InsideCsrProcess)
75 {
76 /* Check if we got a a Capture Buffer */
77 if (CaptureBuffer)
78 {
79 /* We have to convert from our local view to the remote view */
80 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
81 CsrPortMemoryDelta);
82
83 /* Lock the buffer */
84 CaptureBuffer->BufferEnd = 0;
85
86 /* Get the pointer information */
87 PointerCount = CaptureBuffer->PointerCount;
88 Pointers = CaptureBuffer->PointerArray;
89
90 /* Loop through every pointer and convert it */
91 DPRINT("PointerCount: %lx\n", PointerCount);
92 while (PointerCount--)
93 {
94 /* Get this pointer and check if it's valid */
95 DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n",
96 &Pointers, Pointers, *Pointers);
97 if ((CurrentPointer = *Pointers++))
98 {
99 /* Update it */
100 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
101 *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
102 Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
103 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
104 }
105 }
106 }
107
108 /* Send the LPC Message */
109 Status = NtRequestWaitReplyPort(CsrApiPort,
110 &ApiMessage->Header,
111 &ApiMessage->Header);
112
113 /* Check if we got a a Capture Buffer */
114 if (CaptureBuffer)
115 {
116 /* We have to convert from the remote view to our remote view */
117 DPRINT("Reconverting CaptureBuffer\n");
118 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
119 ApiMessage->CsrCaptureData -
120 CsrPortMemoryDelta);
121
122 /* Get the pointer information */
123 PointerCount = CaptureBuffer->PointerCount;
124 Pointers = CaptureBuffer->PointerArray;
125
126 /* Loop through every pointer and convert it */
127 while (PointerCount--)
128 {
129 /* Get this pointer and check if it's valid */
130 if ((CurrentPointer = *Pointers++))
131 {
132 /* Update it */
133 CurrentPointer += (ULONG_PTR)ApiMessage;
134 Pointers[-1] = CurrentPointer;
135 *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
136 }
137 }
138 }
139
140 /* Check for success */
141 if (!NT_SUCCESS(Status))
142 {
143 /* We failed. Overwrite the return value with the failure */
144 DPRINT1("LPC Failed: %lx\n", Status);
145 ApiMessage->Status = Status;
146 }
147 }
148 else
149 {
150 /* This is a server-to-server call. Save our CID and do a direct call */
151 DPRINT1("Next gen server-to-server call\n");
152 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
153 Status = CsrServerApiRoutine(&ApiMessage->Header,
154 &ApiMessage->Header);
155
156 /* Check for success */
157 if (!NT_SUCCESS(Status))
158 {
159 /* We failed. Overwrite the return value with the failure */
160 ApiMessage->Status = Status;
161 }
162 }
163
164 /* Return the CSR Result */
165 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
166 return ApiMessage->Status;
167 }
168
169 NTSTATUS
170 NTAPI
171 CsrConnectToServer(IN PWSTR ObjectDirectory)
172 {
173 ULONG PortNameLength;
174 UNICODE_STRING PortName;
175 LARGE_INTEGER CsrSectionViewSize;
176 NTSTATUS Status;
177 HANDLE CsrSectionHandle;
178 PORT_VIEW LpcWrite;
179 REMOTE_PORT_VIEW LpcRead;
180 SECURITY_QUALITY_OF_SERVICE SecurityQos;
181 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
182 PSID SystemSid = NULL;
183 CSR_CONNECTION_INFO ConnectionInfo;
184 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
185
186 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
187
188 /* Binary compatibility with MS KERNEL32 */
189 if (NULL == ObjectDirectory)
190 {
191 ObjectDirectory = L"\\Windows";
192 }
193
194 /* Calculate the total port name size */
195 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
196 sizeof(CSR_PORT_NAME);
197
198 /* Set the port name */
199 PortName.Length = 0;
200 PortName.MaximumLength = PortNameLength;
201
202 /* Allocate a buffer for it */
203 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
204 if (PortName.Buffer == NULL)
205 {
206 return STATUS_INSUFFICIENT_RESOURCES;
207 }
208
209 /* Create the name */
210 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
211 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
212 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
213
214 /* Create a section for the port memory */
215 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
216 Status = NtCreateSection(&CsrSectionHandle,
217 SECTION_ALL_ACCESS,
218 NULL,
219 &CsrSectionViewSize,
220 PAGE_READWRITE,
221 SEC_COMMIT,
222 NULL);
223 if (!NT_SUCCESS(Status))
224 {
225 DPRINT1("Failure allocating CSR Section\n");
226 return Status;
227 }
228
229 /* Set up the port view structures to match them with the section */
230 LpcWrite.Length = sizeof(PORT_VIEW);
231 LpcWrite.SectionHandle = CsrSectionHandle;
232 LpcWrite.SectionOffset = 0;
233 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
234 LpcWrite.ViewBase = 0;
235 LpcWrite.ViewRemoteBase = 0;
236 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
237 LpcRead.ViewSize = 0;
238 LpcRead.ViewBase = 0;
239
240 /* Setup the QoS */
241 SecurityQos.ImpersonationLevel = SecurityImpersonation;
242 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
243 SecurityQos.EffectiveOnly = TRUE;
244
245 /* Setup the connection info */
246 ConnectionInfo.Version = 0x10000;
247
248 /* Create a SID for us */
249 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
250 1,
251 SECURITY_LOCAL_SYSTEM_RID,
252 0,
253 0,
254 0,
255 0,
256 0,
257 0,
258 0,
259 &SystemSid);
260 if (!NT_SUCCESS(Status))
261 {
262 /* Failure */
263 DPRINT1("Couldn't allocate SID\n");
264 NtClose(CsrSectionHandle);
265 return Status;
266 }
267
268 /* Connect to the port */
269 Status = NtSecureConnectPort(&CsrApiPort,
270 &PortName,
271 &SecurityQos,
272 &LpcWrite,
273 SystemSid,
274 &LpcRead,
275 NULL,
276 &ConnectionInfo,
277 &ConnectionInfoLength);
278 NtClose(CsrSectionHandle);
279 if (!NT_SUCCESS(Status))
280 {
281 /* Failure */
282 DPRINT1("Couldn't connect to CSR port\n");
283 return Status;
284 }
285
286 /* Save the delta between the sections, for capture usage later */
287 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
288 (ULONG_PTR)LpcWrite.ViewBase;
289
290 /* Save the Process */
291 CsrProcessId = ConnectionInfo.ProcessId;
292
293 /* Save CSR Section data */
294 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
295 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
296 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
297
298 /* Create the port heap */
299 CsrPortHeap = RtlCreateHeap(0,
300 LpcWrite.ViewBase,
301 LpcWrite.ViewSize,
302 PAGE_SIZE,
303 0,
304 0);
305 if (CsrPortHeap == NULL)
306 {
307 NtClose(CsrApiPort);
308 CsrApiPort = NULL;
309 return STATUS_INSUFFICIENT_RESOURCES;
310 }
311
312 /* Return success */
313 return STATUS_SUCCESS;
314 }
315
316 /*
317 * @implemented
318 */
319 NTSTATUS
320 NTAPI
321 CsrClientConnectToServer(PWSTR ObjectDirectory,
322 ULONG ServerId,
323 PVOID ConnectionInfo,
324 PULONG ConnectionInfoSize,
325 PBOOLEAN ServerToServerCall)
326 {
327 NTSTATUS Status;
328 PIMAGE_NT_HEADERS NtHeader;
329 UNICODE_STRING CsrSrvName;
330 HANDLE hCsrSrv;
331 ANSI_STRING CsrServerRoutineName;
332 PCSR_CAPTURE_BUFFER CaptureBuffer;
333 CSR_API_MESSAGE ApiMessage;
334 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.Data.CsrClientConnect;
335
336 /* Validate the Connection Info */
337 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
338 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
339 {
340 DPRINT1("Connection info given, but no length\n");
341 return STATUS_INVALID_PARAMETER;
342 }
343
344 /* Check if we're inside a CSR Process */
345 if (InsideCsrProcess)
346 {
347 /* Tell the client that we're already inside CSR */
348 if (ServerToServerCall) *ServerToServerCall = TRUE;
349 return STATUS_SUCCESS;
350 }
351
352 /*
353 * We might be in a CSR Process but not know it, if this is the first call.
354 * So let's find out.
355 */
356 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
357 {
358 /* The image isn't valid */
359 DPRINT1("Invalid image\n");
360 return STATUS_INVALID_IMAGE_FORMAT;
361 }
362 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
363
364 /* Now we can check if we are inside or not */
365 if (InsideCsrProcess)
366 {
367 /* We're inside, so let's find csrsrv */
368 DPRINT1("Next-GEN CSRSS support\n");
369 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
370 Status = LdrGetDllHandle(NULL,
371 NULL,
372 &CsrSrvName,
373 &hCsrSrv);
374
375 /* Now get the Server to Server routine */
376 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
377 Status = LdrGetProcedureAddress(hCsrSrv,
378 &CsrServerRoutineName,
379 0L,
380 (PVOID*)&CsrServerApiRoutine);
381
382 /* Use the local heap as port heap */
383 CsrPortHeap = RtlGetProcessHeap();
384
385 /* Tell the caller we're inside the server */
386 *ServerToServerCall = InsideCsrProcess;
387 return STATUS_SUCCESS;
388 }
389
390 /* Now check if connection info is given */
391 if (ConnectionInfo)
392 {
393 /* Well, we're defintely in a client now */
394 InsideCsrProcess = FALSE;
395
396 /* Do we have a connection to CSR yet? */
397 if (!CsrApiPort)
398 {
399 /* No, set it up now */
400 if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))
401 {
402 /* Failed */
403 DPRINT1("Failure to connect to CSR\n");
404 return Status;
405 }
406 }
407
408 /* Setup the connect message header */
409 ClientConnect->ServerId = ServerId;
410 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
411
412 /* Setup a buffer for the connection info */
413 CaptureBuffer = CsrAllocateCaptureBuffer(1,
414 ClientConnect->ConnectionInfoSize);
415 if (CaptureBuffer == NULL)
416 {
417 return STATUS_INSUFFICIENT_RESOURCES;
418 }
419
420 /* Allocate a pointer for the connection info*/
421 CsrAllocateMessagePointer(CaptureBuffer,
422 ClientConnect->ConnectionInfoSize,
423 &ClientConnect->ConnectionInfo);
424
425 /* Copy the data into the buffer */
426 RtlMoveMemory(ClientConnect->ConnectionInfo,
427 ConnectionInfo,
428 ClientConnect->ConnectionInfoSize);
429
430 /* Return the allocated length */
431 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
432
433 /* Call CSR */
434 Status = CsrClientCallServer(&ApiMessage,
435 CaptureBuffer,
436 CSR_CREATE_API_NUMBER(CSR_SRV_DLL, CsrpClientConnect),
437 sizeof(CSR_CLIENT_CONNECT));
438 /*
439 Status = CsrClientCallServer(&ApiMessage,
440 CaptureBuffer,
441 CSR_CREATE_API_NUMBER(CSR_NATIVE, CONNECT_PROCESS),
442 sizeof(CSR_API_MESSAGE));
443 */
444 }
445 else
446 {
447 /* No connection info, just return */
448 Status = STATUS_SUCCESS;
449 }
450
451 /* Let the caller know if this was server to server */
452 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
453 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
454 return Status;
455 }
456
457 /* EOF */