Sync with trunk r64222.
[reactos.git] / ntoskrnl / lpc / connect.c
1 /*
2 * PROJECT: ReactOS Kernel
3 * LICENSE: GPL - See COPYING in the top level directory
4 * FILE: ntoskrnl/lpc/connect.c
5 * PURPOSE: Local Procedure Call: Connection Management
6 * PROGRAMMERS: Alex Ionescu (alex.ionescu@reactos.org)
7 */
8
9 /* INCLUDES ******************************************************************/
10
11 #include <ntoskrnl.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* PRIVATE FUNCTIONS *********************************************************/
16
17 PVOID
18 NTAPI
19 LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
20 IN OUT PLPCP_CONNECTION_MESSAGE *ConnectMessage,
21 IN PETHREAD CurrentThread)
22 {
23 PVOID SectionToMap;
24 PLPCP_MESSAGE ReplyMessage;
25
26 /* Acquire the LPC lock */
27 KeAcquireGuardedMutex(&LpcpLock);
28
29 /* Check if the reply chain is not empty */
30 if (!IsListEmpty(&CurrentThread->LpcReplyChain))
31 {
32 /* Remove this entry and re-initialize it */
33 RemoveEntryList(&CurrentThread->LpcReplyChain);
34 InitializeListHead(&CurrentThread->LpcReplyChain);
35 }
36
37 /* Check if there's a reply message */
38 ReplyMessage = LpcpGetMessageFromThread(CurrentThread);
39 if (ReplyMessage)
40 {
41 /* Get the message */
42 *Message = ReplyMessage;
43
44 /* Check if it's got messages */
45 if (!IsListEmpty(&ReplyMessage->Entry))
46 {
47 /* Clear the list */
48 RemoveEntryList(&ReplyMessage->Entry);
49 InitializeListHead(&ReplyMessage->Entry);
50 }
51
52 /* Clear message data */
53 CurrentThread->LpcReceivedMessageId = 0;
54 CurrentThread->LpcReplyMessage = NULL;
55
56 /* Get the connection message and clear the section */
57 *ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(ReplyMessage + 1);
58 SectionToMap = (*ConnectMessage)->SectionToMap;
59 (*ConnectMessage)->SectionToMap = NULL;
60 }
61 else
62 {
63 /* No message to return */
64 *Message = NULL;
65 SectionToMap = NULL;
66 }
67
68 /* Release the lock and return the section */
69 KeReleaseGuardedMutex(&LpcpLock);
70 return SectionToMap;
71 }
72
73 /* PUBLIC FUNCTIONS **********************************************************/
74
75 /*
76 * @implemented
77 */
78 NTSTATUS
79 NTAPI
80 NtSecureConnectPort(OUT PHANDLE PortHandle,
81 IN PUNICODE_STRING PortName,
82 IN PSECURITY_QUALITY_OF_SERVICE Qos,
83 IN OUT PPORT_VIEW ClientView OPTIONAL,
84 IN PSID ServerSid OPTIONAL,
85 IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL,
86 OUT PULONG MaxMessageLength OPTIONAL,
87 IN OUT PVOID ConnectionInformation OPTIONAL,
88 IN OUT PULONG ConnectionInformationLength OPTIONAL)
89 {
90 ULONG ConnectionInfoLength = 0;
91 PLPCP_PORT_OBJECT Port, ClientPort;
92 KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
93 NTSTATUS Status = STATUS_SUCCESS;
94 HANDLE Handle;
95 PVOID SectionToMap;
96 PLPCP_MESSAGE Message;
97 PLPCP_CONNECTION_MESSAGE ConnectMessage;
98 PETHREAD Thread = PsGetCurrentThread();
99 ULONG PortMessageLength;
100 LARGE_INTEGER SectionOffset;
101 PTOKEN Token;
102 PTOKEN_USER TokenUserInfo;
103 PAGED_CODE();
104 LPCTRACE(LPC_CONNECT_DEBUG,
105 "Name: %wZ. Qos: %p. Views: %p/%p. Sid: %p\n",
106 PortName,
107 Qos,
108 ClientView,
109 ServerView,
110 ServerSid);
111
112 /* Validate client view */
113 if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
114 {
115 /* Fail */
116 return STATUS_INVALID_PARAMETER;
117 }
118
119 /* Validate server view */
120 if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
121 {
122 /* Fail */
123 return STATUS_INVALID_PARAMETER;
124 }
125
126 /* Check if caller sent connection information length */
127 if (ConnectionInformationLength)
128 {
129 /* Retrieve the input length */
130 ConnectionInfoLength = *ConnectionInformationLength;
131 }
132
133 /* Get the port */
134 Status = ObReferenceObjectByName(PortName,
135 0,
136 NULL,
137 PORT_CONNECT,
138 LpcPortObjectType,
139 PreviousMode,
140 NULL,
141 (PVOID *)&Port);
142 if (!NT_SUCCESS(Status)) return Status;
143
144 /* This has to be a connection port */
145 if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
146 {
147 /* It isn't, so fail */
148 ObDereferenceObject(Port);
149 return STATUS_INVALID_PORT_HANDLE;
150 }
151
152 /* Check if we have a SID */
153 if (ServerSid)
154 {
155 /* Make sure that we have a server */
156 if (Port->ServerProcess)
157 {
158 /* Get its token and query user information */
159 Token = PsReferencePrimaryToken(Port->ServerProcess);
160 //Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
161 // FIXME: Need SeQueryInformationToken
162 Status = STATUS_SUCCESS;
163 TokenUserInfo = ExAllocatePool(PagedPool, sizeof(TOKEN_USER));
164 TokenUserInfo->User.Sid = ServerSid;
165 PsDereferencePrimaryToken(Token);
166
167 /* Check for success */
168 if (NT_SUCCESS(Status))
169 {
170 /* Compare the SIDs */
171 if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid))
172 {
173 /* Fail */
174 Status = STATUS_SERVER_SID_MISMATCH;
175 }
176
177 /* Free token information */
178 ExFreePool(TokenUserInfo);
179 }
180 }
181 else
182 {
183 /* Invalid SID */
184 Status = STATUS_SERVER_SID_MISMATCH;
185 }
186
187 /* Check if SID failed */
188 if (!NT_SUCCESS(Status))
189 {
190 /* Quit */
191 ObDereferenceObject(Port);
192 return Status;
193 }
194 }
195
196 /* Create the client port */
197 Status = ObCreateObject(PreviousMode,
198 LpcPortObjectType,
199 NULL,
200 PreviousMode,
201 NULL,
202 sizeof(LPCP_PORT_OBJECT),
203 0,
204 0,
205 (PVOID *)&ClientPort);
206 if (!NT_SUCCESS(Status))
207 {
208 /* Failed, dereference the server port and return */
209 ObDereferenceObject(Port);
210 return Status;
211 }
212
213 /* Setup the client port */
214 RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
215 ClientPort->Flags = LPCP_CLIENT_PORT;
216 ClientPort->ConnectionPort = Port;
217 ClientPort->MaxMessageLength = Port->MaxMessageLength;
218 ClientPort->SecurityQos = *Qos;
219 InitializeListHead(&ClientPort->LpcReplyChainHead);
220 InitializeListHead(&ClientPort->LpcDataInfoChainHead);
221
222 /* Check if we have dynamic security */
223 if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
224 {
225 /* Remember that */
226 ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
227 }
228 else
229 {
230 /* Create our own client security */
231 Status = SeCreateClientSecurity(Thread,
232 Qos,
233 FALSE,
234 &ClientPort->StaticSecurity);
235 if (!NT_SUCCESS(Status))
236 {
237 /* Security failed, dereference and return */
238 ObDereferenceObject(ClientPort);
239 return Status;
240 }
241 }
242
243 /* Initialize the port queue */
244 Status = LpcpInitializePortQueue(ClientPort);
245 if (!NT_SUCCESS(Status))
246 {
247 /* Failed */
248 ObDereferenceObject(ClientPort);
249 return Status;
250 }
251
252 /* Check if we have a client view */
253 if (ClientView)
254 {
255 /* Get the section handle */
256 Status = ObReferenceObjectByHandle(ClientView->SectionHandle,
257 SECTION_MAP_READ |
258 SECTION_MAP_WRITE,
259 MmSectionObjectType,
260 PreviousMode,
261 (PVOID*)&SectionToMap,
262 NULL);
263 if (!NT_SUCCESS(Status))
264 {
265 /* Fail */
266 ObDereferenceObject(Port);
267 return Status;
268 }
269
270 /* Set the section offset */
271 SectionOffset.QuadPart = ClientView->SectionOffset;
272
273 /* Map it */
274 Status = MmMapViewOfSection(SectionToMap,
275 PsGetCurrentProcess(),
276 &ClientPort->ClientSectionBase,
277 0,
278 0,
279 &SectionOffset,
280 &ClientView->ViewSize,
281 ViewUnmap,
282 0,
283 PAGE_READWRITE);
284
285 /* Update the offset */
286 ClientView->SectionOffset = SectionOffset.LowPart;
287
288 /* Check for failure */
289 if (!NT_SUCCESS(Status))
290 {
291 /* Fail */
292 ObDereferenceObject(SectionToMap);
293 ObDereferenceObject(Port);
294 return Status;
295 }
296
297 /* Update the base */
298 ClientView->ViewBase = ClientPort->ClientSectionBase;
299
300 /* Reference and remember the process */
301 ClientPort->MappingProcess = PsGetCurrentProcess();
302 ObReferenceObject(ClientPort->MappingProcess);
303 }
304 else
305 {
306 /* No section */
307 SectionToMap = NULL;
308 }
309
310 /* Normalize connection information */
311 if (ConnectionInfoLength > Port->MaxConnectionInfoLength)
312 {
313 /* Use the port's maximum allowed value */
314 ConnectionInfoLength = Port->MaxConnectionInfoLength;
315 }
316
317 /* Allocate a message from the port zone */
318 Message = LpcpAllocateFromPortZone();
319 if (!Message)
320 {
321 /* Fail if we couldn't allocate a message */
322 if (SectionToMap) ObDereferenceObject(SectionToMap);
323 ObDereferenceObject(ClientPort);
324 return STATUS_NO_MEMORY;
325 }
326
327 /* Set pointer to the connection message and fill in the CID */
328 ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
329 Message->Request.ClientId = Thread->Cid;
330
331 /* Check if we have a client view */
332 if (ClientView)
333 {
334 /* Set the view size */
335 Message->Request.ClientViewSize = ClientView->ViewSize;
336
337 /* Copy the client view and clear the server view */
338 RtlCopyMemory(&ConnectMessage->ClientView,
339 ClientView,
340 sizeof(PORT_VIEW));
341 RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
342 }
343 else
344 {
345 /* Set the size to 0 and clear the connect message */
346 Message->Request.ClientViewSize = 0;
347 RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE));
348 }
349
350 /* Set the section and client port. Port is NULL for now */
351 ConnectMessage->ClientPort = NULL;
352 ConnectMessage->SectionToMap = SectionToMap;
353
354 /* Set the data for the connection request message */
355 Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
356 sizeof(LPCP_CONNECTION_MESSAGE);
357 Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
358 Message->Request.u1.s1.DataLength;
359 Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST;
360
361 /* Check if we have connection information */
362 if (ConnectionInformation)
363 {
364 /* Copy it in */
365 RtlCopyMemory(ConnectMessage + 1,
366 ConnectionInformation,
367 ConnectionInfoLength);
368 }
369
370 /* Acquire the port lock */
371 KeAcquireGuardedMutex(&LpcpLock);
372
373 /* Check if someone already deleted the port name */
374 if (Port->Flags & LPCP_NAME_DELETED)
375 {
376 /* Fail the request */
377 Status = STATUS_OBJECT_NAME_NOT_FOUND;
378 }
379 else
380 {
381 /* Associate no thread yet */
382 Message->RepliedToThread = NULL;
383
384 /* Generate the Message ID and set it */
385 Message->Request.MessageId = LpcpNextMessageId++;
386 if (!LpcpNextMessageId) LpcpNextMessageId = 1;
387 Thread->LpcReplyMessageId = Message->Request.MessageId;
388
389 /* Insert the message into the queue and thread chain */
390 InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
391 InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
392 Thread->LpcReplyMessage = Message;
393
394 /* Now we can finally reference the client port and link it*/
395 ObReferenceObject(ClientPort);
396 ConnectMessage->ClientPort = ClientPort;
397
398 /* Enter a critical region */
399 KeEnterCriticalRegion();
400 }
401
402 /* Add another reference to the port */
403 ObReferenceObject(Port);
404
405 /* Release the lock */
406 KeReleaseGuardedMutex(&LpcpLock);
407
408 /* Check for success */
409 if (NT_SUCCESS(Status))
410 {
411 LPCTRACE(LPC_CONNECT_DEBUG,
412 "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
413 Message,
414 ConnectMessage,
415 Port,
416 ClientPort,
417 Status);
418
419 /* If this is a waitable port, set the event */
420 if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
421 1,
422 FALSE);
423
424 /* Release the queue semaphore and leave the critical region */
425 LpcpCompleteWait(Port->MsgQueue.Semaphore);
426 KeLeaveCriticalRegion();
427
428 /* Now wait for a reply */
429 LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
430 }
431
432 /* Check for failure */
433 if (!NT_SUCCESS(Status)) goto Cleanup;
434
435 /* Free the connection message */
436 SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
437
438 /* Check if we got a message back */
439 if (Message)
440 {
441 /* Check for new return length */
442 if ((Message->Request.u1.s1.DataLength -
443 sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength)
444 {
445 /* Set new normalized connection length */
446 ConnectionInfoLength = Message->Request.u1.s1.DataLength -
447 sizeof(LPCP_CONNECTION_MESSAGE);
448 }
449
450 /* Check if we had connection information */
451 if (ConnectionInformation)
452 {
453 /* Check if we had a length pointer */
454 if (ConnectionInformationLength)
455 {
456 /* Return the length */
457 *ConnectionInformationLength = ConnectionInfoLength;
458 }
459
460 /* Return the connection information */
461 RtlCopyMemory(ConnectionInformation,
462 ConnectMessage + 1,
463 ConnectionInfoLength );
464 }
465
466 /* Make sure we had a connected port */
467 if (ClientPort->ConnectedPort)
468 {
469 /* Get the message length before the port might get killed */
470 PortMessageLength = Port->MaxMessageLength;
471
472 /* Insert the client port */
473 Status = ObInsertObject(ClientPort,
474 NULL,
475 PORT_ALL_ACCESS,
476 0,
477 (PVOID *)NULL,
478 &Handle);
479 if (NT_SUCCESS(Status))
480 {
481 /* Return the handle */
482 *PortHandle = Handle;
483 LPCTRACE(LPC_CONNECT_DEBUG,
484 "Handle: %p. Length: %lx\n",
485 Handle,
486 PortMessageLength);
487
488 /* Check if maximum length was requested */
489 if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
490
491 /* Check if we had a client view */
492 if (ClientView)
493 {
494 /* Copy it back */
495 RtlCopyMemory(ClientView,
496 &ConnectMessage->ClientView,
497 sizeof(PORT_VIEW));
498 }
499
500 /* Check if we had a server view */
501 if (ServerView)
502 {
503 /* Copy it back */
504 RtlCopyMemory(ServerView,
505 &ConnectMessage->ServerView,
506 sizeof(REMOTE_PORT_VIEW));
507 }
508 }
509 }
510 else
511 {
512 /* No connection port, we failed */
513 if (SectionToMap) ObDereferenceObject(SectionToMap);
514
515 /* Acquire the lock */
516 KeAcquireGuardedMutex(&LpcpLock);
517
518 /* Check if it's because the name got deleted */
519 if (!(ClientPort->ConnectionPort) ||
520 (Port->Flags & LPCP_NAME_DELETED))
521 {
522 /* Set the correct status */
523 Status = STATUS_OBJECT_NAME_NOT_FOUND;
524 }
525 else
526 {
527 /* Otherwise, the caller refused us */
528 Status = STATUS_PORT_CONNECTION_REFUSED;
529 }
530
531 /* Release the lock */
532 KeReleaseGuardedMutex(&LpcpLock);
533
534 /* Kill the port */
535 ObDereferenceObject(ClientPort);
536 }
537
538 /* Free the message */
539 LpcpFreeToPortZone(Message, 0);
540 }
541 else
542 {
543 /* No reply message, fail */
544 if (SectionToMap) ObDereferenceObject(SectionToMap);
545 ObDereferenceObject(ClientPort);
546 Status = STATUS_PORT_CONNECTION_REFUSED;
547 }
548
549 /* Return status */
550 ObDereferenceObject(Port);
551 return Status;
552
553 Cleanup:
554 /* We failed, free the message */
555 SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
556
557 /* Check if the semaphore got signaled */
558 if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
559 {
560 /* Wait on it */
561 KeWaitForSingleObject(&Thread->LpcReplySemaphore,
562 WrExecutive,
563 KernelMode,
564 FALSE,
565 NULL);
566 }
567
568 /* Check if we had a message and free it */
569 if (Message) LpcpFreeToPortZone(Message, 0);
570
571 /* Dereference other objects */
572 if (SectionToMap) ObDereferenceObject(SectionToMap);
573 ObDereferenceObject(ClientPort);
574
575 /* Return status */
576 ObDereferenceObject(Port);
577 return Status;
578 }
579
580 /*
581 * @implemented
582 */
583 NTSTATUS
584 NTAPI
585 NtConnectPort(OUT PHANDLE PortHandle,
586 IN PUNICODE_STRING PortName,
587 IN PSECURITY_QUALITY_OF_SERVICE Qos,
588 IN PPORT_VIEW ClientView,
589 IN PREMOTE_PORT_VIEW ServerView,
590 OUT PULONG MaxMessageLength,
591 IN PVOID ConnectionInformation,
592 OUT PULONG ConnectionInformationLength)
593 {
594 /* Call the newer API */
595 return NtSecureConnectPort(PortHandle,
596 PortName,
597 Qos,
598 ClientView,
599 NULL,
600 ServerView,
601 MaxMessageLength,
602 ConnectionInformation,
603 ConnectionInformationLength);
604 }
605
606 /* EOF */