Check for failed allocations. Spotted by Martin Bealby.
[reactos.git] / reactos / lib / ntdll / csr / connect.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS kernel
4 * FILE: lib/ntdll/csr/connect.c
5 * PURPOSE: Routines for connecting and calling CSR
6 * PROGRAMMER: Alex Ionescu (alex@relsoft.net)
7 */
8
9 /* INCLUDES *****************************************************************/
10
11 #include <ntdll.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* GLOBALS *******************************************************************/
16
17 HANDLE CsrApiPort;
18 HANDLE CsrProcessId;
19 HANDLE CsrPortHeap;
20 ULONG_PTR CsrPortMemoryDelta;
21 BOOLEAN InsideCsrProcess = FALSE;
22 BOOLEAN UsingOldCsr = TRUE;
23
24 typedef NTSTATUS
25 (NTAPI *PCSR_SERVER_API_ROUTINE)(IN PPORT_MESSAGE Request,
26 IN PPORT_MESSAGE Reply);
27
28 PCSR_SERVER_API_ROUTINE CsrServerApiRoutine;
29
30 #define UNICODE_PATH_SEP L"\\"
31 #define CSR_PORT_NAME L"ApiPort"
32
33 /* FUNCTIONS *****************************************************************/
34
35 /*
36 * @implemented
37 */
38 HANDLE
39 NTAPI
40 CsrGetProcessId(VOID)
41 {
42 return CsrProcessId;
43 }
44
45 /*
46 * @implemented
47 */
48 NTSTATUS
49 NTAPI
50 CsrClientCallServer(PCSR_API_MESSAGE ApiMessage,
51 PCSR_CAPTURE_BUFFER CaptureBuffer OPTIONAL,
52 CSR_API_NUMBER ApiNumber,
53 ULONG RequestLength)
54 {
55 NTSTATUS Status;
56 ULONG PointerCount;
57 PULONG_PTR Pointers;
58 ULONG_PTR CurrentPointer;
59 DPRINT("CsrClientCallServer\n");
60
61 /* Fill out the Port Message Header */
62 ApiMessage->Header.u1.s1.DataLength = RequestLength - sizeof(PORT_MESSAGE);
63 ApiMessage->Header.u1.s1.TotalLength = RequestLength;
64
65 /* Fill out the CSR Header */
66 ApiMessage->Type = ApiNumber;
67 //ApiMessage->Opcode = ApiNumber; <- Activate with new CSR
68 ApiMessage->CsrCaptureData = NULL;
69
70 DPRINT("API: %lx, u1.s1.DataLength: %x, u1.s1.TotalLength: %x\n",
71 ApiNumber,
72 ApiMessage->Header.u1.s1.DataLength,
73 ApiMessage->Header.u1.s1.TotalLength);
74
75 /* Check if we are already inside a CSR Server */
76 if (!InsideCsrProcess)
77 {
78 /* Check if we got a a Capture Buffer */
79 if (CaptureBuffer)
80 {
81 /* We have to convert from our local view to the remote view */
82 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)CaptureBuffer +
83 CsrPortMemoryDelta);
84
85 /* Lock the buffer */
86 CaptureBuffer->BufferEnd = 0;
87
88 /* Get the pointer information */
89 PointerCount = CaptureBuffer->PointerCount;
90 Pointers = CaptureBuffer->PointerArray;
91
92 /* Loop through every pointer and convert it */
93 DPRINT("PointerCount: %lx\n", PointerCount);
94 while (PointerCount--)
95 {
96 /* Get this pointer and check if it's valid */
97 DPRINT("Array Address: %p. This pointer: %p. Data: %lx\n",
98 &Pointers, Pointers, *Pointers);
99 if ((CurrentPointer = *Pointers++))
100 {
101 /* Update it */
102 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
103 *(PULONG_PTR)CurrentPointer += CsrPortMemoryDelta;
104 Pointers[-1] = CurrentPointer - (ULONG_PTR)ApiMessage;
105 DPRINT("CurrentPointer: %lx.\n", *(PULONG_PTR)CurrentPointer);
106 }
107 }
108 }
109
110 /* Send the LPC Message */
111 Status = NtRequestWaitReplyPort(CsrApiPort,
112 &ApiMessage->Header,
113 &ApiMessage->Header);
114
115 /* Check if we got a a Capture Buffer */
116 if (CaptureBuffer)
117 {
118 /* We have to convert from the remote view to our remote view */
119 DPRINT("Reconverting CaptureBuffer\n");
120 ApiMessage->CsrCaptureData = (PVOID)((ULONG_PTR)
121 ApiMessage->CsrCaptureData -
122 CsrPortMemoryDelta);
123
124 /* Get the pointer information */
125 PointerCount = CaptureBuffer->PointerCount;
126 Pointers = CaptureBuffer->PointerArray;
127
128 /* Loop through every pointer and convert it */
129 while (PointerCount--)
130 {
131 /* Get this pointer and check if it's valid */
132 if ((CurrentPointer = *Pointers++))
133 {
134 /* Update it */
135 CurrentPointer += (ULONG_PTR)ApiMessage;
136 Pointers[-1] = CurrentPointer;
137 *(PULONG_PTR)CurrentPointer -= CsrPortMemoryDelta;
138 }
139 }
140 }
141
142 /* Check for success */
143 if (!NT_SUCCESS(Status))
144 {
145 /* We failed. Overwrite the return value with the failure */
146 DPRINT1("LPC Failed: %lx\n", Status);
147 ApiMessage->Status = Status;
148 }
149 }
150 else
151 {
152 /* This is a server-to-server call. Save our CID and do a direct call */
153 DbgBreakPoint();
154 ApiMessage->Header.ClientId = NtCurrentTeb()->Cid;
155 Status = CsrServerApiRoutine(&ApiMessage->Header,
156 &ApiMessage->Header);
157
158 /* Check for success */
159 if (!NT_SUCCESS(Status))
160 {
161 /* We failed. Overwrite the return value with the failure */
162 ApiMessage->Status = Status;
163 }
164 }
165
166 /* Return the CSR Result */
167 DPRINT("Got back: 0x%lx\n", ApiMessage->Status);
168 return ApiMessage->Status;
169 }
170
171 NTSTATUS
172 NTAPI
173 CsrConnectToServer(IN PWSTR ObjectDirectory)
174 {
175 ULONG PortNameLength;
176 UNICODE_STRING PortName;
177 LARGE_INTEGER CsrSectionViewSize;
178 NTSTATUS Status;
179 HANDLE CsrSectionHandle;
180 PORT_VIEW LpcWrite;
181 REMOTE_PORT_VIEW LpcRead;
182 SECURITY_QUALITY_OF_SERVICE SecurityQos;
183 SID_IDENTIFIER_AUTHORITY NtSidAuthority = {SECURITY_NT_AUTHORITY};
184 PSID SystemSid = NULL;
185 CSR_CONNECTION_INFO ConnectionInfo;
186 ULONG ConnectionInfoLength = sizeof(CSR_CONNECTION_INFO);
187
188 DPRINT("%s(%S)\n", __FUNCTION__, ObjectDirectory);
189
190 /* Binary compatibility with MS KERNEL32 */
191 if (NULL == ObjectDirectory)
192 {
193 ObjectDirectory = L"\\Windows";
194 }
195
196 /* Calculate the total port name size */
197 PortNameLength = ((wcslen(ObjectDirectory) + 1) * sizeof(WCHAR)) +
198 sizeof(CSR_PORT_NAME);
199
200 /* Set the port name */
201 PortName.Length = 0;
202 PortName.MaximumLength = PortNameLength;
203
204 /* Allocate a buffer for it */
205 PortName.Buffer = RtlAllocateHeap(RtlGetProcessHeap(), 0, PortNameLength);
206 if (PortName.Buffer == NULL)
207 {
208 return STATUS_INSUFFICIENT_RESOURCES;
209 }
210
211 /* Create the name */
212 RtlAppendUnicodeToString(&PortName, ObjectDirectory );
213 RtlAppendUnicodeToString(&PortName, UNICODE_PATH_SEP);
214 RtlAppendUnicodeToString(&PortName, CSR_PORT_NAME);
215
216 /* Create a section for the port memory */
217 CsrSectionViewSize.QuadPart = CSR_CSRSS_SECTION_SIZE;
218 Status = NtCreateSection(&CsrSectionHandle,
219 SECTION_ALL_ACCESS,
220 NULL,
221 &CsrSectionViewSize,
222 PAGE_READWRITE,
223 SEC_COMMIT,
224 NULL);
225 if (!NT_SUCCESS(Status))
226 {
227 DPRINT1("Failure allocating CSR Section\n");
228 return Status;
229 }
230
231 /* Set up the port view structures to match them with the section */
232 LpcWrite.Length = sizeof(PORT_VIEW);
233 LpcWrite.SectionHandle = CsrSectionHandle;
234 LpcWrite.SectionOffset = 0;
235 LpcWrite.ViewSize = CsrSectionViewSize.u.LowPart;
236 LpcWrite.ViewBase = 0;
237 LpcWrite.ViewRemoteBase = 0;
238 LpcRead.Length = sizeof(REMOTE_PORT_VIEW);
239 LpcRead.ViewSize = 0;
240 LpcRead.ViewBase = 0;
241
242 /* Setup the QoS */
243 SecurityQos.ImpersonationLevel = SecurityImpersonation;
244 SecurityQos.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
245 SecurityQos.EffectiveOnly = TRUE;
246
247 /* Setup the connection info */
248 ConnectionInfo.Version = 0x10000;
249
250 /* Create a SID for us */
251 Status = RtlAllocateAndInitializeSid(&NtSidAuthority,
252 1,
253 SECURITY_LOCAL_SYSTEM_RID,
254 0,
255 0,
256 0,
257 0,
258 0,
259 0,
260 0,
261 &SystemSid);
262 if (!NT_SUCCESS(Status))
263 {
264 /* Failure */
265 DPRINT1("Couldn't allocate SID\n");
266 NtClose(CsrSectionHandle);
267 return Status;
268 }
269
270 /* Connect to the port */
271 Status = NtSecureConnectPort(&CsrApiPort,
272 &PortName,
273 &SecurityQos,
274 &LpcWrite,
275 SystemSid,
276 &LpcRead,
277 NULL,
278 &ConnectionInfo,
279 &ConnectionInfoLength);
280 NtClose(CsrSectionHandle);
281 if (!NT_SUCCESS(Status))
282 {
283 /* Failure */
284 DPRINT1("Couldn't connect to CSR port\n");
285 return Status;
286 }
287
288 /* Save the delta between the sections, for capture usage later */
289 CsrPortMemoryDelta = (ULONG_PTR)LpcWrite.ViewRemoteBase -
290 (ULONG_PTR)LpcWrite.ViewBase;
291
292 /* Save the Process */
293 CsrProcessId = ConnectionInfo.ProcessId;
294
295 /* Save CSR Section data */
296 NtCurrentPeb()->ReadOnlySharedMemoryBase = ConnectionInfo.SharedSectionBase;
297 NtCurrentPeb()->ReadOnlySharedMemoryHeap = ConnectionInfo.SharedSectionHeap;
298 NtCurrentPeb()->ReadOnlyStaticServerData = ConnectionInfo.SharedSectionData;
299
300 /* Create the port heap */
301 CsrPortHeap = RtlCreateHeap(0,
302 LpcWrite.ViewBase,
303 LpcWrite.ViewSize,
304 PAGE_SIZE,
305 0,
306 0);
307 if (CsrPortHeap == NULL)
308 {
309 NtClose(CsrApiPort);
310 CsrApiPort = NULL;
311 return STATUS_INSUFFICIENT_RESOURCES;
312 }
313
314 /* Return success */
315 return STATUS_SUCCESS;
316 }
317
318 /*
319 * @implemented
320 */
321 NTSTATUS
322 NTAPI
323 CsrClientConnectToServer(PWSTR ObjectDirectory,
324 ULONG ServerId,
325 PVOID ConnectionInfo,
326 PULONG ConnectionInfoSize,
327 PBOOLEAN ServerToServerCall)
328 {
329 NTSTATUS Status;
330 PIMAGE_NT_HEADERS NtHeader;
331 UNICODE_STRING CsrSrvName;
332 HANDLE hCsrSrv;
333 ANSI_STRING CsrServerRoutineName;
334 PCSR_CAPTURE_BUFFER CaptureBuffer;
335 CSR_API_MESSAGE RosApiMessage;
336 CSR_API_MESSAGE2 ApiMessage;
337 PCSR_CLIENT_CONNECT ClientConnect = &ApiMessage.ClientConnect;
338
339 /* Validate the Connection Info */
340 DPRINT("CsrClientConnectToServer: %lx %p\n", ServerId, ConnectionInfo);
341 if (ConnectionInfo && (!ConnectionInfoSize || !*ConnectionInfoSize))
342 {
343 DPRINT1("Connection info given, but no length\n");
344 return STATUS_INVALID_PARAMETER;
345 }
346
347 /* Check if we're inside a CSR Process */
348 if (InsideCsrProcess)
349 {
350 /* Tell the client that we're already inside CSR */
351 if (ServerToServerCall) *ServerToServerCall = TRUE;
352 return STATUS_SUCCESS;
353 }
354
355 /*
356 * We might be in a CSR Process but not know it, if this is the first call.
357 * So let's find out.
358 */
359 if (!(NtHeader = RtlImageNtHeader(NtCurrentPeb()->ImageBaseAddress)))
360 {
361 /* The image isn't valid */
362 DPRINT1("Invalid image\n");
363 return STATUS_INVALID_IMAGE_FORMAT;
364 }
365 InsideCsrProcess = (NtHeader->OptionalHeader.Subsystem == IMAGE_SUBSYSTEM_NATIVE);
366
367 /* Now we can check if we are inside or not */
368 if (InsideCsrProcess && !UsingOldCsr)
369 {
370 /* We're inside, so let's find csrsrv */
371 DbgBreakPoint();
372 RtlInitUnicodeString(&CsrSrvName, L"csrsrv");
373 Status = LdrGetDllHandle(NULL,
374 NULL,
375 &CsrSrvName,
376 &hCsrSrv);
377 RtlFreeUnicodeString(&CsrSrvName);
378
379 /* Now get the Server to Server routine */
380 RtlInitAnsiString(&CsrServerRoutineName, "CsrCallServerFromServer");
381 Status = LdrGetProcedureAddress(hCsrSrv,
382 &CsrServerRoutineName,
383 0L,
384 (PVOID*)&CsrServerApiRoutine);
385
386 /* Use the local heap as port heap */
387 CsrPortHeap = RtlGetProcessHeap();
388
389 /* Tell the caller we're inside the server */
390 *ServerToServerCall = InsideCsrProcess;
391 return STATUS_SUCCESS;
392 }
393
394 /* Now check if connection info is given */
395 if (ConnectionInfo)
396 {
397 /* Well, we're defintely in a client now */
398 InsideCsrProcess = FALSE;
399
400 /* Do we have a connection to CSR yet? */
401 if (!CsrApiPort)
402 {
403 /* No, set it up now */
404 if (!NT_SUCCESS(Status = CsrConnectToServer(ObjectDirectory)))
405 {
406 /* Failed */
407 DPRINT1("Failure to connect to CSR\n");
408 return Status;
409 }
410 }
411
412 /* Setup the connect message header */
413 ClientConnect->ServerId = ServerId;
414 ClientConnect->ConnectionInfoSize = *ConnectionInfoSize;
415
416 /* Setup a buffer for the connection info */
417 CaptureBuffer = CsrAllocateCaptureBuffer(1,
418 ClientConnect->ConnectionInfoSize);
419 if (CaptureBuffer == NULL)
420 {
421 return STATUS_INSUFFICIENT_RESOURCES;
422 }
423
424 /* Allocate a pointer for the connection info*/
425 CsrAllocateMessagePointer(CaptureBuffer,
426 ClientConnect->ConnectionInfoSize,
427 &ClientConnect->ConnectionInfo);
428
429 /* Copy the data into the buffer */
430 RtlMoveMemory(ClientConnect->ConnectionInfo,
431 ConnectionInfo,
432 ClientConnect->ConnectionInfoSize);
433
434 /* Return the allocated length */
435 *ConnectionInfoSize = ClientConnect->ConnectionInfoSize;
436
437 /* Call CSR */
438 #if 0
439 Status = CsrClientCallServer(&ApiMessage,
440 CaptureBuffer,
441 CSR_MAKE_OPCODE(CsrSrvClientConnect,
442 CSR_SRV_DLL),
443 sizeof(CSR_CLIENT_CONNECT));
444 #endif
445 Status = CsrClientCallServer(&RosApiMessage,
446 NULL,
447 MAKE_CSR_API(CONNECT_PROCESS, CSR_NATIVE),
448 sizeof(CSR_API_MESSAGE));
449 }
450 else
451 {
452 /* No connection info, just return */
453 Status = STATUS_SUCCESS;
454 }
455
456 /* Let the caller know if this was server to server */
457 DPRINT("Status was: 0x%lx. Are we in server: 0x%x\n", Status, InsideCsrProcess);
458 if (ServerToServerCall) *ServerToServerCall = InsideCsrProcess;
459 return Status;
460 }
461
462 /* EOF */