[TCPIP]
[reactos.git] / reactos / lib / drivers / lwip / src / rostcp.c
1 #include "lwip/sys.h"
2 #include "lwip/tcpip.h"
3
4 #include "rosip.h"
5
6 #include <debug.h>
7
8 static const char * const tcp_state_str[] = {
9 "CLOSED",
10 "LISTEN",
11 "SYN_SENT",
12 "SYN_RCVD",
13 "ESTABLISHED",
14 "FIN_WAIT_1",
15 "FIN_WAIT_2",
16 "CLOSE_WAIT",
17 "CLOSING",
18 "LAST_ACK",
19 "TIME_WAIT"
20 };
21
22 /* The way that lwIP does multi-threading is really not ideal for our purposes but
23 * we best go along with it unless we want another unstable TCP library. lwIP uses
24 * a thread called the "tcpip thread" which is the only one allowed to call raw API
25 * functions. Since this is the case, for each of our LibTCP* functions, we queue a request
26 * for a callback to "tcpip thread" which calls our LibTCP*Callback functions. Yes, this is
27 * a lot of unnecessary thread swapping and it could definitely be faster, but I don't want
28 * to going messing around in lwIP because I have no desire to create another mess like oskittcp */
29
30 extern KEVENT TerminationEvent;
31 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
32 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
33
34 static
35 void
36 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
37 {
38 PLIST_ENTRY Entry;
39 PQUEUE_ENTRY qp = NULL;
40
41 ReferenceObject(Connection);
42
43 while (!IsListEmpty(&Connection->PacketQueue))
44 {
45 Entry = RemoveHeadList(&Connection->PacketQueue);
46 qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
47
48 /* We're in the tcpip thread here so this is safe */
49 pbuf_free(qp->p);
50
51 ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
52 }
53
54 DereferenceObject(Connection);
55 }
56
57 void LibTCPEnqueuePacket(PCONNECTION_ENDPOINT Connection, struct pbuf *p)
58 {
59 PQUEUE_ENTRY qp;
60
61 qp = (PQUEUE_ENTRY)ExAllocateFromNPagedLookasideList(&QueueEntryLookasideList);
62 qp->p = p;
63
64 ExInterlockedInsertTailList(&Connection->PacketQueue, &qp->ListEntry, &Connection->Lock);
65 }
66
67 PQUEUE_ENTRY LibTCPDequeuePacket(PCONNECTION_ENDPOINT Connection)
68 {
69 PLIST_ENTRY Entry;
70 PQUEUE_ENTRY qp = NULL;
71
72 if (IsListEmpty(&Connection->PacketQueue)) return NULL;
73
74 Entry = RemoveHeadList(&Connection->PacketQueue);
75
76 qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
77
78 return qp;
79 }
80
81 NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHAR RecvBuffer, UINT RecvLen, UINT *Received)
82 {
83 PQUEUE_ENTRY qp;
84 struct pbuf* p;
85 NTSTATUS Status = STATUS_PENDING;
86 UINT ReadLength, ExistingDataLength, SpaceLeft;
87 KIRQL OldIrql;
88
89 (*Received) = 0;
90 SpaceLeft = RecvLen;
91
92 LockObject(Connection, &OldIrql);
93
94 if (!IsListEmpty(&Connection->PacketQueue))
95 {
96 while ((qp = LibTCPDequeuePacket(Connection)) != NULL)
97 {
98 p = qp->p;
99 ExistingDataLength = (*Received);
100
101 Status = STATUS_SUCCESS;
102
103 ReadLength = MIN(p->tot_len, SpaceLeft);
104 if (ReadLength != p->tot_len)
105 {
106 if (ExistingDataLength)
107 {
108 /* The packet was too big but we used some data already so give it another shot later */
109 InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
110 break;
111 }
112 else
113 {
114 /* The packet is just too big to fit fully in our buffer, even when empty so
115 * return an informative status but still copy all the data we can fit.
116 */
117 Status = STATUS_BUFFER_OVERFLOW;
118 }
119 }
120
121 UnlockObject(Connection, OldIrql);
122
123 /* Return to a lower IRQL because the receive buffer may be pageable memory */
124 for (; (*Received) < ReadLength + ExistingDataLength; (*Received) += p->len, p = p->next)
125 {
126 RtlCopyMemory(RecvBuffer + (*Received), p->payload, p->len);
127 }
128
129 LockObject(Connection, &OldIrql);
130
131 SpaceLeft -= ReadLength;
132
133 /* Use this special pbuf free callback function because we're outside tcpip thread */
134 pbuf_free_callback(qp->p);
135
136 ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
137
138 if (!RecvLen)
139 break;
140
141 if (Status != STATUS_SUCCESS)
142 break;
143 }
144 }
145 else
146 {
147 if (Connection->ReceiveShutdown)
148 Status = STATUS_SUCCESS;
149 else
150 Status = STATUS_PENDING;
151 }
152
153 UnlockObject(Connection, OldIrql);
154
155 return Status;
156 }
157
158 static
159 BOOLEAN
160 WaitForEventSafely(PRKEVENT Event)
161 {
162 PVOID WaitObjects[] = {Event, &TerminationEvent};
163
164 if (KeWaitForMultipleObjects(2,
165 WaitObjects,
166 WaitAny,
167 Executive,
168 KernelMode,
169 FALSE,
170 NULL,
171 NULL) == STATUS_WAIT_0)
172 {
173 /* Signalled by the caller's event */
174 return TRUE;
175 }
176 else /* if KeWaitForMultipleObjects() == STATUS_WAIT_1 */
177 {
178 /* Signalled by our termination event */
179 return FALSE;
180 }
181 }
182
183 static
184 err_t
185 InternalSendEventHandler(void *arg, PTCP_PCB pcb, const u16_t space)
186 {
187 /* Make sure the socket didn't get closed */
188 if (!arg) return ERR_OK;
189
190 TCPSendEventHandler(arg, space);
191
192 return ERR_OK;
193 }
194
195 static
196 err_t
197 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
198 {
199 PCONNECTION_ENDPOINT Connection = arg;
200 u32_t len;
201
202 /* Make sure the socket didn't get closed */
203 if (!arg)
204 {
205 if (p)
206 pbuf_free(p);
207
208 return ERR_OK;
209 }
210
211 if (p)
212 {
213 len = TCPRecvEventHandler(arg, p);
214 if (len == p->tot_len)
215 {
216 tcp_recved(pcb, len);
217
218 pbuf_free(p);
219
220 return ERR_OK;
221 }
222 else if (len != 0)
223 {
224 DbgPrint("UNTESTED CASE: NOT ALL DATA TAKEN! EXTRA DATA MAY BE LOST!\n");
225
226 tcp_recved(pcb, len);
227
228 /* Possible memory leak of pbuf here? */
229
230 return ERR_OK;
231 }
232 else
233 {
234 LibTCPEnqueuePacket(Connection, p);
235
236 tcp_recved(pcb, p->tot_len);
237
238 return ERR_OK;
239 }
240 }
241 else if (err == ERR_OK)
242 {
243 /* Complete pending reads with 0 bytes to indicate a graceful closure,
244 * but note that send is still possible in this state so we don't close the
245 * whole socket here (by calling tcp_close()) as that would violate TCP specs
246 */
247 Connection->ReceiveShutdown = TRUE;
248 TCPFinEventHandler(arg, ERR_OK);
249 }
250
251 return ERR_OK;
252 }
253
254 static
255 err_t
256 InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
257 {
258 /* Make sure the socket didn't get closed */
259 if (!arg)
260 return ERR_ABRT;
261
262 TCPAcceptEventHandler(arg, newpcb);
263
264 /* Set in LibTCPAccept (called from TCPAcceptEventHandler) */
265 if (newpcb->callback_arg)
266 return ERR_OK;
267 else
268 return ERR_ABRT;
269 }
270
271 static
272 err_t
273 InternalConnectEventHandler(void *arg, PTCP_PCB pcb, const err_t err)
274 {
275 /* Make sure the socket didn't get closed */
276 if (!arg)
277 return ERR_OK;
278
279 TCPConnectEventHandler(arg, err);
280
281 return ERR_OK;
282 }
283
284 static
285 void
286 InternalErrorEventHandler(void *arg, const err_t err)
287 {
288 /* Make sure the socket didn't get closed */
289 if (!arg) return;
290
291 TCPFinEventHandler(arg, err);
292 }
293
294 static
295 void
296 LibTCPSocketCallback(void *arg)
297 {
298 struct lwip_callback_msg *msg = arg;
299
300 ASSERT(msg);
301
302 msg->Output.Socket.NewPcb = tcp_new();
303
304 if (msg->Output.Socket.NewPcb)
305 {
306 tcp_arg(msg->Output.Socket.NewPcb, msg->Input.Socket.Arg);
307 tcp_err(msg->Output.Socket.NewPcb, InternalErrorEventHandler);
308 }
309
310 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
311 }
312
313 struct tcp_pcb *
314 LibTCPSocket(void *arg)
315 {
316 struct lwip_callback_msg *msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
317 struct tcp_pcb *ret;
318
319 if (msg)
320 {
321 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
322 msg->Input.Socket.Arg = arg;
323
324 tcpip_callback_with_block(LibTCPSocketCallback, msg, 1);
325
326 if (WaitForEventSafely(&msg->Event))
327 ret = msg->Output.Socket.NewPcb;
328 else
329 ret = NULL;
330
331 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
332
333 return ret;
334 }
335
336 return NULL;
337 }
338
339 static
340 void
341 LibTCPBindCallback(void *arg)
342 {
343 struct lwip_callback_msg *msg = arg;
344
345 ASSERT(msg);
346
347 if (!msg->Input.Bind.Connection->SocketContext)
348 {
349 msg->Output.Bind.Error = ERR_CLSD;
350 goto done;
351 }
352
353 msg->Output.Bind.Error = tcp_bind((PTCP_PCB)msg->Input.Bind.Connection->SocketContext,
354 msg->Input.Bind.IpAddress,
355 ntohs(msg->Input.Bind.Port));
356
357 done:
358 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
359 }
360
361 err_t
362 LibTCPBind(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
363 {
364 struct lwip_callback_msg *msg;
365 err_t ret;
366
367 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
368 if (msg)
369 {
370 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
371 msg->Input.Bind.Connection = Connection;
372 msg->Input.Bind.IpAddress = ipaddr;
373 msg->Input.Bind.Port = port;
374
375 tcpip_callback_with_block(LibTCPBindCallback, msg, 1);
376
377 if (WaitForEventSafely(&msg->Event))
378 ret = msg->Output.Bind.Error;
379 else
380 ret = ERR_CLSD;
381
382 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
383
384 return ret;
385 }
386
387 return ERR_MEM;
388 }
389
390 static
391 void
392 LibTCPListenCallback(void *arg)
393 {
394 struct lwip_callback_msg *msg = arg;
395
396 ASSERT(msg);
397
398 if (!msg->Input.Listen.Connection->SocketContext)
399 {
400 msg->Output.Listen.NewPcb = NULL;
401 goto done;
402 }
403
404 msg->Output.Listen.NewPcb = tcp_listen_with_backlog((PTCP_PCB)msg->Input.Listen.Connection->SocketContext, msg->Input.Listen.Backlog);
405
406 if (msg->Output.Listen.NewPcb)
407 {
408 tcp_accept(msg->Output.Listen.NewPcb, InternalAcceptEventHandler);
409 }
410
411 done:
412 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
413 }
414
415 PTCP_PCB
416 LibTCPListen(PCONNECTION_ENDPOINT Connection, const u8_t backlog)
417 {
418 struct lwip_callback_msg *msg;
419 PTCP_PCB ret;
420
421 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
422 if (msg)
423 {
424 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
425 msg->Input.Listen.Connection = Connection;
426 msg->Input.Listen.Backlog = backlog;
427
428 tcpip_callback_with_block(LibTCPListenCallback, msg, 1);
429
430 if (WaitForEventSafely(&msg->Event))
431 ret = msg->Output.Listen.NewPcb;
432 else
433 ret = NULL;
434
435 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
436
437 return ret;
438 }
439
440 return NULL;
441 }
442
443 static
444 void
445 LibTCPSendCallback(void *arg)
446 {
447 struct lwip_callback_msg *msg = arg;
448
449 ASSERT(msg);
450
451 if (!msg->Input.Send.Connection->SocketContext)
452 {
453 msg->Output.Send.Error = ERR_CLSD;
454 goto done;
455 }
456
457 if (msg->Input.Send.Connection->SendShutdown)
458 {
459 msg->Output.Send.Error = ERR_CLSD;
460 goto done;
461 }
462
463 msg->Output.Send.Error = tcp_write((PTCP_PCB)msg->Input.Send.Connection->SocketContext,
464 msg->Input.Send.Data,
465 msg->Input.Send.DataLength,
466 TCP_WRITE_FLAG_COPY);
467 if (msg->Output.Send.Error == ERR_MEM)
468 {
469 /* No buffer space so return pending */
470 msg->Output.Send.Error = ERR_INPROGRESS;
471 }
472 else if (msg->Output.Send.Error == ERR_OK)
473 {
474 /* Queued successfully so try to send it */
475 tcp_output((PTCP_PCB)msg->Input.Send.Connection->SocketContext);
476 }
477
478 done:
479 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
480 }
481
482 err_t
483 LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, const int safe)
484 {
485 err_t ret;
486 struct lwip_callback_msg *msg;
487
488 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
489 if (msg)
490 {
491 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
492 msg->Input.Send.Connection = Connection;
493 msg->Input.Send.Data = dataptr;
494 msg->Input.Send.DataLength = len;
495
496 if (safe)
497 LibTCPSendCallback(msg);
498 else
499 tcpip_callback_with_block(LibTCPSendCallback, msg, 1);
500
501 if (WaitForEventSafely(&msg->Event))
502 ret = msg->Output.Send.Error;
503 else
504 ret = ERR_CLSD;
505
506 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
507
508 return ret;
509 }
510
511 return ERR_MEM;
512 }
513
514 static
515 void
516 LibTCPConnectCallback(void *arg)
517 {
518 struct lwip_callback_msg *msg = arg;
519
520 ASSERT(arg);
521
522 if (!msg->Input.Connect.Connection->SocketContext)
523 {
524 msg->Output.Connect.Error = ERR_CLSD;
525 goto done;
526 }
527
528 tcp_recv((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalRecvEventHandler);
529 tcp_sent((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalSendEventHandler);
530
531 err_t Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
532 msg->Input.Connect.IpAddress, ntohs(msg->Input.Connect.Port),
533 InternalConnectEventHandler);
534
535 msg->Output.Connect.Error = Error == ERR_OK ? ERR_INPROGRESS : Error;
536
537 done:
538 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
539 }
540
541 err_t
542 LibTCPConnect(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
543 {
544 struct lwip_callback_msg *msg;
545 err_t ret;
546
547 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
548 if (msg)
549 {
550 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
551 msg->Input.Connect.Connection = Connection;
552 msg->Input.Connect.IpAddress = ipaddr;
553 msg->Input.Connect.Port = port;
554
555 tcpip_callback_with_block(LibTCPConnectCallback, msg, 1);
556
557 if (WaitForEventSafely(&msg->Event))
558 {
559 ret = msg->Output.Connect.Error;
560 }
561 else
562 ret = ERR_CLSD;
563
564 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
565
566 return ret;
567 }
568
569 return ERR_MEM;
570 }
571
572 static
573 void
574 LibTCPShutdownCallback(void *arg)
575 {
576 struct lwip_callback_msg *msg = arg;
577 PTCP_PCB pcb = msg->Input.Shutdown.Connection->SocketContext;
578
579 if (!msg->Input.Shutdown.Connection->SocketContext)
580 {
581 msg->Output.Shutdown.Error = ERR_CLSD;
582 goto done;
583 }
584
585 if (pcb->state == CLOSE_WAIT)
586 {
587 /* This case actually results in a socket closure later (lwIP bug?) */
588 msg->Input.Shutdown.Connection->SocketContext = NULL;
589 }
590
591 msg->Output.Shutdown.Error = tcp_shutdown(pcb, msg->Input.Shutdown.shut_rx, msg->Input.Shutdown.shut_tx);
592 if (msg->Output.Shutdown.Error)
593 {
594 msg->Input.Shutdown.Connection->SocketContext = pcb;
595 }
596 else
597 {
598 if (msg->Input.Shutdown.shut_rx)
599 msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
600
601 if (msg->Input.Shutdown.shut_tx)
602 msg->Input.Shutdown.Connection->SendShutdown = TRUE;
603 }
604
605 done:
606 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
607 }
608
609 err_t
610 LibTCPShutdown(PCONNECTION_ENDPOINT Connection, const int shut_rx, const int shut_tx)
611 {
612 struct lwip_callback_msg *msg;
613 err_t ret;
614
615 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
616 if (msg)
617 {
618 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
619
620 msg->Input.Shutdown.Connection = Connection;
621 msg->Input.Shutdown.shut_rx = shut_rx;
622 msg->Input.Shutdown.shut_tx = shut_tx;
623
624 tcpip_callback_with_block(LibTCPShutdownCallback, msg, 1);
625
626 if (WaitForEventSafely(&msg->Event))
627 ret = msg->Output.Shutdown.Error;
628 else
629 ret = ERR_CLSD;
630
631 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
632
633 return ret;
634 }
635
636 return ERR_MEM;
637 }
638
639 static
640 void
641 LibTCPCloseCallback(void *arg)
642 {
643 struct lwip_callback_msg *msg = arg;
644 PTCP_PCB pcb = msg->Input.Close.Connection->SocketContext;
645 int state;
646
647 /* Empty the queue even if we're already "closed" */
648 LibTCPEmptyQueue(msg->Input.Close.Connection);
649
650 if (!msg->Input.Close.Connection->SocketContext)
651 {
652 msg->Output.Close.Error = ERR_OK;
653 goto done;
654 }
655
656 /* Clear the PCB pointer */
657 msg->Input.Close.Connection->SocketContext = NULL;
658
659 /* Save the old PCB state */
660 state = pcb->state;
661
662 msg->Output.Close.Error = tcp_close(pcb);
663 if (!msg->Output.Close.Error)
664 {
665 if (msg->Input.Close.Callback)
666 {
667 /* Call the FIN handler in the cases where it will not be called by lwIP */
668 switch (state)
669 {
670 case CLOSED:
671 case LISTEN:
672 case SYN_SENT:
673 TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
674 break;
675
676 default:
677 break;
678 }
679 }
680 }
681 else
682 {
683 /* Restore the PCB pointer */
684 msg->Input.Close.Connection->SocketContext = pcb;
685 }
686
687 done:
688 KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
689 }
690
691 err_t
692 LibTCPClose(PCONNECTION_ENDPOINT Connection, const int safe, const int callback)
693 {
694 err_t ret;
695 struct lwip_callback_msg *msg;
696
697 msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
698 if (msg)
699 {
700 KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
701
702 msg->Input.Close.Connection = Connection;
703 msg->Input.Close.Callback = callback;
704
705 if (safe)
706 LibTCPCloseCallback(msg);
707 else
708 tcpip_callback_with_block(LibTCPCloseCallback, msg, 1);
709
710 if (WaitForEventSafely(&msg->Event))
711 ret = msg->Output.Close.Error;
712 else
713 ret = ERR_CLSD;
714
715 ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
716
717 return ret;
718 }
719
720 return ERR_MEM;
721 }
722
723 void
724 LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg)
725 {
726 ASSERT(arg);
727
728 tcp_arg(pcb, NULL);
729 tcp_recv(pcb, InternalRecvEventHandler);
730 tcp_sent(pcb, InternalSendEventHandler);
731 tcp_err(pcb, InternalErrorEventHandler);
732 tcp_arg(pcb, arg);
733
734 tcp_accepted(listen_pcb);
735 }
736
737 err_t
738 LibTCPGetHostName(PTCP_PCB pcb, struct ip_addr *const ipaddr, u16_t *const port)
739 {
740 if (!pcb)
741 return ERR_CLSD;
742
743 *ipaddr = pcb->local_ip;
744 *port = pcb->local_port;
745
746 return ERR_OK;
747 }
748
749 err_t
750 LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr * const ipaddr, u16_t * const port)
751 {
752 if (!pcb)
753 return ERR_CLSD;
754
755 *ipaddr = pcb->remote_ip;
756 *port = pcb->remote_port;
757
758 return ERR_OK;
759 }