- Fix multiple LPC race conditions.
[reactos.git] / reactos / ntoskrnl / lpc / send.c
1 /*
2 * PROJECT: ReactOS Kernel
3 * LICENSE: GPL - See COPYING in the top level directory
4 * FILE: ntoskrnl/lpc/send.c
5 * PURPOSE: Local Procedure Call: Sending (Requests)
6 * PROGRAMMERS: Alex Ionescu (alex.ionescu@reactos.org)
7 */
8
9 /* INCLUDES ******************************************************************/
10
11 #include <ntoskrnl.h>
12 #define NDEBUG
13 #include <debug.h>
14
15 /* PUBLIC FUNCTIONS **********************************************************/
16
17 /*
18 * @implemented
19 */
20 NTSTATUS
21 NTAPI
22 LpcRequestPort(IN PVOID PortObject,
23 IN PPORT_MESSAGE LpcMessage)
24 {
25 PLPCP_PORT_OBJECT Port = PortObject, QueuePort, ConnectionPort = NULL;
26 ULONG MessageType;
27 PLPCP_MESSAGE Message;
28 KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
29 PAGED_CODE();
30 LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", Port, LpcMessage);
31
32 /* Check if this is a non-datagram message */
33 if (LpcMessage->u2.s2.Type)
34 {
35 /* Get the message type */
36 MessageType = LpcpGetMessageType(LpcMessage);
37
38 /* Validate it */
39 if ((MessageType < LPC_DATAGRAM) || (MessageType > LPC_CLIENT_DIED))
40 {
41 /* Fail */
42 return STATUS_INVALID_PARAMETER;
43 }
44
45 /* Mark this as a kernel-mode message only if we really came from it */
46 if ((PreviousMode == KernelMode) &&
47 (LpcMessage->u2.s2.Type & LPC_KERNELMODE_MESSAGE))
48 {
49 /* We did, this is a kernel mode message */
50 MessageType |= LPC_KERNELMODE_MESSAGE;
51 }
52 }
53 else
54 {
55 /* This is a datagram */
56 MessageType = LPC_DATAGRAM;
57 }
58
59 /* Can't have data information on this type of call */
60 if (LpcMessage->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
61
62 /* Validate message sizes */
63 if ((LpcMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
64 (LpcMessage->u1.s1.TotalLength <= LpcMessage->u1.s1.DataLength))
65 {
66 /* Fail */
67 return STATUS_PORT_MESSAGE_TOO_LONG;
68 }
69
70 /* Allocate a new message */
71 Message = LpcpAllocateFromPortZone();
72 if (!Message) return STATUS_NO_MEMORY;
73
74 /* Clear the context */
75 Message->RepliedToThread = NULL;
76 Message->PortContext = NULL;
77
78 /* Copy the message */
79 LpcpMoveMessage(&Message->Request,
80 LpcMessage,
81 LpcMessage + 1,
82 MessageType,
83 &PsGetCurrentThread()->Cid);
84
85 /* Acquire the LPC lock */
86 KeAcquireGuardedMutex(&LpcpLock);
87
88 /* Check if this is anything but a connection port */
89 if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
90 {
91 /* The queue port is the connected port */
92 QueuePort = Port->ConnectedPort;
93 if (QueuePort)
94 {
95 /* Check if this is a client port */
96 if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
97 {
98 /* Then copy the context */
99 Message->PortContext = QueuePort->PortContext;
100 ConnectionPort = QueuePort = Port->ConnectionPort;
101 if (!ConnectionPort)
102 {
103 /* Fail */
104 LpcpFreeToPortZone(Message, 3);
105 return STATUS_PORT_DISCONNECTED;
106 }
107 }
108 else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
109 {
110 /* Any other kind of port, use the connection port */
111 ConnectionPort = QueuePort = Port->ConnectionPort;
112 if (!ConnectionPort)
113 {
114 /* Fail */
115 LpcpFreeToPortZone(Message, 3);
116 return STATUS_PORT_DISCONNECTED;
117 }
118 }
119
120 /* If we have a connection port, reference it */
121 if (ConnectionPort) ObReferenceObject(ConnectionPort);
122 }
123 }
124 else
125 {
126 /* For connection ports, use the port itself */
127 QueuePort = PortObject;
128 }
129
130 /* Make sure we have a port */
131 if (QueuePort)
132 {
133 /* Generate the Message ID and set it */
134 Message->Request.MessageId = LpcpNextMessageId++;
135 if (!LpcpNextMessageId) LpcpNextMessageId = 1;
136 Message->Request.CallbackId = 0;
137
138 /* No Message ID for the thread */
139 PsGetCurrentThread()->LpcReplyMessageId = 0;
140
141 /* Insert the message in our chain */
142 InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
143
144 /* Release the lock and release the semaphore */
145 KeEnterCriticalRegion();
146 KeReleaseGuardedMutex(&LpcpLock);
147 LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
148
149 /* If this is a waitable port, wake it up */
150 if (QueuePort->Flags & LPCP_WAITABLE_PORT)
151 {
152 /* Wake it */
153 KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
154 }
155
156 /* We're done */
157 KeLeaveCriticalRegion();
158 LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
159 if (ConnectionPort) ObDereferenceObject(ConnectionPort);
160 return STATUS_SUCCESS;
161 }
162
163 /* If we got here, then free the message and fail */
164 LpcpFreeToPortZone(Message, 3);
165 if (ConnectionPort) ObDereferenceObject(ConnectionPort);
166 return STATUS_PORT_DISCONNECTED;
167 }
168
169 /*
170 * @unimplemented
171 */
172 NTSTATUS
173 NTAPI
174 LpcRequestWaitReplyPort(IN PVOID Port,
175 IN PPORT_MESSAGE LpcMessageRequest,
176 OUT PPORT_MESSAGE LpcMessageReply)
177 {
178 UNIMPLEMENTED;
179 return STATUS_NOT_IMPLEMENTED;
180 }
181
182 /*
183 * @unimplemented
184 */
185 NTSTATUS
186 NTAPI
187 NtRequestPort(IN HANDLE PortHandle,
188 IN PPORT_MESSAGE LpcMessage)
189 {
190 UNIMPLEMENTED;
191 return STATUS_NOT_IMPLEMENTED;
192 }
193
194 /*
195 * @implemented
196 */
197 NTSTATUS
198 NTAPI
199 NtRequestWaitReplyPort(IN HANDLE PortHandle,
200 IN PPORT_MESSAGE LpcRequest,
201 IN OUT PPORT_MESSAGE LpcReply)
202 {
203 PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
204 KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
205 NTSTATUS Status;
206 PLPCP_MESSAGE Message;
207 PETHREAD Thread = PsGetCurrentThread();
208 BOOLEAN Callback;
209 PKSEMAPHORE Semaphore;
210 ULONG MessageType;
211 PAGED_CODE();
212 LPCTRACE(LPC_SEND_DEBUG,
213 "Handle: %lx. Messages: %p/%p. Type: %lx\n",
214 PortHandle,
215 LpcRequest,
216 LpcReply,
217 LpcpGetMessageType(LpcRequest));
218
219 /* Check if the thread is dying */
220 if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING;
221
222 /* Check if this is an LPC Request */
223 if (LpcpGetMessageType(LpcRequest) == LPC_REQUEST)
224 {
225 /* Then it's a callback */
226 Callback = TRUE;
227 }
228 else if (LpcpGetMessageType(LpcRequest))
229 {
230 /* This is a not kernel-mode message */
231 return STATUS_INVALID_PARAMETER;
232 }
233 else
234 {
235 /* This is a kernel-mode message without a callback */
236 LpcRequest->u2.s2.Type |= LPC_REQUEST;
237 Callback = FALSE;
238 }
239
240 /* Get the message type */
241 MessageType = LpcRequest->u2.s2.Type;
242
243 /* Validate the length */
244 if ((LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
245 LpcRequest->u1.s1.TotalLength)
246 {
247 /* Fail */
248 return STATUS_INVALID_PARAMETER;
249 }
250
251 /* Reference the object */
252 Status = ObReferenceObjectByHandle(PortHandle,
253 0,
254 LpcPortObjectType,
255 PreviousMode,
256 (PVOID*)&Port,
257 NULL);
258 if (!NT_SUCCESS(Status)) return Status;
259
260 /* Validate the message length */
261 if ((LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
262 (LpcRequest->u1.s1.TotalLength <= LpcRequest->u1.s1.DataLength))
263 {
264 /* Fail */
265 ObDereferenceObject(Port);
266 return STATUS_PORT_MESSAGE_TOO_LONG;
267 }
268
269 /* Allocate a message from the port zone */
270 Message = LpcpAllocateFromPortZone();
271 if (!Message)
272 {
273 /* Fail if we couldn't allocate a message */
274 ObDereferenceObject(Port);
275 return STATUS_NO_MEMORY;
276 }
277
278 /* Check if this is a callback */
279 if (Callback)
280 {
281 /* FIXME: TODO */
282 Semaphore = NULL; // we'd use the Thread Semaphore here
283 ASSERT(FALSE);
284 }
285 else
286 {
287 /* No callback, just copy the message */
288 LpcpMoveMessage(&Message->Request,
289 LpcRequest,
290 LpcRequest + 1,
291 MessageType,
292 &Thread->Cid);
293
294 /* Acquire the LPC lock */
295 KeAcquireGuardedMutex(&LpcpLock);
296
297 /* Right now clear the port context */
298 Message->PortContext = NULL;
299
300 /* Check if this is a not connection port */
301 if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
302 {
303 /* We want the connected port */
304 QueuePort = Port->ConnectedPort;
305 if (!QueuePort)
306 {
307 /* We have no connected port, fail */
308 LpcpFreeToPortZone(Message, 3);
309 ObDereferenceObject(Port);
310 return STATUS_PORT_DISCONNECTED;
311 }
312
313 /* This will be the rundown port */
314 ReplyPort = QueuePort;
315
316 /* Check if this is a communication port */
317 if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
318 {
319 /* Copy the port context and use the connection port */
320 Message->PortContext = QueuePort->PortContext;
321 ConnectionPort = QueuePort = Port->ConnectionPort;
322 if (!ConnectionPort)
323 {
324 /* Fail */
325 LpcpFreeToPortZone(Message, 3);
326 ObDereferenceObject(Port);
327 return STATUS_PORT_DISCONNECTED;
328 }
329 }
330 else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
331 LPCP_COMMUNICATION_PORT)
332 {
333 /* Use the connection port for anything but communication ports */
334 ConnectionPort = QueuePort = Port->ConnectionPort;
335 if (!ConnectionPort)
336 {
337 /* Fail */
338 LpcpFreeToPortZone(Message, 3);
339 ObDereferenceObject(Port);
340 return STATUS_PORT_DISCONNECTED;
341 }
342 }
343
344 /* Reference the connection port if it exists */
345 if (ConnectionPort) ObReferenceObject(ConnectionPort);
346 }
347 else
348 {
349 /* Otherwise, for a connection port, use the same port object */
350 QueuePort = ReplyPort = Port;
351 }
352
353 /* No reply thread */
354 Message->RepliedToThread = NULL;
355 Message->SenderPort = Port;
356
357 /* Generate the Message ID and set it */
358 Message->Request.MessageId = LpcpNextMessageId++;
359 if (!LpcpNextMessageId) LpcpNextMessageId = 1;
360 Message->Request.CallbackId = 0;
361
362 /* Set the message ID for our thread now */
363 Thread->LpcReplyMessageId = Message->Request.MessageId;
364 Thread->LpcReplyMessage = NULL;
365
366 /* Insert the message in our chain */
367 InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
368 InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
369 Thread->LpcWaitingOnPort = Port;
370
371 /* Release the lock and get the semaphore we'll use later */
372 KeEnterCriticalRegion();
373 KeReleaseGuardedMutex(&LpcpLock);
374 Semaphore = QueuePort->MsgQueue.Semaphore;
375
376 /* If this is a waitable port, wake it up */
377 if (QueuePort->Flags & LPCP_WAITABLE_PORT)
378 {
379 /* Wake it */
380 KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
381 }
382 }
383
384 /* Now release the semaphore */
385 LpcpCompleteWait(Semaphore);
386 KeLeaveCriticalRegion();
387
388 /* And let's wait for the reply */
389 LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);
390
391 /* Acquire the LPC lock */
392 KeAcquireGuardedMutex(&LpcpLock);
393
394 /* Get the LPC Message and clear our thread's reply data */
395 Message = Thread->LpcReplyMessage;
396 Thread->LpcReplyMessage = NULL;
397 Thread->LpcReplyMessageId = 0;
398
399 /* Check if we have anything on the reply chain*/
400 if (!IsListEmpty(&Thread->LpcReplyChain))
401 {
402 /* Remove this thread and reinitialize the list */
403 RemoveEntryList(&Thread->LpcReplyChain);
404 InitializeListHead(&Thread->LpcReplyChain);
405 }
406
407 /* Release the lock */
408 KeReleaseGuardedMutex(&LpcpLock);
409
410 /* Check if we got a reply */
411 if (Status == STATUS_SUCCESS)
412 {
413 /* Check if we have a valid message */
414 if (Message)
415 {
416 LPCTRACE(LPC_SEND_DEBUG,
417 "Reply Messages: %p/%p\n",
418 &Message->Request,
419 (&Message->Request) + 1);
420
421 /* Move the message */
422 LpcpMoveMessage(LpcReply,
423 &Message->Request,
424 (&Message->Request) + 1,
425 0,
426 NULL);
427
428 /* Check if this is an LPC request with data information */
429 if ((LpcpGetMessageType(&Message->Request) == LPC_REQUEST) &&
430 (Message->Request.u2.s2.DataInfoOffset))
431 {
432 /* Save the data information */
433 LpcpSaveDataInfoMessage(Port, Message, 0);
434 }
435 else
436 {
437 /* Otherwise, just free it */
438 LpcpFreeToPortZone(Message, 0);
439 }
440 }
441 else
442 {
443 /* We don't have a reply */
444 Status = STATUS_LPC_REPLY_LOST;
445 }
446 }
447 else
448 {
449 /* The wait failed, free the message */
450 if (Message) LpcpFreeToPortZone(Message, 0);
451 }
452
453 /* All done */
454 LPCTRACE(LPC_SEND_DEBUG,
455 "Port: %p. Status: %p\n",
456 Port,
457 Status);
458 ObDereferenceObject(Port);
459 if (ConnectionPort) ObDereferenceObject(ConnectionPort);
460 return Status;
461 }
462
463 /* EOF */