[uxtheme]
[reactos.git] / reactos / dll / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *****************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS *******************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22 BOOLEAN UsingOldCsr = TRUE;
23
24 typedef NTSTATUS
25 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
26 IN PPORT_MESSAGE Reply);
27
28 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
29
30 #define UNICODE_PATH_SEP L"\\"
31 #define CSR_PORT_NAME L"ApiPort"
32
33 /* FUNCTIONS *****************************************************************/
34
35 /*
36 * @implemented
37 */
38 HANDLE
39 NTAPI
40 CsrGetProcessId(VOID)
41 {
42 return CsrProcessId;
43 }
44
45 /*
46 * @implemented
47 */
48 NTSTATUS
49 NTAPI
50 CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,
51 PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
52 CSR_API_NUMBER ApiNumber,
53 ULONG RequestLength)
54 {
55 NTSTATUS Status;
56 ULONG PointerCount;
57 PULONG_PTR Pointers;
58 ULONG_PTR CurrentPointer;
59 DPRINT("CsrClientCallServer\n");
60
61 /* Fill out the Port Message Header */
62 ApiMessage->Header.u2.ZeroInit = 0;
63 ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);
64 ApiMessage->Header.u1.s1.TotalLength = RequestLength;
65
66 /* Fill out the CSR Header */
67 ApiMessage->Type = ApiNumber;
68 //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR
69 ApiMessage->CsrCaptureData = NULL;
70
71 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
72 ApiNumber,
73 ApiMessage->Header.u1.s1.DataLength,
74 ApiMessage->Header.u1.s1.TotalLength);
75
76 /* Check if we are already inside a CSR Server */
77 if (!InsideCsrProcess)
78 {
79 /* Check if we got a a Capture Buffer */
80 if (CaptureBuffer)
81 {
82 /* We have to convert from our local view to the remote view */
83 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
84 CsrPortMemoryDelta);
85
86 /* Lock the buffer */
87 CaptureBuffer->BufferEnd = 0;
88
89 /* Get the pointer information */
90 PointerCount = CaptureBuffer->PointerCount;
91 Pointers = CaptureBuffer->PointerArray;
92
93 /* Loop through every pointer and convert it */
94 DPRINT("PointerCount: %lx\n", PointerCount);
95 while (PointerCount--)
96 {
97 /* Get this pointer and check if it's valid */
98 DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n",
99 &Pointers, Pointers, *Pointers);
100 if ((CurrentPointer = *Pointers++))
101 {
102 /* Update it */
103 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
104 *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
105 Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
106 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
107 }
108 }
109 }
110
111 /* Send the LPC Message */
112 Status = NtRequestWaitReplyPort(CsrApiPort,
113 &ApiMessage->Header,
114 &ApiMessage->Header);
115
116 /* Check if we got a a Capture Buffer */
117 if (CaptureBuffer)
118 {
119 /* We have to convert from the remote view to our remote view */
120 DPRINT("Reconverting CaptureBuffer\n");
121 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
122 ApiMessage->CsrCaptureData -
123 CsrPortMemoryDelta);
124
125 /* Get the pointer information */
126 PointerCount = CaptureBuffer->PointerCount;
127 Pointers = CaptureBuffer->PointerArray;
128
129 /* Loop through every pointer and convert it */
130 while (PointerCount--)
131 {
132 /* Get this pointer and check if it's valid */
133 if ((CurrentPointer = *Pointers++))
134 {
135 /* Update it */
136 CurrentPointer += (ULONG_PTR)ApiMessage;
137 Pointers[-1] = CurrentPointer;
138 *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
139 }
140 }
141 }
142
143 /* Check for success */
144 if (!NT_SUCCESS(Status))
145 {
146 /* We failed. Overwrite the return value with the failure */
147 DPRINT1("LPC Failed: %lx\n", Status);
148 ApiMessage->Status = Status;
149 }
150 }
151 else
152 {
153 /* This is a server-to-server call. Save our CID and do a direct call */
154 DbgBreakPoint();
155 ApiMessage->Header.ClientId = NtCurrentTeb()->ClientId;
156 Status = CsrServerApiRoutine(&ApiMessage->Header,
157 &ApiMessage->Header);
158
159 /* Check for success */
160 if (!NT_SUCCESS(Status))
161 {
162 /* We failed. Overwrite the return value with the failure */
163 ApiMessage->Status = Status;
164 }
165 }
166
167 /* Return the CSR Result */
168 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
169 return ApiMessage->Status;
170 }
171
172 NTSTATUS
173 NTAPI
174 CsrConnectToServer(IN PWSTR ObjectDirectory)
175 {
176 ULONG PortNameLength;
177 UNICODE_STRING PortName;
178 LARGE_INTEGER CsrSectionViewSize;
179 NTSTATUS Status;
180 HANDLE CsrSectionHandle;
181 PORT_VIEW LpcWrite;
182 REMOTE_PORT_VIEW LpcRead;
183 SECURITY_QUALITY_OF_SERVICE SecurityQos;
184 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
185 PSID SystemSid = NULL;
186 CSR_CONNECTION_INFO ConnectionInfo;
187 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
188
189 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
190
191 /* Binary compatibility with MS KERNEL32 */
192 if (NULL == ObjectDirectory)
193 {
194 ObjectDirectory = L"\\Windows";
195 }
196
197 /* Calculate the total port name size */
198 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
199 sizeof(CSR_PORT_NAME);
200
201 /* Set the port name */
202 PortName.Length = 0;
203 PortName.MaximumLength = PortNameLength;
204
205 /* Allocate a buffer for it */
206 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
207 if (PortName.Buffer == NULL)
208 {
209 return STATUS_INSUFFICIENT_RESOURCES;
210 }
211
212 /* Create the name */
213 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
214 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
215 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
216
217 /* Create a section for the port memory */
218 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
219 Status = NtCreateSection(&CsrSectionHandle,
220 SECTION_ALL_ACCESS,
221 NULL,
222 &CsrSectionViewSize,
223 PAGE_READWRITE,
224 SEC_COMMIT,
225 NULL);
226 if (!NT_SUCCESS(Status))
227 {
228 DPRINT1("Failure allocating CSR Section\n");
229 return Status;
230 }
231
232 /* Set up the port view structures to match them with the section */
233 LpcWrite.Length = sizeof(PORT_VIEW);
234 LpcWrite.SectionHandle = CsrSectionHandle;
235 LpcWrite.SectionOffset = 0;
236 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
237 LpcWrite.ViewBase = 0;
238 LpcWrite.ViewRemoteBase = 0;
239 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
240 LpcRead.ViewSize = 0;
241 LpcRead.ViewBase = 0;
242
243 /* Setup the QoS */
244 SecurityQos.ImpersonationLevel = SecurityImpersonation;
245 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
246 SecurityQos.EffectiveOnly = TRUE;
247
248 /* Setup the connection info */
249 ConnectionInfo.Version = 0x10000;
250
251 /* Create a SID for us */
252 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
253 1,
254 SECURITY_LOCAL_SYSTEM_RID,
255 0,
256 0,
257 0,
258 0,
259 0,
260 0,
261 0,
262 &SystemSid);
263 if (!NT_SUCCESS(Status))
264 {
265 /* Failure */
266 DPRINT1("Couldn't allocate SID\n");
267 NtClose(CsrSectionHandle);
268 return Status;
269 }
270
271 /* Connect to the port */
272 Status = NtSecureConnectPort(&CsrApiPort,
273 &PortName,
274 &SecurityQos,
275 &LpcWrite,
276 SystemSid,
277 &LpcRead,
278 NULL,
279 &ConnectionInfo,
280 &ConnectionInfoLength);
281 NtClose(CsrSectionHandle);
282 if (!NT_SUCCESS(Status))
283 {
284 /* Failure */
285 DPRINT1("Couldn't connect to CSR port\n");
286 return Status;
287 }
288
289 /* Save the delta between the sections, for capture usage later */
290 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
291 (ULONG_PTR)LpcWrite.ViewBase;
292
293 /* Save the Process */
294 CsrProcessId = ConnectionInfo.ProcessId;
295
296 /* Save CSR Section data */
297 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
298 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
299 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
300
301 /* Create the port heap */
302 CsrPortHeap = RtlCreateHeap(0,
303 LpcWrite.ViewBase,
304 LpcWrite.ViewSize,
305 PAGE_SIZE,
306 0,
307 0);
308 if (CsrPortHeap == NULL)
309 {
310 NtClose(CsrApiPort);
311 CsrApiPort = NULL;
312 return STATUS_INSUFFICIENT_RESOURCES;
313 }
314
315 /* Return success */
316 return STATUS_SUCCESS;
317 }
318
319 /*
320 * @implemented
321 */
322 NTSTATUS
323 NTAPI
324 CsrClientConnectToServer(PWSTR ObjectDirectory,
325 ULONG ServerId,
326 PVOID ConnectionInfo,
327 PULONG ConnectionInfoSize,
328 PBOOLEAN ServerToServerCall)
329 {
330 NTSTATUS Status;
331 PIMAGE_NT_HEADERS NtHeader;
332 UNICODE_STRING CsrSrvName;
333 HANDLE hCsrSrv;
334 ANSI_STRING CsrServerRoutineName;
335 PCSR_CAPTURE_BUFFER CaptureBuffer;
336 CSR_API_MESSAGE RosApiMessage;
337 CSR_API_MESSAGE2 ApiMessage;
338 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;
339
340 /* Validate the Connection Info */
341 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
342 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
343 {
344 DPRINT1("Connection info given, but no length\n");
345 return STATUS_INVALID_PARAMETER;
346 }
347
348 /* Check if we're inside a CSR Process */
349 if (InsideCsrProcess)
350 {
351 /* Tell the client that we're already inside CSR */
352 if (ServerToServerCall) *ServerToServerCall = TRUE;
353 return STATUS_SUCCESS;
354 }
355
356 /*
357 * We might be in a CSR Process but not know it, if this is the first call.
358 * So let's find out.
359 */
360 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
361 {
362 /* The image isn't valid */
363 DPRINT1("Invalid image\n");
364 return STATUS_INVALID_IMAGE_FORMAT;
365 }
366 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
367
368 /* Now we can check if we are inside or not */
369 if (InsideCsrProcess && !UsingOldCsr)
370 {
371 /* We're inside, so let's find csrsrv */
372 DbgBreakPoint();
373 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
374 Status = LdrGetDllHandle(NULL,
375 NULL,
376 &CsrSrvName,
377 &hCsrSrv);
378
379 /* Now get the Server to Server routine */
380 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
381 Status = LdrGetProcedureAddress(hCsrSrv,
382 &CsrServerRoutineName,
383 0L,
384 (PVOID*)&CsrServerApiRoutine);
385
386 /* Use the local heap as port heap */
387 CsrPortHeap = RtlGetProcessHeap();
388
389 /* Tell the caller we're inside the server */
390 *ServerToServerCall = InsideCsrProcess;
391 return STATUS_SUCCESS;
392 }
393
394 /* Now check if connection info is given */
395 if (ConnectionInfo)
396 {
397 /* Well, we're defintely in a client now */
398 InsideCsrProcess = FALSE;
399
400 /* Do we have a connection to CSR yet? */
401 if (!CsrApiPort)
402 {
403 /* No, set it up now */
404 if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))
405 {
406 /* Failed */
407 DPRINT1("Failure to connect to CSR\n");
408 return Status;
409 }
410 }
411
412 /* Setup the connect message header */
413 ClientConnect->ServerId = ServerId;
414 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
415
416 /* Setup a buffer for the connection info */
417 CaptureBuffer = CsrAllocateCaptureBuffer(1,
418 ClientConnect->ConnectionInfoSize);
419 if (CaptureBuffer == NULL)
420 {
421 return STATUS_INSUFFICIENT_RESOURCES;
422 }
423
424 /* Allocate a pointer for the connection info*/
425 CsrAllocateMessagePointer(CaptureBuffer,
426 ClientConnect->ConnectionInfoSize,
427 &ClientConnect->ConnectionInfo);
428
429 /* Copy the data into the buffer */
430 RtlMoveMemory(ClientConnect->ConnectionInfo,
431 ConnectionInfo,
432 ClientConnect->ConnectionInfoSize);
433
434 /* Return the allocated length */
435 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
436
437 /* Call CSR */
438 #if 0
439 Status = CsrClientCallServer(&ApiMessage,
440 CaptureBuffer,
441 CSR_MAKE_OPCODE(CsrSrvClientConnect,
442 CSR_SRV_DLL),
443 sizeof(CSR_CLIENT_CONNECT));
444 #endif
445 Status = CsrClientCallServer(&RosApiMessage,
446 NULL,
447 MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),
448 sizeof(CSR_API_MESSAGE));
449 }
450 else
451 {
452 /* No connection info, just return */
453 Status = STATUS_SUCCESS;
454 }
455
456 /* Let the caller know if this was server to server */
457 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
458 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
459 return Status;
460 }
461
462 /* EOF */