adding winetest from the vendor drop for usp10.dll
[reactos.git] / reactos / ntoskrnl / lpc / ntlpc / connect.c
1 /*
2 * PROJECT: ReactOS Kernel
3 * LICENSE: GPL - See COPYING in the top level directory
4 * FILE: ntoskrnl/lpc/connect.c
5 * PURPOSE: Local Procedure Call: Connection Management
6 * PROGRAMMERS: Alex Ionescu (alex.ionescu@reactos.org)
7 */
8
9 /* INCLUDES ******************************************************************/
10
11 #include <ntoskrnl.h>
12 #include "lpc.h"
13 #define NDEBUG
14 #include <internal/debug.h>
15
16 /* PRIVATE FUNCTIONS *********************************************************/
17
18 PVOID
19 NTAPI
20 LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
21 IN OUT PLPCP_CONNECTION_MESSAGE *ConnectMessage,
22 IN PETHREAD CurrentThread)
23 {
24 PVOID SectionToMap;
25
26 /* Acquire the LPC lock */
27 KeAcquireGuardedMutex(&LpcpLock);
28
29 /* Check if the reply chain is not empty */
30 if (!IsListEmpty(&CurrentThread->LpcReplyChain))
31 {
32 /* Remove this entry and re-initialize it */
33 RemoveEntryList(&CurrentThread->LpcReplyChain);
34 InitializeListHead(&CurrentThread->LpcReplyChain);
35 }
36
37 /* Check if there's a reply message */
38 if (CurrentThread->LpcReplyMessage)
39 {
40 /* Get the message */
41 *Message = CurrentThread->LpcReplyMessage;
42
43 /* Clear message data */
44 CurrentThread->LpcReceivedMessageId = 0;
45 CurrentThread->LpcReplyMessage = NULL;
46
47 /* Get the connection message and clear the section */
48 *ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(*Message + 1);
49 SectionToMap = (*ConnectMessage)->SectionToMap;
50 (*ConnectMessage)->SectionToMap = NULL;
51 }
52 else
53 {
54 /* No message to return */
55 *Message = NULL;
56 SectionToMap = NULL;
57 }
58
59 /* Release the lock and return the section */
60 KeReleaseGuardedMutex(&LpcpLock);
61 return SectionToMap;
62 }
63
64 /* PUBLIC FUNCTIONS **********************************************************/
65
66 /*
67 * @implemented
68 */
69 NTSTATUS
70 NTAPI
71 NtSecureConnectPort(OUT PHANDLE PortHandle,
72 IN PUNICODE_STRING PortName,
73 IN PSECURITY_QUALITY_OF_SERVICE Qos,
74 IN OUT PPORT_VIEW ClientView OPTIONAL,
75 IN PSID ServerSid OPTIONAL,
76 IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL,
77 OUT PULONG MaxMessageLength OPTIONAL,
78 IN OUT PVOID ConnectionInformation OPTIONAL,
79 IN OUT PULONG ConnectionInformationLength OPTIONAL)
80 {
81 ULONG ConnectionInfoLength = 0;
82 PLPCP_PORT_OBJECT Port, ClientPort;
83 KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
84 NTSTATUS Status = STATUS_SUCCESS;
85 HANDLE Handle;
86 PVOID SectionToMap;
87 PLPCP_MESSAGE Message;
88 PLPCP_CONNECTION_MESSAGE ConnectMessage;
89 PETHREAD Thread = PsGetCurrentThread();
90 ULONG PortMessageLength;
91 LARGE_INTEGER SectionOffset;
92 PTOKEN Token;
93 PTOKEN_USER TokenUserInfo;
94 PAGED_CODE();
95 LPCTRACE(LPC_CONNECT_DEBUG,
96 "Name: %wZ. Qos: %p. Views: %p/%p. Sid: %p\n",
97 PortName,
98 Qos,
99 ClientView,
100 ServerView,
101 ServerSid);
102
103 /* Validate client view */
104 if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
105 {
106 /* Fail */
107 return STATUS_INVALID_PARAMETER;
108 }
109
110 /* Validate server view */
111 if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
112 {
113 /* Fail */
114 return STATUS_INVALID_PARAMETER;
115 }
116
117 /* Check if caller sent connection information length */
118 if (ConnectionInformationLength)
119 {
120 /* Retrieve the input length */
121 ConnectionInfoLength = *ConnectionInformationLength;
122 }
123
124 /* Get the port */
125 Status = ObReferenceObjectByName(PortName,
126 0,
127 NULL,
128 PORT_ALL_ACCESS,
129 LpcPortObjectType,
130 PreviousMode,
131 NULL,
132 (PVOID *)&Port);
133 if (!NT_SUCCESS(Status)) return Status;
134
135 /* This has to be a connection port */
136 if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
137 {
138 /* It isn't, so fail */
139 ObDereferenceObject(Port);
140 return STATUS_INVALID_PORT_HANDLE;
141 }
142
143 /* Check if we have a SID */
144 if (ServerSid)
145 {
146 /* Make sure that we have a server */
147 if (Port->ServerProcess)
148 {
149 /* Get its token and query user information */
150 Token = PsReferencePrimaryToken(Port->ServerProcess);
151 //Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
152 // FIXME: Need SeQueryInformationToken
153 Status = STATUS_SUCCESS;
154 TokenUserInfo = ExAllocatePool(PagedPool, sizeof(TOKEN_USER));
155 TokenUserInfo->User.Sid = ServerSid;
156 PsDereferencePrimaryToken(Token);
157
158 /* Check for success */
159 if (NT_SUCCESS(Status))
160 {
161 /* Compare the SIDs */
162 if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid))
163 {
164 /* Fail */
165 Status = STATUS_SERVER_SID_MISMATCH;
166 }
167
168 /* Free token information */
169 ExFreePool(TokenUserInfo);
170 }
171 }
172 else
173 {
174 /* Invalid SID */
175 Status = STATUS_SERVER_SID_MISMATCH;
176 }
177
178 /* Check if SID failed */
179 if (!NT_SUCCESS(Status))
180 {
181 /* Quit */
182 ObDereferenceObject(Port);
183 return Status;
184 }
185 }
186
187 /* Create the client port */
188 Status = ObCreateObject(PreviousMode,
189 LpcPortObjectType,
190 NULL,
191 PreviousMode,
192 NULL,
193 sizeof(LPCP_PORT_OBJECT),
194 0,
195 0,
196 (PVOID *)&ClientPort);
197 if (!NT_SUCCESS(Status))
198 {
199 /* Failed, dereference the server port and return */
200 ObDereferenceObject(Port);
201 return Status;
202 }
203
204 /* Setup the client port */
205 RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
206 ClientPort->Flags = LPCP_CLIENT_PORT;
207 ClientPort->ConnectionPort = Port;
208 ClientPort->MaxMessageLength = Port->MaxMessageLength;
209 ClientPort->SecurityQos = *Qos;
210 InitializeListHead(&ClientPort->LpcReplyChainHead);
211 InitializeListHead(&ClientPort->LpcDataInfoChainHead);
212
213 /* Check if we have dynamic security */
214 if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
215 {
216 /* Remember that */
217 ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
218 }
219 else
220 {
221 /* Create our own client security */
222 Status = SeCreateClientSecurity(Thread,
223 Qos,
224 FALSE,
225 &ClientPort->StaticSecurity);
226 if (!NT_SUCCESS(Status))
227 {
228 /* Security failed, dereference and return */
229 ObDereferenceObject(ClientPort);
230 return Status;
231 }
232 }
233
234 /* Initialize the port queue */
235 Status = LpcpInitializePortQueue(ClientPort);
236 if (!NT_SUCCESS(Status))
237 {
238 /* Failed */
239 ObDereferenceObject(ClientPort);
240 return Status;
241 }
242
243 /* Check if we have a client view */
244 if (ClientView)
245 {
246 /* Get the section handle */
247 Status = ObReferenceObjectByHandle(ClientView->SectionHandle,
248 SECTION_MAP_READ |
249 SECTION_MAP_WRITE,
250 MmSectionObjectType,
251 PreviousMode,
252 (PVOID*)&SectionToMap,
253 NULL);
254 if (!NT_SUCCESS(Status))
255 {
256 /* Fail */
257 ObDereferenceObject(Port);
258 return Status;
259 }
260
261 /* Set the section offset */
262 SectionOffset.QuadPart = ClientView->SectionOffset;
263
264 /* Map it */
265 Status = MmMapViewOfSection(SectionToMap,
266 PsGetCurrentProcess(),
267 &Port->ClientSectionBase,
268 0,
269 0,
270 &SectionOffset,
271 &ClientView->ViewSize,
272 ViewUnmap,
273 0,
274 PAGE_READWRITE);
275
276 /* Update the offset */
277 ClientView->SectionOffset = SectionOffset.LowPart;
278
279 /* Check for failure */
280 if (!NT_SUCCESS(Status))
281 {
282 /* Fail */
283 ObDereferenceObject(SectionToMap);
284 ObDereferenceObject(Port);
285 return Status;
286 }
287
288 /* Update the base */
289 ClientView->ViewBase = Port->ClientSectionBase;
290 }
291 else
292 {
293 /* No section */
294 SectionToMap = NULL;
295 }
296
297 /* Normalize connection information */
298 if (ConnectionInfoLength > Port->MaxConnectionInfoLength)
299 {
300 /* Use the port's maximum allowed value */
301 ConnectionInfoLength = Port->MaxConnectionInfoLength;
302 }
303
304 /* Allocate a message from the port zone */
305 Message = LpcpAllocateFromPortZone();
306 if (!Message)
307 {
308 /* Fail if we couldn't allocate a message */
309 if (SectionToMap) ObDereferenceObject(SectionToMap);
310 ObDereferenceObject(ClientPort);
311 return STATUS_NO_MEMORY;
312 }
313
314 /* Set pointer to the connection message and fill in the CID */
315 ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
316 Message->Request.ClientId = Thread->Cid;
317
318 /* Check if we have a client view */
319 if (ClientView)
320 {
321 /* Set the view size */
322 Message->Request.ClientViewSize = ClientView->ViewSize;
323
324 /* Copy the client view and clear the server view */
325 RtlMoveMemory(&ConnectMessage->ClientView,
326 ClientView,
327 sizeof(PORT_VIEW));
328 RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
329 }
330 else
331 {
332 /* Set the size to 0 and clear the connect message */
333 Message->Request.ClientViewSize = 0;
334 RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE));
335 }
336
337 /* Set the section and client port. Port is NULL for now */
338 ConnectMessage->ClientPort = NULL;
339 ConnectMessage->SectionToMap = SectionToMap;
340
341 /* Set the data for the connection request message */
342 Message->Request.u1.s1.DataLength = sizeof(LPCP_CONNECTION_MESSAGE) +
343 ConnectionInfoLength;
344 Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
345 Message->Request.u1.s1.DataLength;
346 Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST;
347
348 /* Check if we have connection information */
349 if (ConnectionInformation)
350 {
351 /* Copy it in */
352 RtlMoveMemory(ConnectMessage + 1,
353 ConnectionInformation,
354 ConnectionInfoLength);
355 }
356
357 /* Acquire the port lock */
358 KeAcquireGuardedMutex(&LpcpLock);
359
360 /* Check if someone already deleted the port name */
361 if (Port->Flags & LPCP_NAME_DELETED)
362 {
363 /* Fail the request */
364 KeReleaseGuardedMutex(&LpcpLock);
365 Status = STATUS_OBJECT_NAME_NOT_FOUND;
366 goto Cleanup;
367 }
368
369 /* Associate no thread yet */
370 Message->RepliedToThread = NULL;
371
372 /* Generate the Message ID and set it */
373 Message->Request.MessageId = LpcpNextMessageId++;
374 if (!LpcpNextMessageId) LpcpNextMessageId = 1;
375 Thread->LpcReplyMessageId = Message->Request.MessageId;
376
377 /* Insert the message into the queue and thread chain */
378 InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
379 InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
380 Thread->LpcReplyMessage = Message;
381
382 /* Now we can finally reference the client port and link it*/
383 ObReferenceObject(ClientPort);
384 ConnectMessage->ClientPort = ClientPort;
385
386 /* Release the lock */
387 KeReleaseGuardedMutex(&LpcpLock);
388 LPCTRACE(LPC_CONNECT_DEBUG,
389 "Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
390 Message,
391 ConnectMessage,
392 Port,
393 ClientPort,
394 Status);
395
396 /* If this is a waitable port, set the event */
397 if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
398 1,
399 FALSE);
400
401 /* Release the queue semaphore */
402 LpcpCompleteWait(Port->MsgQueue.Semaphore);
403
404 /* Now wait for a reply */
405 LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
406
407 /* Check if our wait ended in success */
408 if (Status != STATUS_SUCCESS) goto Cleanup;
409
410 /* Free the connection message */
411 SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
412
413 /* Check if we got a message back */
414 if (Message)
415 {
416 /* Check for new return length */
417 if ((Message->Request.u1.s1.DataLength -
418 sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength)
419 {
420 /* Set new normalized connection length */
421 ConnectionInfoLength = Message->Request.u1.s1.DataLength -
422 sizeof(LPCP_CONNECTION_MESSAGE);
423 }
424
425 /* Check if we had connection information */
426 if (ConnectionInformation)
427 {
428 /* Check if we had a length pointer */
429 if (ConnectionInformationLength)
430 {
431 /* Return the length */
432 *ConnectionInformationLength = ConnectionInfoLength;
433 }
434
435 /* Return the connection information */
436 RtlMoveMemory(ConnectionInformation,
437 ConnectMessage + 1,
438 ConnectionInfoLength );
439 }
440
441 /* Make sure we had a connected port */
442 if (ClientPort->ConnectedPort)
443 {
444 /* Get the message length before the port might get killed */
445 PortMessageLength = Port->MaxMessageLength;
446
447 /* Insert the client port */
448 Status = ObInsertObject(ClientPort,
449 NULL,
450 PORT_ALL_ACCESS,
451 0,
452 (PVOID *)NULL,
453 &Handle);
454 if (NT_SUCCESS(Status))
455 {
456 /* Return the handle */
457 *PortHandle = Handle;
458 LPCTRACE(LPC_CONNECT_DEBUG,
459 "Handle: %lx. Length: %lx\n",
460 Handle,
461 PortMessageLength);
462
463 /* Check if maximum length was requested */
464 if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
465
466 /* Check if we had a client view */
467 if (ClientView)
468 {
469 /* Copy it back */
470 RtlMoveMemory(ClientView,
471 &ConnectMessage->ClientView,
472 sizeof(PORT_VIEW));
473 }
474
475 /* Check if we had a server view */
476 if (ServerView)
477 {
478 /* Copy it back */
479 RtlMoveMemory(ServerView,
480 &ConnectMessage->ServerView,
481 sizeof(REMOTE_PORT_VIEW));
482 }
483 }
484 }
485 else
486 {
487 /* No connection port, we failed */
488 if (SectionToMap) ObDereferenceObject(SectionToMap);
489
490 /* Check if it's because the name got deleted */
491 if (Port->Flags & LPCP_NAME_DELETED)
492 {
493 /* Set the correct status */
494 Status = STATUS_OBJECT_NAME_NOT_FOUND;
495 }
496 else
497 {
498 /* Otherwise, the caller refused us */
499 Status = STATUS_PORT_CONNECTION_REFUSED;
500 }
501
502 /* Kill the port */
503 ObDereferenceObject(ClientPort);
504 }
505
506 /* Free the message */
507 LpcpFreeToPortZone(Message, FALSE);
508 return Status;
509 }
510
511 /* No reply message, fail */
512 if (SectionToMap) ObDereferenceObject(SectionToMap);
513 ObDereferenceObject(ClientPort);
514 return STATUS_PORT_CONNECTION_REFUSED;
515
516 Cleanup:
517 /* We failed, free the message */
518 SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
519
520 /* Check if the semaphore got signaled */
521 if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
522 {
523 /* Wait on it */
524 KeWaitForSingleObject(&Thread->LpcReplySemaphore,
525 KernelMode,
526 Executive,
527 FALSE,
528 NULL);
529 }
530
531 /* Check if we had a message and free it */
532 if (Message) LpcpFreeToPortZone(Message, FALSE);
533
534 /* Dereference other objects */
535 if (SectionToMap) ObDereferenceObject(SectionToMap);
536 ObDereferenceObject(ClientPort);
537
538 /* Return status */
539 return Status;
540 }
541
542 /*
543 * @implemented
544 */
545 NTSTATUS
546 NTAPI
547 NtConnectPort(OUT PHANDLE PortHandle,
548 IN PUNICODE_STRING PortName,
549 IN PSECURITY_QUALITY_OF_SERVICE Qos,
550 IN PPORT_VIEW ClientView,
551 IN PREMOTE_PORT_VIEW ServerView,
552 OUT PULONG MaxMessageLength,
553 IN PVOID ConnectionInformation,
554 OUT PULONG ConnectionInformationLength)
555 {
556 /* Call the newer API */
557 return NtSecureConnectPort(PortHandle,
558 PortName,
559 Qos,
560 ClientView,
561 NULL,
562 ServerView,
563 MaxMessageLength,
564 ConnectionInformation,
565 ConnectionInformationLength);
566 }
567
568 /* EOF */