Create the AHCI branch for Aman's work
[reactos.git] / ntoskrnl / lpc / connect.c
1 /*
2 * PROJECT: ReactOS Kernel
3 * LICENSE: GPL - See COPYING in the top level directory
4 * FILE: ntoskrnl/lpc/connect.c
5 * PURPOSE: Local Procedure Call: Connection Management
6 * PROGRAMMERS: Alex Ionescu (alex.ionescu@reactos.org)
7 */
8
9 /* INCLUDES ******************************************************************/
10
11 #include <ntoskrnl.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* PRIVATE FUNCTIONS *********************************************************/
16
17 PVOID
18 NTAPI
19 LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
20 IN OUT PLPCP_CONNECTION_MESSAGE *ConnectMessage,
21 IN PETHREAD CurrentThread)
22 {
23 PVOID SectionToMap;
24 PLPCP_MESSAGE ReplyMessage;
25
26 /* Acquire the LPC lock */
27 KeAcquireGuardedMutex(&LpcpLock);
28
29 /* Check if the reply chain is not empty */
30 if (!IsListEmpty(&CurrentThread->LpcReplyChain))
31 {
32 /* Remove this entry and re-initialize it */
33 RemoveEntryList(&CurrentThread->LpcReplyChain);
34 InitializeListHead(&CurrentThread->LpcReplyChain);
35 }
36
37 /* Check if there's a reply message */
38 ReplyMessage = LpcpGetMessageFromThread(CurrentThread);
39 if (ReplyMessage)
40 {
41 /* Get the message */
42 *Message = ReplyMessage;
43
44 /* Check if it's got messages */
45 if (!IsListEmpty(&ReplyMessage->Entry))
46 {
47 /* Clear the list */
48 RemoveEntryList(&ReplyMessage->Entry);
49 InitializeListHead(&ReplyMessage->Entry);
50 }
51
52 /* Clear message data */
53 CurrentThread->LpcReceivedMessageId = 0;
54 CurrentThread->LpcReplyMessage = NULL;
55
56 /* Get the connection message and clear the section */
57 *ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(ReplyMessage + 1);
58 SectionToMap = (*ConnectMessage)->SectionToMap;
59 (*ConnectMessage)->SectionToMap = NULL;
60 }
61 else
62 {
63 /* No message to return */
64 *Message = NULL;
65 SectionToMap = NULL;
66 }
67
68 /* Release the lock and return the section */
69 KeReleaseGuardedMutex(&LpcpLock);
70 return SectionToMap;
71 }
72
73 /* PUBLIC FUNCTIONS **********************************************************/
74
75 /*
76 * @implemented
77 */
78 NTSTATUS
79 NTAPI
80 NtSecureConnectPort(OUT PHANDLE PortHandle,
81 IN PUNICODE_STRING PortName,
82 IN PSECURITY_QUALITY_OF_SERVICE Qos,
83 IN OUT PPORT_VIEW ClientView OPTIONAL,
84 IN PSID ServerSid OPTIONAL,
85 IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL,
86 OUT PULONG MaxMessageLength OPTIONAL,
87 IN OUT PVOID ConnectionInformation OPTIONAL,
88 IN OUT PULONG ConnectionInformationLength OPTIONAL)
89 {
90 ULONG ConnectionInfoLength = 0;
91 PLPCP_PORT_OBJECT Port, ClientPort;
92 KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
93 NTSTATUS Status = STATUS_SUCCESS;
94 HANDLE Handle;
95 PVOID SectionToMap;
96 PLPCP_MESSAGE Message;
97 PLPCP_CONNECTION_MESSAGE ConnectMessage;
98 PETHREAD Thread = PsGetCurrentThread();
99 ULONG PortMessageLength;
100 LARGE_INTEGER SectionOffset;
101 PTOKEN Token;
102 PTOKEN_USER TokenUserInfo;
103 PAGED_CODE();
104 LPCTRACE(LPC_CONNECT_DEBUG,
105 "Name: %wZ. Qos: %p. Views: %p/%p. Sid: %p\n",
106 PortName,
107 Qos,
108 ClientView,
109 ServerView,
110 ServerSid);
111
112 /* Validate client view */
113 if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
114 {
115 /* Fail */
116 return STATUS_INVALID_PARAMETER;
117 }
118
119 /* Validate server view */
120 if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
121 {
122 /* Fail */
123 return STATUS_INVALID_PARAMETER;
124 }
125
126 /* Check if caller sent connection information length */
127 if (ConnectionInformationLength)
128 {
129 /* Retrieve the input length */
130 ConnectionInfoLength = *ConnectionInformationLength;
131 }
132
133 /* Get the port */
134 Status = ObReferenceObjectByName(PortName,
135 0,
136 NULL,
137 PORT_CONNECT,
138 LpcPortObjectType,
139 PreviousMode,
140 NULL,
141 (PVOID *)&Port);
142 if (!NT_SUCCESS(Status))
143 {
144 DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status);
145 return Status;
146 }
147
148 /* This has to be a connection port */
149 if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
150 {
151 /* It isn't, so fail */
152 ObDereferenceObject(Port);
153 return STATUS_INVALID_PORT_HANDLE;
154 }
155
156 /* Check if we have a SID */
157 if (ServerSid)
158 {
159 /* Make sure that we have a server */
160 if (Port->ServerProcess)
161 {
162 /* Get its token and query user information */
163 Token = PsReferencePrimaryToken(Port->ServerProcess);
164 //Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
165 // FIXME: Need SeQueryInformationToken
166 Status = STATUS_SUCCESS;
167 TokenUserInfo = ExAllocatePoolWithTag(PagedPool, sizeof(TOKEN_USER), TAG_SE);
168 TokenUserInfo->User.Sid = ServerSid;
169 PsDereferencePrimaryToken(Token);
170
171 /* Check for success */
172 if (NT_SUCCESS(Status))
173 {
174 /* Compare the SIDs */
175 if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid))
176 {
177 /* Fail */
178 Status = STATUS_SERVER_SID_MISMATCH;
179 }
180
181 /* Free token information */
182 ExFreePoolWithTag(TokenUserInfo, TAG_SE);
183 }
184 }
185 else
186 {
187 /* Invalid SID */
188 Status = STATUS_SERVER_SID_MISMATCH;
189 }
190
191 /* Check if SID failed */
192 if (!NT_SUCCESS(Status))
193 {
194 /* Quit */
195 ObDereferenceObject(Port);
196 return Status;
197 }
198 }
199
200 /* Create the client port */
201 Status = ObCreateObject(PreviousMode,
202 LpcPortObjectType,
203 NULL,
204 PreviousMode,
205 NULL,
206 sizeof(LPCP_PORT_OBJECT),
207 0,
208 0,
209 (PVOID *)&ClientPort);
210 if (!NT_SUCCESS(Status))
211 {
212 /* Failed, dereference the server port and return */
213 ObDereferenceObject(Port);
214 return Status;
215 }
216
217 /* Setup the client port */
218 RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
219 ClientPort->Flags = LPCP_CLIENT_PORT;
220 ClientPort->ConnectionPort = Port;
221 ClientPort->MaxMessageLength = Port->MaxMessageLength;
222 ClientPort->SecurityQos = *Qos;
223 InitializeListHead(&ClientPort->LpcReplyChainHead);
224 InitializeListHead(&ClientPort->LpcDataInfoChainHead);
225
226 /* Check if we have dynamic security */
227 if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
228 {
229 /* Remember that */
230 ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
231 }
232 else
233 {
234 /* Create our own client security */
235 Status = SeCreateClientSecurity(Thread,
236 Qos,
237 FALSE,
238 &ClientPort->StaticSecurity);
239 if (!NT_SUCCESS(Status))
240 {
241 /* Security failed, dereference and return */
242 ObDereferenceObject(ClientPort);
243 return Status;
244 }
245 }
246
247 /* Initialize the port queue */
248 Status = LpcpInitializePortQueue(ClientPort);
249 if (!NT_SUCCESS(Status))
250 {
251 /* Failed */
252 ObDereferenceObject(ClientPort);
253 return Status;
254 }
255
256 /* Check if we have a client view */
257 if (ClientView)
258 {
259 /* Get the section handle */
260 Status = ObReferenceObjectByHandle(ClientView->SectionHandle,
261 SECTION_MAP_READ |
262 SECTION_MAP_WRITE,
263 MmSectionObjectType,
264 PreviousMode,
265 (PVOID*)&SectionToMap,
266 NULL);
267 if (!NT_SUCCESS(Status))
268 {
269 /* Fail */
270 ObDereferenceObject(Port);
271 return Status;
272 }
273
274 /* Set the section offset */
275 SectionOffset.QuadPart = ClientView->SectionOffset;
276
277 /* Map it */
278 Status = MmMapViewOfSection(SectionToMap,
279 PsGetCurrentProcess(),
280 &ClientPort->ClientSectionBase,
281 0,
282 0,
283 &SectionOffset,
284 &ClientView->ViewSize,
285 ViewUnmap,
286 0,
287 PAGE_READWRITE);
288
289 /* Update the offset */
290 ClientView->SectionOffset = SectionOffset.LowPart;
291
292 /* Check for failure */
293 if (!NT_SUCCESS(Status))
294 {
295 /* Fail */
296 ObDereferenceObject(SectionToMap);
297 ObDereferenceObject(Port);
298 return Status;
299 }
300
301 /* Update the base */
302 ClientView->ViewBase = ClientPort->ClientSectionBase;
303
304 /* Reference and remember the process */
305 ClientPort->MappingProcess = PsGetCurrentProcess();
306 ObReferenceObject(ClientPort->MappingProcess);
307 }
308 else
309 {
310 /* No section */
311 SectionToMap = NULL;
312 }
313
314 /* Normalize connection information */
315 if (ConnectionInfoLength > Port->MaxConnectionInfoLength)
316 {
317 /* Use the port's maximum allowed value */
318 ConnectionInfoLength = Port->MaxConnectionInfoLength;
319 }
320
321 /* Allocate a message from the port zone */
322 Message = LpcpAllocateFromPortZone();
323 if (!Message)
324 {
325 /* Fail if we couldn't allocate a message */
326 if (SectionToMap) ObDereferenceObject(SectionToMap);
327 ObDereferenceObject(ClientPort);
328 return STATUS_NO_MEMORY;
329 }
330
331 /* Set pointer to the connection message and fill in the CID */
332 ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
333 Message->Request.ClientId = Thread->Cid;
334
335 /* Check if we have a client view */
336 if (ClientView)
337 {
338 /* Set the view size */
339 Message->Request.ClientViewSize = ClientView->ViewSize;
340
341 /* Copy the client view and clear the server view */
342 RtlCopyMemory(&ConnectMessage->ClientView,
343 ClientView,
344 sizeof(PORT_VIEW));
345 RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
346 }
347 else
348 {
349 /* Set the size to 0 and clear the connect message */
350 Message->Request.ClientViewSize = 0;
351 RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE));
352 }
353
354 /* Set the section and client port. Port is NULL for now */
355 ConnectMessage->ClientPort = NULL;
356 ConnectMessage->SectionToMap = SectionToMap;
357
358 /* Set the data for the connection request message */
359 Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength +
360 sizeof(LPCP_CONNECTION_MESSAGE);
361 Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
362 Message->Request.u1.s1.DataLength;
363 Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST;
364
365 /* Check if we have connection information */
366 if (ConnectionInformation)
367 {
368 /* Copy it in */
369 RtlCopyMemory(ConnectMessage + 1,
370 ConnectionInformation,
371 ConnectionInfoLength);
372 }
373
374 /* Acquire the port lock */
375 KeAcquireGuardedMutex(&LpcpLock);
376
377 /* Check if someone already deleted the port name */
378 if (Port->Flags & LPCP_NAME_DELETED)
379 {
380 /* Fail the request */
381 Status = STATUS_OBJECT_NAME_NOT_FOUND;
382 }
383 else
384 {
385 /* Associate no thread yet */
386 Message->RepliedToThread = NULL;
387
388 /* Generate the Message ID and set it */
389 Message->Request.MessageId = LpcpNextMessageId++;
390 if (!LpcpNextMessageId) LpcpNextMessageId = 1;
391 Thread->LpcReplyMessageId = Message->Request.MessageId;
392
393 /* Insert the message into the queue and thread chain */
394 InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
395 InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
396 Thread->LpcReplyMessage = Message;
397
398 /* Now we can finally reference the client port and link it*/
399 ObReferenceObject(ClientPort);
400 ConnectMessage->ClientPort = ClientPort;
401
402 /* Enter a critical region */
403 KeEnterCriticalRegion();
404 }
405
406 /* Add another reference to the port */
407 ObReferenceObject(Port);
408
409 /* Release the lock */
410 KeReleaseGuardedMutex(&LpcpLock);
411
412 /* Check for success */
413 if (NT_SUCCESS(Status))
414 {
415 LPCTRACE(LPC_CONNECT_DEBUG,
416 "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
417 Message,
418 ConnectMessage,
419 Port,
420 ClientPort,
421 Status);
422
423 /* If this is a waitable port, set the event */
424 if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
425 1,
426 FALSE);
427
428 /* Release the queue semaphore and leave the critical region */
429 LpcpCompleteWait(Port->MsgQueue.Semaphore);
430 KeLeaveCriticalRegion();
431
432 /* Now wait for a reply */
433 LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
434 }
435
436 /* Check for failure */
437 if (!NT_SUCCESS(Status)) goto Cleanup;
438
439 /* Free the connection message */
440 SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
441
442 /* Check if we got a message back */
443 if (Message)
444 {
445 /* Check for new return length */
446 if ((Message->Request.u1.s1.DataLength -
447 sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength)
448 {
449 /* Set new normalized connection length */
450 ConnectionInfoLength = Message->Request.u1.s1.DataLength -
451 sizeof(LPCP_CONNECTION_MESSAGE);
452 }
453
454 /* Check if we had connection information */
455 if (ConnectionInformation)
456 {
457 /* Check if we had a length pointer */
458 if (ConnectionInformationLength)
459 {
460 /* Return the length */
461 *ConnectionInformationLength = ConnectionInfoLength;
462 }
463
464 /* Return the connection information */
465 RtlCopyMemory(ConnectionInformation,
466 ConnectMessage + 1,
467 ConnectionInfoLength );
468 }
469
470 /* Make sure we had a connected port */
471 if (ClientPort->ConnectedPort)
472 {
473 /* Get the message length before the port might get killed */
474 PortMessageLength = Port->MaxMessageLength;
475
476 /* Insert the client port */
477 Status = ObInsertObject(ClientPort,
478 NULL,
479 PORT_ALL_ACCESS,
480 0,
481 (PVOID *)NULL,
482 &Handle);
483 if (NT_SUCCESS(Status))
484 {
485 /* Return the handle */
486 *PortHandle = Handle;
487 LPCTRACE(LPC_CONNECT_DEBUG,
488 "Handle: %p. Length: %lx\n",
489 Handle,
490 PortMessageLength);
491
492 /* Check if maximum length was requested */
493 if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
494
495 /* Check if we had a client view */
496 if (ClientView)
497 {
498 /* Copy it back */
499 RtlCopyMemory(ClientView,
500 &ConnectMessage->ClientView,
501 sizeof(PORT_VIEW));
502 }
503
504 /* Check if we had a server view */
505 if (ServerView)
506 {
507 /* Copy it back */
508 RtlCopyMemory(ServerView,
509 &ConnectMessage->ServerView,
510 sizeof(REMOTE_PORT_VIEW));
511 }
512 }
513 }
514 else
515 {
516 /* No connection port, we failed */
517 if (SectionToMap) ObDereferenceObject(SectionToMap);
518
519 /* Acquire the lock */
520 KeAcquireGuardedMutex(&LpcpLock);
521
522 /* Check if it's because the name got deleted */
523 if (!(ClientPort->ConnectionPort) ||
524 (Port->Flags & LPCP_NAME_DELETED))
525 {
526 /* Set the correct status */
527 Status = STATUS_OBJECT_NAME_NOT_FOUND;
528 }
529 else
530 {
531 /* Otherwise, the caller refused us */
532 Status = STATUS_PORT_CONNECTION_REFUSED;
533 }
534
535 /* Release the lock */
536 KeReleaseGuardedMutex(&LpcpLock);
537
538 /* Kill the port */
539 ObDereferenceObject(ClientPort);
540 }
541
542 /* Free the message */
543 LpcpFreeToPortZone(Message, 0);
544 }
545 else
546 {
547 /* No reply message, fail */
548 if (SectionToMap) ObDereferenceObject(SectionToMap);
549 ObDereferenceObject(ClientPort);
550 Status = STATUS_PORT_CONNECTION_REFUSED;
551 }
552
553 /* Return status */
554 ObDereferenceObject(Port);
555 return Status;
556
557 Cleanup:
558 /* We failed, free the message */
559 SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
560
561 /* Check if the semaphore got signaled */
562 if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
563 {
564 /* Wait on it */
565 KeWaitForSingleObject(&Thread->LpcReplySemaphore,
566 WrExecutive,
567 KernelMode,
568 FALSE,
569 NULL);
570 }
571
572 /* Check if we had a message and free it */
573 if (Message) LpcpFreeToPortZone(Message, 0);
574
575 /* Dereference other objects */
576 if (SectionToMap) ObDereferenceObject(SectionToMap);
577 ObDereferenceObject(ClientPort);
578
579 /* Return status */
580 ObDereferenceObject(Port);
581 return Status;
582 }
583
584 /*
585 * @implemented
586 */
587 NTSTATUS
588 NTAPI
589 NtConnectPort(OUT PHANDLE PortHandle,
590 IN PUNICODE_STRING PortName,
591 IN PSECURITY_QUALITY_OF_SERVICE Qos,
592 IN PPORT_VIEW ClientView,
593 IN PREMOTE_PORT_VIEW ServerView,
594 OUT PULONG MaxMessageLength,
595 IN PVOID ConnectionInformation,
596 OUT PULONG ConnectionInformationLength)
597 {
598 /* Call the newer API */
599 return NtSecureConnectPort(PortHandle,
600 PortName,
601 Qos,
602 ClientView,
603 NULL,
604 ServerView,
605 MaxMessageLength,
606 ConnectionInformation,
607 ConnectionInformationLength);
608 }
609
610 /* EOF */