[WSHTCPIP]
[reactos.git] / reactos / lib / drivers / lwip / src / rostcp.c
index 005f62e..ea06d75 100755 (executable)
@@ -1,4 +1,5 @@
 #include "lwip/sys.h"
+#include "lwip/netif.h"
 #include "lwip/tcpip.h"
 
 #include "rosip.h"
@@ -31,6 +32,24 @@ extern KEVENT TerminationEvent;
 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
 
+/* Required for ERR_T to NTSTATUS translation in receive error handling */
+NTSTATUS TCPTranslateError(const err_t err);
+
+void
+LibTCPDumpPcb(PVOID SocketContext)
+{
+    struct tcp_pcb *pcb = (struct tcp_pcb*)SocketContext;
+    unsigned int addr = ntohl(pcb->remote_ip.addr);
+
+    DbgPrint("\tState: %s\n", tcp_state_str[pcb->state]);
+    DbgPrint("\tRemote: (%d.%d.%d.%d, %d)\n",
+    (addr >> 24) & 0xFF,
+    (addr >> 16) & 0xFF,
+    (addr >> 8) & 0xFF,
+    addr & 0xFF,
+    pcb->remote_port);
+}
+
 static
 void
 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
@@ -84,9 +103,8 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
     PQUEUE_ENTRY qp;
     struct pbuf* p;
     NTSTATUS Status;
-    UINT ReadLength, PayloadLength;
+    UINT ReadLength, PayloadLength, Offset, Copied;
     KIRQL OldIrql;
-    PUCHAR Payload;
 
     (*Received) = 0;
 
@@ -98,14 +116,14 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
         {
             p = qp->p;
 
-            /* Calculate the payload first */
-            Payload = p->payload;
-            Payload += qp->Offset;
-            PayloadLength = p->len;
+            /* Calculate the payload length first */
+            PayloadLength = p->tot_len;
             PayloadLength -= qp->Offset;
+            Offset = qp->Offset;
 
             /* Check if we're reading the whole buffer */
             ReadLength = MIN(PayloadLength, RecvLen);
+            ASSERT(ReadLength != 0);
             if (ReadLength != PayloadLength)
             {
                 /* Save this one for later */
@@ -116,10 +134,8 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
 
             UnlockObject(Connection, OldIrql);
 
-            /* Return to a lower IRQL because the receive buffer may be pageable memory */
-            RtlCopyMemory(RecvBuffer,
-                          Payload,
-                          ReadLength);
+            Copied = pbuf_copy_partial(p, RecvBuffer, ReadLength, Offset);
+            ASSERT(Copied == ReadLength);
 
             LockObject(Connection, &OldIrql);
 
@@ -141,6 +157,7 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
                 ASSERT(RecvLen == 0);
             }
 
+            ASSERT((*Received) != 0);
             Status = STATUS_SUCCESS;
 
             if (!RecvLen)
@@ -150,7 +167,7 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
     else
     {
         if (Connection->ReceiveShutdown)
-            Status = STATUS_SUCCESS;
+            Status = Connection->ReceiveShutdownStatus;
         else
             Status = STATUS_PENDING;
     }
@@ -202,8 +219,6 @@ err_t
 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
 {
     PCONNECTION_ENDPOINT Connection = arg;
-    struct pbuf *pb;
-    ULONG RecvLen;
 
     /* Make sure the socket didn't get closed */
     if (!arg)
@@ -216,19 +231,9 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
 
     if (p)
     {
-        pb = p;
-        RecvLen = 0;
-        while (pb != NULL)
-        {
-            /* Enqueue this buffer */
-            LibTCPEnqueuePacket(Connection, pb);
-            RecvLen += pb->len;
+        LibTCPEnqueuePacket(Connection, p);
 
-            /* Advance and unchain the buffer */
-            pb = pbuf_dechain(pb);;
-        }
-
-        tcp_recved(pcb, RecvLen);
+        tcp_recved(pcb, p->tot_len);
 
         TCPRecvEventHandler(arg);
     }
@@ -239,19 +244,37 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
          * whole socket here (by calling tcp_close()) as that would violate TCP specs
          */
         Connection->ReceiveShutdown = TRUE;
-        TCPFinEventHandler(arg, ERR_OK);
+        Connection->ReceiveShutdownStatus = STATUS_SUCCESS;
+
+        /* If we already did a send shutdown, we're in TIME_WAIT so we can't use this PCB anymore */
+        if (Connection->SendShutdown)
+        {
+            Connection->SocketContext = NULL;
+            tcp_arg(pcb, NULL);
+        }
+
+        /* Indicate the graceful close event */
+        TCPRecvEventHandler(arg);
+
+        /* If the PCB is gone, clean up the connection */
+        if (Connection->SendShutdown)
+        {
+            TCPFinEventHandler(Connection, ERR_CLSD);
+        }
     }
 
     return ERR_OK;
 }
 
+/* This function MUST return an error value that is not ERR_ABRT or ERR_OK if the connection
+ * is not accepted to avoid leaking the new PCB */
 static
 err_t
 InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
 {
     /* Make sure the socket didn't get closed */
     if (!arg)
-        return ERR_ABRT;
+        return ERR_CLSD;
 
     TCPAcceptEventHandler(arg, newpcb);
 
@@ -259,7 +282,7 @@ InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
     if (newpcb->callback_arg)
         return ERR_OK;
     else
-        return ERR_ABRT;
+        return ERR_CLSD;
 }
 
 static
@@ -279,10 +302,21 @@ static
 void
 InternalErrorEventHandler(void *arg, const err_t err)
 {
+    PCONNECTION_ENDPOINT Connection = arg;
+
     /* Make sure the socket didn't get closed */
-    if (!arg) return;
+    if (!arg || Connection->SocketContext == NULL) return;
+
+    /* The PCB is dead now */
+    Connection->SocketContext = NULL;
 
-    TCPFinEventHandler(arg, err);
+    /* Give them one shot to receive the remaining data */
+    Connection->ReceiveShutdown = TRUE;
+    Connection->ReceiveShutdownStatus = TCPTranslateError(err);
+    TCPRecvEventHandler(Connection);
+
+    /* Terminate the connection */
+    TCPFinEventHandler(Connection, err);
 }
 
 static
@@ -607,24 +641,38 @@ LibTCPShutdownCallback(void *arg)
         goto done;
     }
 
-    if (pcb->state == CLOSE_WAIT)
-    {
-        /* This case actually results in a socket closure later (lwIP bug?) */
-        msg->Input.Shutdown.Connection->SocketContext = NULL;
+    /* LwIP makes the (questionable) assumption that SHUTDOWN_RDWR is equivalent to tcp_close().
+     * This assumption holds even if the shutdown calls are done separately (even through multiple
+     * WinSock shutdown() calls). This assumption means that lwIP has the right to deallocate our
+     * PCB without telling us if we shutdown TX and RX. To avoid these problems, we'll clear the
+     * socket context if we have called shutdown for TX and RX.
+     */
+    if (msg->Input.Shutdown.shut_rx) {
+        msg->Output.Shutdown.Error = tcp_shutdown(pcb, TRUE, FALSE);
     }
-
-    msg->Output.Shutdown.Error = tcp_shutdown(pcb, msg->Input.Shutdown.shut_rx, msg->Input.Shutdown.shut_tx);
-    if (msg->Output.Shutdown.Error)
-    {
-        msg->Input.Shutdown.Connection->SocketContext = pcb;
+    if (msg->Input.Shutdown.shut_tx) {
+        msg->Output.Shutdown.Error = tcp_shutdown(pcb, FALSE, TRUE);
     }
-    else
+
+    if (!msg->Output.Shutdown.Error)
     {
         if (msg->Input.Shutdown.shut_rx)
+        {
             msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
+            msg->Input.Shutdown.Connection->ReceiveShutdownStatus = STATUS_FILE_CLOSED;
+        }
 
         if (msg->Input.Shutdown.shut_tx)
             msg->Input.Shutdown.Connection->SendShutdown = TRUE;
+
+        if (msg->Input.Shutdown.Connection->ReceiveShutdown &&
+            msg->Input.Shutdown.Connection->SendShutdown)
+        {
+            /* The PCB is not ours anymore */
+            msg->Input.Shutdown.Connection->SocketContext = NULL;
+            tcp_arg(pcb, NULL);
+            TCPFinEventHandler(msg->Input.Shutdown.Connection, ERR_CLSD);
+        }
     }
 
 done:
@@ -671,48 +719,41 @@ LibTCPCloseCallback(void *arg)
     /* Empty the queue even if we're already "closed" */
     LibTCPEmptyQueue(msg->Input.Close.Connection);
 
-    if (!msg->Input.Close.Connection->SocketContext)
+    /* Check if we've already been closed */
+    if (msg->Input.Close.Connection->Closing)
     {
         msg->Output.Close.Error = ERR_OK;
         goto done;
     }
 
-    /* Clear the PCB pointer */
-    msg->Input.Close.Connection->SocketContext = NULL;
+    /* Enter "closing" mode if we're doing a normal close */
+    if (msg->Input.Close.Callback)
+        msg->Input.Close.Connection->Closing = TRUE;
 
-    switch (pcb->state)
+    /* Check if the PCB was already "closed" but the client doesn't know it yet */
+    if (!msg->Input.Close.Connection->SocketContext)
     {
-        case CLOSED:
-        case LISTEN:
-        case SYN_SENT:
-           msg->Output.Close.Error = tcp_close(pcb);
-
-           if (!msg->Output.Close.Error && msg->Input.Close.Callback)
-               TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
-           break;
+        msg->Output.Close.Error = ERR_OK;
+        goto done;
+    }
 
-        default:
-           if (msg->Input.Close.Connection->SendShutdown &&
-               msg->Input.Close.Connection->ReceiveShutdown)
-           {
-               /* Abort the connection */
-               tcp_abort(pcb);
+    /* Clear the PCB pointer and stop callbacks */
+    msg->Input.Close.Connection->SocketContext = NULL;
+    tcp_arg(pcb, NULL);
 
-               /* Aborts always succeed */
-               msg->Output.Close.Error = ERR_OK;
-           }
-           else
-           {
-               /* Start the graceful close process (or send RST for pending data) */
-               msg->Output.Close.Error = tcp_close(pcb);
-           }
-           break;
-    }
+    /* This may generate additional callbacks but we don't care,
+     * because they're too inconsistent to rely on */
+    msg->Output.Close.Error = tcp_close(pcb);
 
     if (msg->Output.Close.Error)
     {
         /* Restore the PCB pointer */
         msg->Input.Close.Connection->SocketContext = pcb;
+        msg->Input.Close.Connection->Closing = FALSE;
+    }
+    else if (msg->Input.Close.Callback)
+    {
+        TCPFinEventHandler(msg->Input.Close.Connection, ERR_CLSD);
     }
 
 done:
@@ -788,3 +829,14 @@ LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr * const ipaddr, u16_t * const por
 
     return ERR_OK;
 }
+
+void
+LibTCPSetNoDelay(
+    PTCP_PCB pcb,
+    BOOLEAN Set)
+{
+    if (Set)
+        pcb->flags |= TF_NODELAY;
+    else
+        pcb->flags &= ~TF_NODELAY;
+}