4c2f4616a48325a3db978052e492f953d7d65179
[reactos.git] / reactos / lib / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *****************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS *******************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22 BOOLEAN UsingOldCsr = TRUE;
23
24 typedef NTSTATUS
25 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
26 IN PPORT_MESSAGE Reply);
27
28 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
29
30 #define UNICODE_PATH_SEP L"\\"
31 #define CSR_PORT_NAME L"ApiPort"
32
33 /* FUNCTIONS *****************************************************************/
34
35 /*
36 * @implemented
37 */
38 HANDLE
39 NTAPI
40 CsrGetProcessId(VOID)
41 {
42 return CsrProcessId;
43 }
44
45 /*
46 * @implemented
47 */
48 NTSTATUS
49 NTAPI
50 CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,
51 PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
52 CSR_API_NUMBER ApiNumber,
53 ULONG RequestLength)
54 {
55 NTSTATUS Status;
56 ULONG PointerCount;
57 PULONG_PTR Pointers;
58 ULONG_PTR CurrentPointer;
59 DPRINT("CsrClientCallServer\n");
60
61 /* Fill out the Port Message Header */
62 ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);
63 ApiMessage->Header.u1.s1.TotalLength = RequestLength;
64
65 /* Fill out the CSR Header */
66 ApiMessage->Type = ApiNumber;
67 //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR
68 ApiMessage->CsrCaptureData = NULL;
69
70 DPRINT("API: %x, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
71 ApiNumber,
72 ApiMessage->Header.u1.s1.DataLength,
73 ApiMessage->Header.u1.s1.TotalLength);
74
75 /* Check if we are already inside a CSR Server */
76 if (!InsideCsrProcess)
77 {
78 /* Check if we got a a Capture Buffer */
79 if (CaptureBuffer)
80 {
81 /* We have to convert from our local view to the remote view */
82 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
83 CsrPortMemoryDelta);
84
85 /* Lock the buffer */
86 CaptureBuffer->BufferEnd = 0;
87
88 /* Get the pointer information */
89 PointerCount = CaptureBuffer->PointerCount;
90 Pointers = CaptureBuffer->PointerArray;
91
92 /* Loop through every pointer and convert it */
93 DPRINT("PointerCount: %lx\n", PointerCount);
94 while (PointerCount--)
95 {
96 /* Get this pointer and check if it's valid */
97 DPRINT("Array Address: %p. This pointer: %p. Data: %p\n",
98 &Pointers, Pointers, *Pointers);
99 if ((CurrentPointer = *Pointers++))
100 {
101 /* Update it */
102 DPRINT("CurrentPointer: %p.\n", *(PULONG_PTR)CurrentPointer);
103 *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
104 Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
105 DPRINT("CurrentPointer: %p.\n", *(PULONG_PTR)CurrentPointer);
106 }
107 }
108 }
109
110 /* Send the LPC Message */
111 Status = NtRequestWaitReplyPort(CsrApiPort,
112 &ApiMessage->Header,
113 &ApiMessage->Header);
114
115 /* Check if we got a a Capture Buffer */
116 if (CaptureBuffer)
117 {
118 /* We have to convert from the remote view to our remote view */
119 DPRINT("Reconverting CaptureBuffer\n");
120 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
121 ApiMessage->CsrCaptureData -
122 CsrPortMemoryDelta);
123
124 /* Get the pointer information */
125 PointerCount = CaptureBuffer->PointerCount;
126 Pointers = CaptureBuffer->PointerArray;
127
128 /* Loop through every pointer and convert it */
129 while (PointerCount--)
130 {
131 /* Get this pointer and check if it's valid */
132 if ((CurrentPointer = *Pointers++))
133 {
134 /* Update it */
135 CurrentPointer += (ULONG_PTR)ApiMessage;
136 Pointers[-1] = CurrentPointer;
137 *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
138 }
139 }
140 }
141
142 /* Check for success */
143 if (!NT_SUCCESS(Status))
144 {
145 /* We failed. Overwrite the return value with the failure */
146 DPRINT1("LPC Failed: %lx\n", Status);
147 ApiMessage->Status = Status;
148 }
149 }
150 else
151 {
152 /* This is a server-to-server call. Save our CID and do a direct call */
153 DbgBreakPoint();
154 ApiMessage->Header.ClientId = NtCurrentTeb()->Cid;
155 Status = CsrServerApiRoutine(&ApiMessage->Header,
156 &ApiMessage->Header);
157
158 /* Check for success */
159 if (!NT_SUCCESS(Status))
160 {
161 /* We failed. Overwrite the return value with the failure */
162 ApiMessage->Status = Status;
163 }
164 }
165
166 /* Return the CSR Result */
167 DPRINT("Got back: %x\n", ApiMessage->Status);
168 return ApiMessage->Status;
169 }
170
171 NTSTATUS
172 NTAPI
173 CsrConnectToServer(IN PWSTR ObjectDirectory)
174 {
175 ULONG PortNameLength;
176 UNICODE_STRING PortName;
177 LARGE_INTEGER CsrSectionViewSize;
178 NTSTATUS Status;
179 HANDLE CsrSectionHandle;
180 PORT_VIEW LpcWrite;
181 REMOTE_PORT_VIEW LpcRead;
182 SECURITY_QUALITY_OF_SERVICE SecurityQos;
183 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
184 PSID SystemSid = NULL;
185 CSR_CONNECTION_INFO ConnectionInfo;
186 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
187
188 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
189
190 /* Binary compatibility with MS KERNEL32 */
191 if (NULL == ObjectDirectory)
192 {
193 ObjectDirectory = L"\\Windows";
194 }
195
196 /* Calculate the total port name size */
197 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
198 sizeof(CSR_PORT_NAME);
199
200 /* Set the port name */
201 PortName.Length = 0;
202 PortName.MaximumLength = PortNameLength;
203
204 /* Allocate a buffer for it */
205 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
206
207 /* Create the name */
208 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
209 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
210 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
211
212 /* Create a section for the port memory */
213 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
214 Status = NtCreateSection(&CsrSectionHandle,
215 SECTION_ALL_ACCESS,
216 NULL,
217 &CsrSectionViewSize,
218 PAGE_READWRITE,
219 SEC_COMMIT,
220 NULL);
221 if (!NT_SUCCESS(Status))
222 {
223 DPRINT1("Failure allocating CSR Section\n");
224 return Status;
225 }
226
227 /* Set up the port view structures to match them with the section */
228 LpcWrite.Length = sizeof(PORT_VIEW);
229 LpcWrite.SectionHandle = CsrSectionHandle;
230 LpcWrite.SectionOffset = 0;
231 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
232 LpcWrite.ViewBase = 0;
233 LpcWrite.ViewRemoteBase = 0;
234 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
235 LpcRead.ViewSize = 0;
236 LpcRead.ViewBase = 0;
237
238 /* Setup the QoS */
239 SecurityQos.ImpersonationLevel = SecurityImpersonation;
240 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
241 SecurityQos.EffectiveOnly = TRUE;
242
243 /* Setup the connection info */
244 ConnectionInfo.Version = 0x10000;
245
246 /* Create a SID for us */
247 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
248 1,
249 SECURITY_LOCAL_SYSTEM_RID,
250 0,
251 0,
252 0,
253 0,
254 0,
255 0,
256 0,
257 &SystemSid);
258
259 /* Connect to the port */
260 Status = NtSecureConnectPort(&CsrApiPort,
261 &PortName,
262 &SecurityQos,
263 &LpcWrite,
264 SystemSid,
265 &LpcRead,
266 NULL,
267 &ConnectionInfo,
268 &ConnectionInfoLength);
269 NtClose(CsrSectionHandle);
270 if (!NT_SUCCESS(Status))
271 {
272 /* Failure */
273 DPRINT1("Couldn't connect to CSR port\n");
274 return Status;
275 }
276
277 /* Save the delta between the sections, for capture usage later */
278 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
279 (ULONG_PTR)LpcWrite.ViewBase;
280
281 /* Save the Process */
282 CsrProcessId = ConnectionInfo.ProcessId;
283
284 /* Save CSR Section data */
285 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
286 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
287 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
288
289 /* Create the port heap */
290 CsrPortHeap = RtlCreateHeap(0,
291 LpcWrite.ViewBase,
292 LpcWrite.ViewSize,
293 PAGE_SIZE,
294 0,
295 0);
296
297 /* Return success */
298 return STATUS_SUCCESS;
299 }
300
301 /*
302 * @implemented
303 */
304 NTSTATUS
305 NTAPI
306 CsrClientConnectToServer(PWSTR ObjectDirectory,
307 ULONG ServerId,
308 PVOID ConnectionInfo,
309 PULONG ConnectionInfoSize,
310 PBOOLEAN ServerToServerCall)
311 {
312 NTSTATUS Status;
313 PIMAGE_NT_HEADERS NtHeader;
314 UNICODE_STRING CsrSrvName;
315 HANDLE hCsrSrv;
316 ANSI_STRING CsrServerRoutineName;
317 PCSR_CAPTURE_BUFFER CaptureBuffer;
318 CSR_API_MESSAGE RosApiMessage;
319 CSR_API_MESSAGE2 ApiMessage;
320 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;
321
322 /* Validate the Connection Info */
323 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
324 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
325 {
326 DPRINT1("Connection info given, but no length\n");
327 return STATUS_INVALID_PARAMETER;
328 }
329
330 /* Check if we're inside a CSR Process */
331 if (InsideCsrProcess)
332 {
333 /* Tell the client that we're already inside CSR */
334 if (ServerToServerCall) *ServerToServerCall = TRUE;
335 return STATUS_SUCCESS;
336 }
337
338 /*
339 * We might be in a CSR Process but not know it, if this is the first call.
340 * So let's find out.
341 */
342 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
343 {
344 /* The image isn't valid */
345 DPRINT1("Invalid image\n");
346 return STATUS_INVALID_IMAGE_FORMAT;
347 }
348 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
349
350 /* Now we can check if we are inside or not */
351 if (InsideCsrProcess && !UsingOldCsr)
352 {
353 /* We're inside, so let's find csrsrv */
354 DbgBreakPoint();
355 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
356 Status = LdrGetDllHandle(NULL,
357 NULL,
358 &CsrSrvName,
359 &hCsrSrv);
360 RtlFreeUnicodeString(&CsrSrvName);
361
362 /* Now get the Server to Server routine */
363 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
364 Status = LdrGetProcedureAddress(hCsrSrv,
365 &CsrServerRoutineName,
366 0L,
367 (PVOID*)&CsrServerApiRoutine);
368
369 /* Use the local heap as port heap */
370 CsrPortHeap = RtlGetProcessHeap();
371
372 /* Tell the caller we're inside the server */
373 *ServerToServerCall = InsideCsrProcess;
374 return STATUS_SUCCESS;
375 }
376
377 /* Now check if connection info is given */
378 if (ConnectionInfo)
379 {
380 /* Well, we're defintely in a client now */
381 InsideCsrProcess = FALSE;
382
383 /* Do we have a connection to CSR yet? */
384 if (!CsrApiPort)
385 {
386 /* No, set it up now */
387 if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))
388 {
389 /* Failed */
390 DPRINT1("Failure to connect to CSR\n");
391 return Status;
392 }
393 }
394
395 /* Setup the connect message header */
396 ClientConnect->ServerId = ServerId;
397 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
398
399 /* Setup a buffer for the connection info */
400 CaptureBuffer = CsrAllocateCaptureBuffer(1,
401 ClientConnect->ConnectionInfoSize);
402
403 /* Allocate a pointer for the connection info*/
404 CsrAllocateMessagePointer(CaptureBuffer,
405 ClientConnect->ConnectionInfoSize,
406 &ClientConnect->ConnectionInfo);
407
408 /* Copy the data into the buffer */
409 RtlMoveMemory(ClientConnect->ConnectionInfo,
410 ConnectionInfo,
411 ClientConnect->ConnectionInfoSize);
412
413 /* Return the allocated length */
414 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
415
416 /* Call CSR */
417 #if 0
418 Status = CsrClientCallServer(&ApiMessage,
419 CaptureBuffer,
420 CSR_MAKE_OPCODE(CsrSrvClientConnect,
421 CSR_SRV_DLL),
422 sizeof(CSR_CLIENT_CONNECT));
423 #endif
424 Status = CsrClientCallServer(&RosApiMessage,
425 NULL,
426 MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),
427 sizeof(CSR_API_MESSAGE));
428 }
429 else
430 {
431 /* No connection info, just return */
432 Status = STATUS_SUCCESS;
433 }
434
435 /* Let the caller know if this was server to server */
436 DPRINT("Status was: %lx. Are we in server: %lx\n", Status, InsideCsrProcess);
437 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
438 return Status;
439 }
440
441 /* EOF */