[LWIP]
[reactos.git] / reactos / lib / drivers / lwip / src / rostcp.c
1 #include "lwip/sys.h"
2 #include "lwip/tcpip.h"
3
4 #include "rosip.h"
5
6 #include <debug.h>
7
8 static const char * const tcp_state_str[] = {
9 "CLOSED",
10 "LISTEN",
11 "SYN_SENT",
12 "SYN_RCVD",
13 "ESTABLISHED",
14 "FIN_WAIT_1",
15 "FIN_WAIT_2",
16 "CLOSE_WAIT",
17 "CLOSING",
18 "LAST_ACK",
19 "TIME_WAIT"
20 };
21
22 /* The way that lwIP does multi-threading is really not ideal for our purposes but
23 * we best go along with it unless we want another unstable TCP library. lwIP uses
24 * a thread called the "tcpip thread" which is the only one allowed to call raw API
25 * functions. Since this is the case, for each of our LibTCP* functions, we queue a request
26 * for a callback to "tcpip thread" which calls our LibTCP*Callback functions. Yes, this is
27 * a lot of unnecessary thread swapping and it could definitely be faster, but I don't want
28 * to going messing around in lwIP because I have no desire to create another mess like oskittcp */
29
30 extern KEVENT TerminationEvent;
31 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
32 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
33
34 static
35 void
36 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
37 {
38 PLIST_ENTRY Entry;
39 PQUEUE_ENTRY qp = NULL;
40
41 ReferenceObject(Connection);
42
43 while (!IsListEmpty(&Connection->PacketQueue))
44 {
45 Entry = RemoveHeadList(&Connection->PacketQueue);
46 qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
47
48 /* We're in the tcpip thread here so this is safe */
49 pbuf_free(qp->p);
50
51 ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
52 }
53
54 DereferenceObject(Connection);
55 }
56
57 void LibTCPEnqueuePacket(PCONNECTION_ENDPOINT Connection, struct pbuf *p)
58 {
59 PQUEUE_ENTRY qp;
60
61 qp = (PQUEUE_ENTRY)ExAllocateFromNPagedLookasideList(&QueueEntryLookasideList);
62 qp->p = p;
63 qp->Offset = 0;
64
65 ExInterlockedInsertTailList(&Connection->PacketQueue, &qp->ListEntry, &Connection->Lock);
66 }
67
68 PQUEUE_ENTRY LibTCPDequeuePacket(PCONNECTION_ENDPOINT Connection)
69 {
70 PLIST_ENTRY Entry;
71 PQUEUE_ENTRY qp = NULL;
72
73 if (IsListEmpty(&Connection->PacketQueue)) return NULL;
74
75 Entry = RemoveHeadList(&Connection->PacketQueue);
76
77 qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
78
79 return qp;
80 }
81
82 NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHAR RecvBuffer, UINT RecvLen, UINT *Received)
83 {
84 PQUEUE_ENTRY qp;
85 struct pbuf* p;
86 NTSTATUS Status;
87 UINT ReadLength, PayloadLength;
88 KIRQL OldIrql;
89 PUCHAR Payload;
90
91 (*Received) = 0;
92
93 LockObject(Connection, &OldIrql);
94
95 if (!IsListEmpty(&Connection->PacketQueue))
96 {
97 while ((qp = LibTCPDequeuePacket(Connection)) != NULL)
98 {
99 p = qp->p;
100
101 /* Calculate the payload first */
102 Payload = p->payload;
103 Payload += qp->Offset;
104 PayloadLength = p->len;
105 PayloadLength -= qp->Offset;
106
107 /* Check if we're reading the whole buffer */
108 ReadLength = MIN(PayloadLength, RecvLen);
109 if (ReadLength != PayloadLength)
110 {
111 /* Save this one for later */
112 qp->Offset += ReadLength;
113 InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
114 qp = NULL;
115 }
116
117 UnlockObject(Connection, OldIrql);
118
119 /* Return to a lower IRQL because the receive buffer may be pageable memory */
120 RtlCopyMemory(RecvBuffer,
121 Payload,
122 ReadLength);
123
124 LockObject(Connection, &OldIrql);
125
126 /* Update trackers */
127 RecvLen -= ReadLength;
128 RecvBuffer += ReadLength;
129 (*Received) += ReadLength;
130
131 if (qp != NULL)
132 {
133 /* Use this special pbuf free callback function because we're outside tcpip thread */
134 pbuf_free_callback(qp->p);
135
136 ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
137 }
138 else
139 {
140 /* If we get here, it means we've filled the buffer */
141 ASSERT(RecvLen == 0);
142 }
143
144 Status = STATUS_SUCCESS;
145
146 if (!RecvLen)
147 break;
148 }
149 }
150 else
151 {
152 if (Connection->ReceiveShutdown)
153 Status = STATUS_SUCCESS;
154 else
155 Status = STATUS_PENDING;
156 }
157
158 UnlockObject(Connection, OldIrql);
159
160 return Status;
161 }
162
163 static
164 BOOLEAN
165 WaitForEventSafely(PRKEVENT Event)
166 {
167 PVOID WaitObjects[] = {Event, &TerminationEvent};
168
169 if (KeWaitForMultipleObjects(2,
170 WaitObjects,
171 WaitAny,
172 Executive,
173 KernelMode,
174 FALSE,
175 NULL,
176 NULL) == STATUS_WAIT_0)
177 {
178 /* Signalled by the caller's event */
179 return TRUE;
180 }
181 else /* if KeWaitForMultipleObjects() == STATUS_WAIT_1 */
182 {
183 /* Signalled by our termination event */
184 return FALSE;
185 }
186 }
187
188 static
189 err_t
190 InternalSendEventHandler(void *arg, PTCP_PCB pcb, const u16_t space)
191 {
192 /* Make sure the socket didn't get closed */
193 if (!arg) return ERR_OK;
194
195 TCPSendEventHandler(arg, space);
196
197 return ERR_OK;
198 }
199
200 static
201 err_t
202 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
203 {
204 PCONNECTION_ENDPOINT Connection = arg;
205 struct pbuf *pb;
206 ULONG RecvLen;
207
208 /* Make sure the socket didn't get closed */
209 if (!arg)
210 {
211 if (p)
212 pbuf_free(p);
213
214 return ERR_OK;
215 }
216
217 if (p)
218 {
219 pb = p;
220 RecvLen = 0;
221 while (pb != NULL)
222 {
223 /* Enqueue this buffer */
224 LibTCPEnqueuePacket(Connection, pb);
225 RecvLen += pb->len;
226
227 /* Advance and unchain the buffer */
228 pb = pbuf_dechain(pb);;
229 }
230
231 tcp_recved(pcb, RecvLen);
232
233 TCPRecvEventHandler(arg);
234 }
235 else if (err == ERR_OK)
236 {
237 /* Complete pending reads with 0 bytes to indicate a graceful closure,
238 * but note that send is still possible in this state so we don't close the
239 * whole socket here (by calling tcp_close()) as that would violate TCP specs
240 */
241 Connection->ReceiveShutdown = TRUE;
242 TCPFinEventHandler(arg, ERR_OK);
243 }
244
245 return ERR_OK;
246 }
247
248 static
249 err_t
250 InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
251 {
252 /* Make sure the socket didn't get closed */
253 if (!arg)
254 return ERR_ABRT;
255
256 TCPAcceptEventHandler(arg, newpcb);
257
258 /* Set in LibTCPAccept (called from TCPAcceptEventHandler) */
259 if (newpcb->callback_arg)
260 return ERR_OK;
261 else
262 return ERR_ABRT;
263 }
264
265 static
266 err_t
267 InternalConnectEventHandler(void *arg, PTCP_PCB pcb, const err_t err)
268 {
269 /* Make sure the socket didn't get closed */
270 if (!arg)
271 return ERR_OK;
272
273 TCPConnectEventHandler(arg, err);
274
275 return ERR_OK;
276 }
277
278 static
279 void
280 InternalErrorEventHandler(void *arg, const err_t err)
281 {
282 /* Make sure the socket didn't get closed */
283 if (!arg) return;
284
285 TCPFinEventHandler(arg, err);
286 }
287
288 static
289 void
290 LibTCPSocketCallback(void *arg)
291 {
292 struct lwip_callback_msg *msg = arg;
293
294 ASSERT(msg);
295
296 msg->Output.Socket.NewPcb = tcp_new();
297
298 if (msg->Output.Socket.NewPcb)
299 {
300 tcp_arg(msg->Output.Socket.NewPcb, msg->Input.Socket.Arg);
301 tcp_err(msg->Output.Socket.NewPcb, InternalErrorEventHandler);
302 }
303
304 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
305 }
306
307 struct tcp_pcb *
308 LibTCPSocket(void *arg)
309 {
310 struct lwip_callback_msg *msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
311 struct tcp_pcb *ret;
312
313 if (msg)
314 {
315 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
316 msg->Input.Socket.Arg = arg;
317
318 tcpip_callback_with_block(LibTCPSocketCallback, msg, 1);
319
320 if (WaitForEventSafely(&msg->Event))
321 ret = msg->Output.Socket.NewPcb;
322 else
323 ret = NULL;
324
325 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
326
327 return ret;
328 }
329
330 return NULL;
331 }
332
333 static
334 void
335 LibTCPBindCallback(void *arg)
336 {
337 struct lwip_callback_msg *msg = arg;
338 PTCP_PCB pcb = msg->Input.Bind.Connection->SocketContext;
339
340 ASSERT(msg);
341
342 if (!msg->Input.Bind.Connection->SocketContext)
343 {
344 msg->Output.Bind.Error = ERR_CLSD;
345 goto done;
346 }
347
348 /* We're guaranteed that the local address is valid to bind at this point */
349 pcb->so_options |= SOF_REUSEADDR;
350
351 msg->Output.Bind.Error = tcp_bind(pcb,
352 msg->Input.Bind.IpAddress,
353 ntohs(msg->Input.Bind.Port));
354
355 done:
356 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
357 }
358
359 err_t
360 LibTCPBind(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
361 {
362 struct lwip_callback_msg *msg;
363 err_t ret;
364
365 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
366 if (msg)
367 {
368 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
369 msg->Input.Bind.Connection = Connection;
370 msg->Input.Bind.IpAddress = ipaddr;
371 msg->Input.Bind.Port = port;
372
373 tcpip_callback_with_block(LibTCPBindCallback, msg, 1);
374
375 if (WaitForEventSafely(&msg->Event))
376 ret = msg->Output.Bind.Error;
377 else
378 ret = ERR_CLSD;
379
380 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
381
382 return ret;
383 }
384
385 return ERR_MEM;
386 }
387
388 static
389 void
390 LibTCPListenCallback(void *arg)
391 {
392 struct lwip_callback_msg *msg = arg;
393
394 ASSERT(msg);
395
396 if (!msg->Input.Listen.Connection->SocketContext)
397 {
398 msg->Output.Listen.NewPcb = NULL;
399 goto done;
400 }
401
402 msg->Output.Listen.NewPcb = tcp_listen_with_backlog((PTCP_PCB)msg->Input.Listen.Connection->SocketContext, msg->Input.Listen.Backlog);
403
404 if (msg->Output.Listen.NewPcb)
405 {
406 tcp_accept(msg->Output.Listen.NewPcb, InternalAcceptEventHandler);
407 }
408
409 done:
410 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
411 }
412
413 PTCP_PCB
414 LibTCPListen(PCONNECTION_ENDPOINT Connection, const u8_t backlog)
415 {
416 struct lwip_callback_msg *msg;
417 PTCP_PCB ret;
418
419 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
420 if (msg)
421 {
422 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
423 msg->Input.Listen.Connection = Connection;
424 msg->Input.Listen.Backlog = backlog;
425
426 tcpip_callback_with_block(LibTCPListenCallback, msg, 1);
427
428 if (WaitForEventSafely(&msg->Event))
429 ret = msg->Output.Listen.NewPcb;
430 else
431 ret = NULL;
432
433 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
434
435 return ret;
436 }
437
438 return NULL;
439 }
440
441 static
442 void
443 LibTCPSendCallback(void *arg)
444 {
445 struct lwip_callback_msg *msg = arg;
446
447 ASSERT(msg);
448
449 if (!msg->Input.Send.Connection->SocketContext)
450 {
451 msg->Output.Send.Error = ERR_CLSD;
452 goto done;
453 }
454
455 if (msg->Input.Send.Connection->SendShutdown)
456 {
457 msg->Output.Send.Error = ERR_CLSD;
458 goto done;
459 }
460
461 msg->Output.Send.Error = tcp_write((PTCP_PCB)msg->Input.Send.Connection->SocketContext,
462 msg->Input.Send.Data,
463 msg->Input.Send.DataLength,
464 TCP_WRITE_FLAG_COPY);
465 if (msg->Output.Send.Error == ERR_MEM)
466 {
467 /* No buffer space so return pending */
468 msg->Output.Send.Error = ERR_INPROGRESS;
469 }
470 else if (msg->Output.Send.Error == ERR_OK)
471 {
472 /* Queued successfully so try to send it */
473 tcp_output((PTCP_PCB)msg->Input.Send.Connection->SocketContext);
474 }
475
476 done:
477 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
478 }
479
480 err_t
481 LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, const int safe)
482 {
483 err_t ret;
484 struct lwip_callback_msg *msg;
485
486 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
487 if (msg)
488 {
489 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
490 msg->Input.Send.Connection = Connection;
491 msg->Input.Send.Data = dataptr;
492 msg->Input.Send.DataLength = len;
493
494 if (safe)
495 LibTCPSendCallback(msg);
496 else
497 tcpip_callback_with_block(LibTCPSendCallback, msg, 1);
498
499 if (WaitForEventSafely(&msg->Event))
500 ret = msg->Output.Send.Error;
501 else
502 ret = ERR_CLSD;
503
504 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
505
506 return ret;
507 }
508
509 return ERR_MEM;
510 }
511
512 static
513 void
514 LibTCPConnectCallback(void *arg)
515 {
516 struct lwip_callback_msg *msg = arg;
517 err_t Error;
518
519 ASSERT(arg);
520
521 if (!msg->Input.Connect.Connection->SocketContext)
522 {
523 msg->Output.Connect.Error = ERR_CLSD;
524 goto done;
525 }
526
527 tcp_recv((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalRecvEventHandler);
528 tcp_sent((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalSendEventHandler);
529
530 Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
531 msg->Input.Connect.IpAddress, ntohs(msg->Input.Connect.Port),
532 InternalConnectEventHandler);
533
534 msg->Output.Connect.Error = Error == ERR_OK ? ERR_INPROGRESS : Error;
535
536 done:
537 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
538 }
539
540 err_t
541 LibTCPConnect(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
542 {
543 struct lwip_callback_msg *msg;
544 err_t ret;
545
546 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
547 if (msg)
548 {
549 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
550 msg->Input.Connect.Connection = Connection;
551 msg->Input.Connect.IpAddress = ipaddr;
552 msg->Input.Connect.Port = port;
553
554 tcpip_callback_with_block(LibTCPConnectCallback, msg, 1);
555
556 if (WaitForEventSafely(&msg->Event))
557 {
558 ret = msg->Output.Connect.Error;
559 }
560 else
561 ret = ERR_CLSD;
562
563 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
564
565 return ret;
566 }
567
568 return ERR_MEM;
569 }
570
571 static
572 void
573 LibTCPShutdownCallback(void *arg)
574 {
575 struct lwip_callback_msg *msg = arg;
576 PTCP_PCB pcb = msg->Input.Shutdown.Connection->SocketContext;
577
578 if (!msg->Input.Shutdown.Connection->SocketContext)
579 {
580 msg->Output.Shutdown.Error = ERR_CLSD;
581 goto done;
582 }
583
584 if (pcb->state == CLOSE_WAIT)
585 {
586 /* This case actually results in a socket closure later (lwIP bug?) */
587 msg->Input.Shutdown.Connection->SocketContext = NULL;
588 }
589
590 msg->Output.Shutdown.Error = tcp_shutdown(pcb, msg->Input.Shutdown.shut_rx, msg->Input.Shutdown.shut_tx);
591 if (msg->Output.Shutdown.Error)
592 {
593 msg->Input.Shutdown.Connection->SocketContext = pcb;
594 }
595 else
596 {
597 if (msg->Input.Shutdown.shut_rx)
598 msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
599
600 if (msg->Input.Shutdown.shut_tx)
601 msg->Input.Shutdown.Connection->SendShutdown = TRUE;
602 }
603
604 done:
605 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
606 }
607
608 err_t
609 LibTCPShutdown(PCONNECTION_ENDPOINT Connection, const int shut_rx, const int shut_tx)
610 {
611 struct lwip_callback_msg *msg;
612 err_t ret;
613
614 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
615 if (msg)
616 {
617 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
618
619 msg->Input.Shutdown.Connection = Connection;
620 msg->Input.Shutdown.shut_rx = shut_rx;
621 msg->Input.Shutdown.shut_tx = shut_tx;
622
623 tcpip_callback_with_block(LibTCPShutdownCallback, msg, 1);
624
625 if (WaitForEventSafely(&msg->Event))
626 ret = msg->Output.Shutdown.Error;
627 else
628 ret = ERR_CLSD;
629
630 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
631
632 return ret;
633 }
634
635 return ERR_MEM;
636 }
637
638 static
639 void
640 LibTCPCloseCallback(void *arg)
641 {
642 struct lwip_callback_msg *msg = arg;
643 PTCP_PCB pcb = msg->Input.Close.Connection->SocketContext;
644
645 /* Empty the queue even if we're already "closed" */
646 LibTCPEmptyQueue(msg->Input.Close.Connection);
647
648 if (!msg->Input.Close.Connection->SocketContext)
649 {
650 msg->Output.Close.Error = ERR_OK;
651 goto done;
652 }
653
654 /* Clear the PCB pointer */
655 msg->Input.Close.Connection->SocketContext = NULL;
656
657 switch (pcb->state)
658 {
659 case CLOSED:
660 case LISTEN:
661 case SYN_SENT:
662 msg->Output.Close.Error = tcp_close(pcb);
663
664 if (!msg->Output.Close.Error && msg->Input.Close.Callback)
665 TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
666 break;
667
668 default:
669 if (msg->Input.Close.Connection->SendShutdown &&
670 msg->Input.Close.Connection->ReceiveShutdown)
671 {
672 /* Abort the connection */
673 tcp_abort(pcb);
674
675 /* Aborts always succeed */
676 msg->Output.Close.Error = ERR_OK;
677 }
678 else
679 {
680 /* Start the graceful close process (or send RST for pending data) */
681 msg->Output.Close.Error = tcp_close(pcb);
682 }
683 break;
684 }
685
686 if (msg->Output.Close.Error)
687 {
688 /* Restore the PCB pointer */
689 msg->Input.Close.Connection->SocketContext = pcb;
690 }
691
692 done:
693 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
694 }
695
696 err_t
697 LibTCPClose(PCONNECTION_ENDPOINT Connection, const int safe, const int callback)
698 {
699 err_t ret;
700 struct lwip_callback_msg *msg;
701
702 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
703 if (msg)
704 {
705 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
706
707 msg->Input.Close.Connection = Connection;
708 msg->Input.Close.Callback = callback;
709
710 if (safe)
711 LibTCPCloseCallback(msg);
712 else
713 tcpip_callback_with_block(LibTCPCloseCallback, msg, 1);
714
715 if (WaitForEventSafely(&msg->Event))
716 ret = msg->Output.Close.Error;
717 else
718 ret = ERR_CLSD;
719
720 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
721
722 return ret;
723 }
724
725 return ERR_MEM;
726 }
727
728 void
729 LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg)
730 {
731 ASSERT(arg);
732
733 tcp_arg(pcb, NULL);
734 tcp_recv(pcb, InternalRecvEventHandler);
735 tcp_sent(pcb, InternalSendEventHandler);
736 tcp_err(pcb, InternalErrorEventHandler);
737 tcp_arg(pcb, arg);
738
739 tcp_accepted(listen_pcb);
740 }
741
742 err_t
743 LibTCPGetHostName(PTCP_PCB pcb, struct ip_addr *const ipaddr, u16_t *const port)
744 {
745 if (!pcb)
746 return ERR_CLSD;
747
748 *ipaddr = pcb->local_ip;
749 *port = pcb->local_port;
750
751 return ERR_OK;
752 }
753
754 err_t
755 LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr * const ipaddr, u16_t * const port)
756 {
757 if (!pcb)
758 return ERR_CLSD;
759
760 *ipaddr = pcb->remote_ip;
761 *port = pcb->remote_port;
762
763 return ERR_OK;
764 }