[LWIP]
[reactos.git] / reactos / lib / drivers / lwip / src / rostcp.c
1 #include "lwip/sys.h"
2 #include "lwip/tcpip.h"
3
4 #include "rosip.h"
5
6 #include <debug.h>
7
8 static const char * const tcp_state_str[] = {
9 "CLOSED",
10 "LISTEN",
11 "SYN_SENT",
12 "SYN_RCVD",
13 "ESTABLISHED",
14 "FIN_WAIT_1",
15 "FIN_WAIT_2",
16 "CLOSE_WAIT",
17 "CLOSING",
18 "LAST_ACK",
19 "TIME_WAIT"
20 };
21
22 /* The way that lwIP does multi-threading is really not ideal for our purposes but
23 * we best go along with it unless we want another unstable TCP library. lwIP uses
24 * a thread called the "tcpip thread" which is the only one allowed to call raw API
25 * functions. Since this is the case, for each of our LibTCP* functions, we queue a request
26 * for a callback to "tcpip thread" which calls our LibTCP*Callback functions. Yes, this is
27 * a lot of unnecessary thread swapping and it could definitely be faster, but I don't want
28 * to going messing around in lwIP because I have no desire to create another mess like oskittcp */
29
30 extern KEVENT TerminationEvent;
31 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
32 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
33
34 static
35 void
36 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
37 {
38 PLIST_ENTRY Entry;
39 PQUEUE_ENTRY qp = NULL;
40
41 ReferenceObject(Connection);
42
43 while (!IsListEmpty(&Connection->PacketQueue))
44 {
45 Entry = RemoveHeadList(&Connection->PacketQueue);
46 qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
47
48 /* We're in the tcpip thread here so this is safe */
49 pbuf_free(qp->p);
50
51 ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
52 }
53
54 DereferenceObject(Connection);
55 }
56
57 void LibTCPEnqueuePacket(PCONNECTION_ENDPOINT Connection, struct pbuf *p)
58 {
59 PQUEUE_ENTRY qp;
60
61 qp = (PQUEUE_ENTRY)ExAllocateFromNPagedLookasideList(&QueueEntryLookasideList);
62 qp->p = p;
63
64 ExInterlockedInsertTailList(&Connection->PacketQueue, &qp->ListEntry, &Connection->Lock);
65 }
66
67 PQUEUE_ENTRY LibTCPDequeuePacket(PCONNECTION_ENDPOINT Connection)
68 {
69 PLIST_ENTRY Entry;
70 PQUEUE_ENTRY qp = NULL;
71
72 if (IsListEmpty(&Connection->PacketQueue)) return NULL;
73
74 Entry = RemoveHeadList(&Connection->PacketQueue);
75
76 qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
77
78 return qp;
79 }
80
81 NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHAR RecvBuffer, UINT RecvLen, UINT *Received)
82 {
83 PQUEUE_ENTRY qp;
84 struct pbuf* p;
85 NTSTATUS Status = STATUS_PENDING;
86 UINT ReadLength, ExistingDataLength, SpaceLeft;
87 KIRQL OldIrql;
88
89 (*Received) = 0;
90 SpaceLeft = RecvLen;
91
92 LockObject(Connection, &OldIrql);
93
94 if (!IsListEmpty(&Connection->PacketQueue))
95 {
96 while ((qp = LibTCPDequeuePacket(Connection)) != NULL)
97 {
98 p = qp->p;
99 ExistingDataLength = (*Received);
100
101 Status = STATUS_SUCCESS;
102
103 ReadLength = MIN(p->tot_len, SpaceLeft);
104 if (ReadLength != p->tot_len)
105 {
106 if (ExistingDataLength)
107 {
108 /* The packet was too big but we used some data already so give it another shot later */
109 InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
110 break;
111 }
112 else
113 {
114 /* The packet is just too big to fit fully in our buffer, even when empty so
115 * return an informative status but still copy all the data we can fit.
116 */
117 Status = STATUS_BUFFER_OVERFLOW;
118 }
119 }
120
121 UnlockObject(Connection, OldIrql);
122
123 /* Return to a lower IRQL because the receive buffer may be pageable memory */
124 for (; (*Received) < ReadLength + ExistingDataLength; (*Received) += p->len, p = p->next)
125 {
126 RtlCopyMemory(RecvBuffer + (*Received), p->payload, p->len);
127 }
128
129 LockObject(Connection, &OldIrql);
130
131 SpaceLeft -= ReadLength;
132
133 /* Use this special pbuf free callback function because we're outside tcpip thread */
134 pbuf_free_callback(qp->p);
135
136 ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
137
138 if (!RecvLen)
139 break;
140
141 if (Status != STATUS_SUCCESS)
142 break;
143 }
144 }
145 else
146 {
147 if (Connection->ReceiveShutdown)
148 Status = STATUS_SUCCESS;
149 else
150 Status = STATUS_PENDING;
151 }
152
153 UnlockObject(Connection, OldIrql);
154
155 return Status;
156 }
157
158 static
159 BOOLEAN
160 WaitForEventSafely(PRKEVENT Event)
161 {
162 PVOID WaitObjects[] = {Event, &TerminationEvent};
163
164 if (KeWaitForMultipleObjects(2,
165 WaitObjects,
166 WaitAny,
167 Executive,
168 KernelMode,
169 FALSE,
170 NULL,
171 NULL) == STATUS_WAIT_0)
172 {
173 /* Signalled by the caller's event */
174 return TRUE;
175 }
176 else /* if KeWaitForMultipleObjects() == STATUS_WAIT_1 */
177 {
178 /* Signalled by our termination event */
179 return FALSE;
180 }
181 }
182
183 static
184 err_t
185 InternalSendEventHandler(void *arg, PTCP_PCB pcb, const u16_t space)
186 {
187 /* Make sure the socket didn't get closed */
188 if (!arg) return ERR_OK;
189
190 TCPSendEventHandler(arg, space);
191
192 return ERR_OK;
193 }
194
195 static
196 err_t
197 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
198 {
199 PCONNECTION_ENDPOINT Connection = arg;
200 u32_t len;
201
202 /* Make sure the socket didn't get closed */
203 if (!arg)
204 {
205 if (p)
206 pbuf_free(p);
207
208 return ERR_OK;
209 }
210
211 if (p)
212 {
213 len = TCPRecvEventHandler(arg, p);
214 if (len == p->tot_len)
215 {
216 tcp_recved(pcb, len);
217
218 pbuf_free(p);
219
220 return ERR_OK;
221 }
222 else if (len != 0)
223 {
224 DbgPrint("UNTESTED CASE: NOT ALL DATA TAKEN! EXTRA DATA MAY BE LOST!\n");
225
226 tcp_recved(pcb, len);
227
228 /* Possible memory leak of pbuf here? */
229
230 return ERR_OK;
231 }
232 else
233 {
234 LibTCPEnqueuePacket(Connection, p);
235
236 tcp_recved(pcb, p->tot_len);
237
238 return ERR_OK;
239 }
240 }
241 else if (err == ERR_OK)
242 {
243 /* Complete pending reads with 0 bytes to indicate a graceful closure,
244 * but note that send is still possible in this state so we don't close the
245 * whole socket here (by calling tcp_close()) as that would violate TCP specs
246 */
247 Connection->ReceiveShutdown = TRUE;
248 TCPFinEventHandler(arg, ERR_OK);
249 }
250
251 return ERR_OK;
252 }
253
254 static
255 err_t
256 InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
257 {
258 /* Make sure the socket didn't get closed */
259 if (!arg)
260 return ERR_ABRT;
261
262 TCPAcceptEventHandler(arg, newpcb);
263
264 /* Set in LibTCPAccept (called from TCPAcceptEventHandler) */
265 if (newpcb->callback_arg)
266 return ERR_OK;
267 else
268 return ERR_ABRT;
269 }
270
271 static
272 err_t
273 InternalConnectEventHandler(void *arg, PTCP_PCB pcb, const err_t err)
274 {
275 /* Make sure the socket didn't get closed */
276 if (!arg)
277 return ERR_OK;
278
279 TCPConnectEventHandler(arg, err);
280
281 return ERR_OK;
282 }
283
284 static
285 void
286 InternalErrorEventHandler(void *arg, const err_t err)
287 {
288 /* Make sure the socket didn't get closed */
289 if (!arg) return;
290
291 TCPFinEventHandler(arg, err);
292 }
293
294 static
295 void
296 LibTCPSocketCallback(void *arg)
297 {
298 struct lwip_callback_msg *msg = arg;
299
300 ASSERT(msg);
301
302 msg->Output.Socket.NewPcb = tcp_new();
303
304 if (msg->Output.Socket.NewPcb)
305 {
306 tcp_arg(msg->Output.Socket.NewPcb, msg->Input.Socket.Arg);
307 tcp_err(msg->Output.Socket.NewPcb, InternalErrorEventHandler);
308 }
309
310 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
311 }
312
313 struct tcp_pcb *
314 LibTCPSocket(void *arg)
315 {
316 struct lwip_callback_msg *msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
317 struct tcp_pcb *ret;
318
319 if (msg)
320 {
321 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
322 msg->Input.Socket.Arg = arg;
323
324 tcpip_callback_with_block(LibTCPSocketCallback, msg, 1);
325
326 if (WaitForEventSafely(&msg->Event))
327 ret = msg->Output.Socket.NewPcb;
328 else
329 ret = NULL;
330
331 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
332
333 return ret;
334 }
335
336 return NULL;
337 }
338
339 static
340 void
341 LibTCPBindCallback(void *arg)
342 {
343 struct lwip_callback_msg *msg = arg;
344
345 ASSERT(msg);
346
347 if (!msg->Input.Bind.Connection->SocketContext)
348 {
349 msg->Output.Bind.Error = ERR_CLSD;
350 goto done;
351 }
352
353 msg->Output.Bind.Error = tcp_bind((PTCP_PCB)msg->Input.Bind.Connection->SocketContext,
354 msg->Input.Bind.IpAddress,
355 ntohs(msg->Input.Bind.Port));
356
357 done:
358 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
359 }
360
361 err_t
362 LibTCPBind(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
363 {
364 struct lwip_callback_msg *msg;
365 err_t ret;
366
367 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
368 if (msg)
369 {
370 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
371 msg->Input.Bind.Connection = Connection;
372 msg->Input.Bind.IpAddress = ipaddr;
373 msg->Input.Bind.Port = port;
374
375 tcpip_callback_with_block(LibTCPBindCallback, msg, 1);
376
377 if (WaitForEventSafely(&msg->Event))
378 ret = msg->Output.Bind.Error;
379 else
380 ret = ERR_CLSD;
381
382 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
383
384 return ret;
385 }
386
387 return ERR_MEM;
388 }
389
390 static
391 void
392 LibTCPListenCallback(void *arg)
393 {
394 struct lwip_callback_msg *msg = arg;
395
396 ASSERT(msg);
397
398 if (!msg->Input.Listen.Connection->SocketContext)
399 {
400 msg->Output.Listen.NewPcb = NULL;
401 goto done;
402 }
403
404 msg->Output.Listen.NewPcb = tcp_listen_with_backlog((PTCP_PCB)msg->Input.Listen.Connection->SocketContext, msg->Input.Listen.Backlog);
405
406 if (msg->Output.Listen.NewPcb)
407 {
408 tcp_accept(msg->Output.Listen.NewPcb, InternalAcceptEventHandler);
409 }
410
411 done:
412 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
413 }
414
415 PTCP_PCB
416 LibTCPListen(PCONNECTION_ENDPOINT Connection, const u8_t backlog)
417 {
418 struct lwip_callback_msg *msg;
419 PTCP_PCB ret;
420
421 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
422 if (msg)
423 {
424 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
425 msg->Input.Listen.Connection = Connection;
426 msg->Input.Listen.Backlog = backlog;
427
428 tcpip_callback_with_block(LibTCPListenCallback, msg, 1);
429
430 if (WaitForEventSafely(&msg->Event))
431 ret = msg->Output.Listen.NewPcb;
432 else
433 ret = NULL;
434
435 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
436
437 return ret;
438 }
439
440 return NULL;
441 }
442
443 static
444 void
445 LibTCPSendCallback(void *arg)
446 {
447 struct lwip_callback_msg *msg = arg;
448
449 ASSERT(msg);
450
451 if (!msg->Input.Send.Connection->SocketContext)
452 {
453 msg->Output.Send.Error = ERR_CLSD;
454 goto done;
455 }
456
457 if (msg->Input.Send.Connection->SendShutdown)
458 {
459 msg->Output.Send.Error = ERR_CLSD;
460 goto done;
461 }
462
463 msg->Output.Send.Error = tcp_write((PTCP_PCB)msg->Input.Send.Connection->SocketContext,
464 msg->Input.Send.Data,
465 msg->Input.Send.DataLength,
466 TCP_WRITE_FLAG_COPY);
467 if (msg->Output.Send.Error == ERR_MEM)
468 {
469 /* No buffer space so return pending */
470 msg->Output.Send.Error = ERR_INPROGRESS;
471 }
472 else if (msg->Output.Send.Error == ERR_OK)
473 {
474 /* Queued successfully so try to send it */
475 tcp_output((PTCP_PCB)msg->Input.Send.Connection->SocketContext);
476 }
477
478 done:
479 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
480 }
481
482 err_t
483 LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, const int safe)
484 {
485 err_t ret;
486 struct lwip_callback_msg *msg;
487
488 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
489 if (msg)
490 {
491 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
492 msg->Input.Send.Connection = Connection;
493 msg->Input.Send.Data = dataptr;
494 msg->Input.Send.DataLength = len;
495
496 if (safe)
497 LibTCPSendCallback(msg);
498 else
499 tcpip_callback_with_block(LibTCPSendCallback, msg, 1);
500
501 if (WaitForEventSafely(&msg->Event))
502 ret = msg->Output.Send.Error;
503 else
504 ret = ERR_CLSD;
505
506 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
507
508 return ret;
509 }
510
511 return ERR_MEM;
512 }
513
514 static
515 void
516 LibTCPConnectCallback(void *arg)
517 {
518 struct lwip_callback_msg *msg = arg;
519 err_t Error;
520
521 ASSERT(arg);
522
523 if (!msg->Input.Connect.Connection->SocketContext)
524 {
525 msg->Output.Connect.Error = ERR_CLSD;
526 goto done;
527 }
528
529 tcp_recv((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalRecvEventHandler);
530 tcp_sent((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalSendEventHandler);
531
532 Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
533 msg->Input.Connect.IpAddress, ntohs(msg->Input.Connect.Port),
534 InternalConnectEventHandler);
535
536 msg->Output.Connect.Error = Error == ERR_OK ? ERR_INPROGRESS : Error;
537
538 done:
539 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
540 }
541
542 err_t
543 LibTCPConnect(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
544 {
545 struct lwip_callback_msg *msg;
546 err_t ret;
547
548 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
549 if (msg)
550 {
551 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
552 msg->Input.Connect.Connection = Connection;
553 msg->Input.Connect.IpAddress = ipaddr;
554 msg->Input.Connect.Port = port;
555
556 tcpip_callback_with_block(LibTCPConnectCallback, msg, 1);
557
558 if (WaitForEventSafely(&msg->Event))
559 {
560 ret = msg->Output.Connect.Error;
561 }
562 else
563 ret = ERR_CLSD;
564
565 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
566
567 return ret;
568 }
569
570 return ERR_MEM;
571 }
572
573 static
574 void
575 LibTCPShutdownCallback(void *arg)
576 {
577 struct lwip_callback_msg *msg = arg;
578 PTCP_PCB pcb = msg->Input.Shutdown.Connection->SocketContext;
579
580 if (!msg->Input.Shutdown.Connection->SocketContext)
581 {
582 msg->Output.Shutdown.Error = ERR_CLSD;
583 goto done;
584 }
585
586 if (pcb->state == CLOSE_WAIT)
587 {
588 /* This case actually results in a socket closure later (lwIP bug?) */
589 msg->Input.Shutdown.Connection->SocketContext = NULL;
590 }
591
592 msg->Output.Shutdown.Error = tcp_shutdown(pcb, msg->Input.Shutdown.shut_rx, msg->Input.Shutdown.shut_tx);
593 if (msg->Output.Shutdown.Error)
594 {
595 msg->Input.Shutdown.Connection->SocketContext = pcb;
596 }
597 else
598 {
599 if (msg->Input.Shutdown.shut_rx)
600 msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
601
602 if (msg->Input.Shutdown.shut_tx)
603 msg->Input.Shutdown.Connection->SendShutdown = TRUE;
604 }
605
606 done:
607 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
608 }
609
610 err_t
611 LibTCPShutdown(PCONNECTION_ENDPOINT Connection, const int shut_rx, const int shut_tx)
612 {
613 struct lwip_callback_msg *msg;
614 err_t ret;
615
616 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
617 if (msg)
618 {
619 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
620
621 msg->Input.Shutdown.Connection = Connection;
622 msg->Input.Shutdown.shut_rx = shut_rx;
623 msg->Input.Shutdown.shut_tx = shut_tx;
624
625 tcpip_callback_with_block(LibTCPShutdownCallback, msg, 1);
626
627 if (WaitForEventSafely(&msg->Event))
628 ret = msg->Output.Shutdown.Error;
629 else
630 ret = ERR_CLSD;
631
632 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
633
634 return ret;
635 }
636
637 return ERR_MEM;
638 }
639
640 static
641 void
642 LibTCPCloseCallback(void *arg)
643 {
644 struct lwip_callback_msg *msg = arg;
645 PTCP_PCB pcb = msg->Input.Close.Connection->SocketContext;
646
647 /* Empty the queue even if we're already "closed" */
648 LibTCPEmptyQueue(msg->Input.Close.Connection);
649
650 if (!msg->Input.Close.Connection->SocketContext)
651 {
652 msg->Output.Close.Error = ERR_OK;
653 goto done;
654 }
655
656 /* Clear the PCB pointer */
657 msg->Input.Close.Connection->SocketContext = NULL;
658
659 switch (pcb->state)
660 {
661 case CLOSED:
662 case LISTEN:
663 case SYN_SENT:
664 msg->Output.Close.Error = tcp_close(pcb);
665
666 if (!msg->Output.Close.Error && msg->Input.Close.Callback)
667 TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
668 break;
669
670 default:
671 if (msg->Input.Close.Connection->SendShutdown &&
672 msg->Input.Close.Connection->ReceiveShutdown)
673 {
674 /* Abort the connection */
675 tcp_abort(pcb);
676
677 /* Aborts always succeed */
678 msg->Output.Close.Error = ERR_OK;
679 }
680 else
681 {
682 /* Start the graceful close process (or send RST for pending data) */
683 msg->Output.Close.Error = tcp_close(pcb);
684 }
685 break;
686 }
687
688 if (msg->Output.Close.Error)
689 {
690 /* Restore the PCB pointer */
691 msg->Input.Close.Connection->SocketContext = pcb;
692 }
693
694 done:
695 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
696 }
697
698 err_t
699 LibTCPClose(PCONNECTION_ENDPOINT Connection, const int safe, const int callback)
700 {
701 err_t ret;
702 struct lwip_callback_msg *msg;
703
704 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
705 if (msg)
706 {
707 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
708
709 msg->Input.Close.Connection = Connection;
710 msg->Input.Close.Callback = callback;
711
712 if (safe)
713 LibTCPCloseCallback(msg);
714 else
715 tcpip_callback_with_block(LibTCPCloseCallback, msg, 1);
716
717 if (WaitForEventSafely(&msg->Event))
718 ret = msg->Output.Close.Error;
719 else
720 ret = ERR_CLSD;
721
722 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
723
724 return ret;
725 }
726
727 return ERR_MEM;
728 }
729
730 void
731 LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg)
732 {
733 ASSERT(arg);
734
735 tcp_arg(pcb, NULL);
736 tcp_recv(pcb, InternalRecvEventHandler);
737 tcp_sent(pcb, InternalSendEventHandler);
738 tcp_err(pcb, InternalErrorEventHandler);
739 tcp_arg(pcb, arg);
740
741 tcp_accepted(listen_pcb);
742 }
743
744 err_t
745 LibTCPGetHostName(PTCP_PCB pcb, struct ip_addr *const ipaddr, u16_t *const port)
746 {
747 if (!pcb)
748 return ERR_CLSD;
749
750 *ipaddr = pcb->local_ip;
751 *port = pcb->local_port;
752
753 return ERR_OK;
754 }
755
756 err_t
757 LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr * const ipaddr, u16_t * const port)
758 {
759 if (!pcb)
760 return ERR_CLSD;
761
762 *ipaddr = pcb->remote_ip;
763 *port = pcb->remote_port;
764
765 return ERR_OK;
766 }