[WSHTCPIP]
[reactos.git] / reactos / lib / drivers / lwip / src / rostcp.c
index ce51acb..ea06d75 100755 (executable)
@@ -1,4 +1,5 @@
 #include "lwip/sys.h"
+#include "lwip/netif.h"
 #include "lwip/tcpip.h"
 
 #include "rosip.h"
@@ -6,17 +7,17 @@
 #include <debug.h>
 
 static const char * const tcp_state_str[] = {
-  "CLOSED",      
-  "LISTEN",      
-  "SYN_SENT",    
-  "SYN_RCVD",    
-  "ESTABLISHED", 
-  "FIN_WAIT_1",  
-  "FIN_WAIT_2",  
-  "CLOSE_WAIT",  
-  "CLOSING",     
-  "LAST_ACK",    
-  "TIME_WAIT"   
+  "CLOSED",
+  "LISTEN",
+  "SYN_SENT",
+  "SYN_RCVD",
+  "ESTABLISHED",
+  "FIN_WAIT_1",
+  "FIN_WAIT_2",
+  "CLOSE_WAIT",
+  "CLOSING",
+  "LAST_ACK",
+  "TIME_WAIT"
 };
 
 /* The way that lwIP does multi-threading is really not ideal for our purposes but
@@ -31,6 +32,24 @@ extern KEVENT TerminationEvent;
 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
 
+/* Required for ERR_T to NTSTATUS translation in receive error handling */
+NTSTATUS TCPTranslateError(const err_t err);
+
+void
+LibTCPDumpPcb(PVOID SocketContext)
+{
+    struct tcp_pcb *pcb = (struct tcp_pcb*)SocketContext;
+    unsigned int addr = ntohl(pcb->remote_ip.addr);
+
+    DbgPrint("\tState: %s\n", tcp_state_str[pcb->state]);
+    DbgPrint("\tRemote: (%d.%d.%d.%d, %d)\n",
+    (addr >> 24) & 0xFF,
+    (addr >> 16) & 0xFF,
+    (addr >> 8) & 0xFF,
+    addr & 0xFF,
+    pcb->remote_port);
+}
+
 static
 void
 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
@@ -39,7 +58,7 @@ LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
     PQUEUE_ENTRY qp = NULL;
 
     ReferenceObject(Connection);
-    
+
     while (!IsListEmpty(&Connection->PacketQueue))
     {
         Entry = RemoveHeadList(&Connection->PacketQueue);
@@ -60,6 +79,7 @@ void LibTCPEnqueuePacket(PCONNECTION_ENDPOINT Connection, struct pbuf *p)
 
     qp = (PQUEUE_ENTRY)ExAllocateFromNPagedLookasideList(&QueueEntryLookasideList);
     qp->p = p;
+    qp->Offset = 0;
 
     ExInterlockedInsertTailList(&Connection->PacketQueue, &qp->ListEntry, &Connection->Lock);
 }
@@ -72,7 +92,7 @@ PQUEUE_ENTRY LibTCPDequeuePacket(PCONNECTION_ENDPOINT Connection)
     if (IsListEmpty(&Connection->PacketQueue)) return NULL;
 
     Entry = RemoveHeadList(&Connection->PacketQueue);
-    
+
     qp = CONTAINING_RECORD(Entry, QUEUE_ENTRY, ListEntry);
 
     return qp;
@@ -82,12 +102,11 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
 {
     PQUEUE_ENTRY qp;
     struct pbuf* p;
-    NTSTATUS Status = STATUS_PENDING;
-    UINT ReadLength, ExistingDataLength, SpaceLeft;
+    NTSTATUS Status;
+    UINT ReadLength, PayloadLength, Offset, Copied;
     KIRQL OldIrql;
 
     (*Received) = 0;
-    SpaceLeft = RecvLen;
 
     LockObject(Connection, &OldIrql);
 
@@ -96,56 +115,59 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
         while ((qp = LibTCPDequeuePacket(Connection)) != NULL)
         {
             p = qp->p;
-            ExistingDataLength = (*Received);
 
-            Status = STATUS_SUCCESS;
+            /* Calculate the payload length first */
+            PayloadLength = p->tot_len;
+            PayloadLength -= qp->Offset;
+            Offset = qp->Offset;
 
-            ReadLength = MIN(p->tot_len, SpaceLeft);
-            if (ReadLength != p->tot_len)
+            /* Check if we're reading the whole buffer */
+            ReadLength = MIN(PayloadLength, RecvLen);
+            ASSERT(ReadLength != 0);
+            if (ReadLength != PayloadLength)
             {
-                if (ExistingDataLength)
-                {
-                    /* The packet was too big but we used some data already so give it another shot later */
-                    InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
-                    break;
-                }
-                else
-                {
-                    /* The packet is just too big to fit fully in our buffer, even when empty so
-                     * return an informative status but still copy all the data we can fit.
-                     */
-                    Status = STATUS_BUFFER_OVERFLOW;
-                }
+                /* Save this one for later */
+                qp->Offset += ReadLength;
+                InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
+                qp = NULL;
             }
 
             UnlockObject(Connection, OldIrql);
 
-            /* Return to a lower IRQL because the receive buffer may be pageable memory */
-            for (; (*Received) < ReadLength + ExistingDataLength; (*Received) += p->len, p = p->next)
-            {
-                RtlCopyMemory(RecvBuffer + (*Received), p->payload, p->len);
-            }
+            Copied = pbuf_copy_partial(p, RecvBuffer, ReadLength, Offset);
+            ASSERT(Copied == ReadLength);
 
             LockObject(Connection, &OldIrql);
 
-            SpaceLeft -= ReadLength;
+            /* Update trackers */
+            RecvLen -= ReadLength;
+            RecvBuffer += ReadLength;
+            (*Received) += ReadLength;
 
-            /* Use this special pbuf free callback function because we're outside tcpip thread */
-            pbuf_free_callback(qp->p);
+            if (qp != NULL)
+            {
+                /* Use this special pbuf free callback function because we're outside tcpip thread */
+                pbuf_free_callback(qp->p);
 
-            ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
+                ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
+            }
+            else
+            {
+                /* If we get here, it means we've filled the buffer */
+                ASSERT(RecvLen == 0);
+            }
 
-            if (!RecvLen)
-                break;
+            ASSERT((*Received) != 0);
+            Status = STATUS_SUCCESS;
 
-            if (Status != STATUS_SUCCESS)
+            if (!RecvLen)
                 break;
         }
     }
     else
     {
         if (Connection->ReceiveShutdown)
-            Status = STATUS_SUCCESS;
+            Status = Connection->ReceiveShutdownStatus;
         else
             Status = STATUS_PENDING;
     }
@@ -160,7 +182,7 @@ BOOLEAN
 WaitForEventSafely(PRKEVENT Event)
 {
     PVOID WaitObjects[] = {Event, &TerminationEvent};
-    
+
     if (KeWaitForMultipleObjects(2,
                                  WaitObjects,
                                  WaitAny,
@@ -186,9 +208,9 @@ InternalSendEventHandler(void *arg, PTCP_PCB pcb, const u16_t space)
 {
     /* Make sure the socket didn't get closed */
     if (!arg) return ERR_OK;
-    
+
     TCPSendEventHandler(arg, space);
-    
+
     return ERR_OK;
 }
 
@@ -197,7 +219,6 @@ err_t
 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
 {
     PCONNECTION_ENDPOINT Connection = arg;
-    u32_t len;
 
     /* Make sure the socket didn't get closed */
     if (!arg)
@@ -207,36 +228,14 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
 
         return ERR_OK;
     }
-    
+
     if (p)
     {
-        len = TCPRecvEventHandler(arg, p);
-        if (len == p->tot_len)
-        {
-            tcp_recved(pcb, len);
+        LibTCPEnqueuePacket(Connection, p);
 
-            pbuf_free(p);
+        tcp_recved(pcb, p->tot_len);
 
-            return ERR_OK;
-        }
-        else if (len != 0)
-        {
-            DbgPrint("UNTESTED CASE: NOT ALL DATA TAKEN! EXTRA DATA MAY BE LOST!\n");
-            
-            tcp_recved(pcb, len);
-            
-            /* Possible memory leak of pbuf here? */
-            
-            return ERR_OK;
-        }
-        else
-        {
-            LibTCPEnqueuePacket(Connection, p);
-
-            tcp_recved(pcb, p->tot_len);
-
-            return ERR_OK;
-        }
+        TCPRecvEventHandler(arg);
     }
     else if (err == ERR_OK)
     {
@@ -245,19 +244,37 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
          * whole socket here (by calling tcp_close()) as that would violate TCP specs
          */
         Connection->ReceiveShutdown = TRUE;
-        TCPFinEventHandler(arg, ERR_OK);
+        Connection->ReceiveShutdownStatus = STATUS_SUCCESS;
+
+        /* If we already did a send shutdown, we're in TIME_WAIT so we can't use this PCB anymore */
+        if (Connection->SendShutdown)
+        {
+            Connection->SocketContext = NULL;
+            tcp_arg(pcb, NULL);
+        }
+
+        /* Indicate the graceful close event */
+        TCPRecvEventHandler(arg);
+
+        /* If the PCB is gone, clean up the connection */
+        if (Connection->SendShutdown)
+        {
+            TCPFinEventHandler(Connection, ERR_CLSD);
+        }
     }
 
     return ERR_OK;
 }
 
+/* This function MUST return an error value that is not ERR_ABRT or ERR_OK if the connection
+ * is not accepted to avoid leaking the new PCB */
 static
 err_t
 InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
 {
     /* Make sure the socket didn't get closed */
     if (!arg)
-        return ERR_ABRT;
+        return ERR_CLSD;
 
     TCPAcceptEventHandler(arg, newpcb);
 
@@ -265,7 +282,7 @@ InternalAcceptEventHandler(void *arg, PTCP_PCB newpcb, const err_t err)
     if (newpcb->callback_arg)
         return ERR_OK;
     else
-        return ERR_ABRT;
+        return ERR_CLSD;
 }
 
 static
@@ -285,10 +302,21 @@ static
 void
 InternalErrorEventHandler(void *arg, const err_t err)
 {
+    PCONNECTION_ENDPOINT Connection = arg;
+
     /* Make sure the socket didn't get closed */
-    if (!arg) return;
+    if (!arg || Connection->SocketContext == NULL) return;
 
-    TCPFinEventHandler(arg, err);
+    /* The PCB is dead now */
+    Connection->SocketContext = NULL;
+
+    /* Give them one shot to receive the remaining data */
+    Connection->ReceiveShutdown = TRUE;
+    Connection->ReceiveShutdownStatus = TCPTranslateError(err);
+    TCPRecvEventHandler(Connection);
+
+    /* Terminate the connection */
+    TCPFinEventHandler(Connection, err);
 }
 
 static
@@ -315,7 +343,7 @@ LibTCPSocket(void *arg)
 {
     struct lwip_callback_msg *msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
     struct tcp_pcb *ret;
-    
+
     if (msg)
     {
         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
@@ -327,12 +355,12 @@ LibTCPSocket(void *arg)
             ret = msg->Output.Socket.NewPcb;
         else
             ret = NULL;
-        
+
         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-        
+
         return ret;
     }
-    
+
     return NULL;
 }
 
@@ -341,6 +369,7 @@ void
 LibTCPBindCallback(void *arg)
 {
     struct lwip_callback_msg *msg = arg;
+    PTCP_PCB pcb = msg->Input.Bind.Connection->SocketContext;
 
     ASSERT(msg);
 
@@ -350,10 +379,13 @@ LibTCPBindCallback(void *arg)
         goto done;
     }
 
-    msg->Output.Bind.Error = tcp_bind((PTCP_PCB)msg->Input.Bind.Connection->SocketContext,
+    /* We're guaranteed that the local address is valid to bind at this point */
+    pcb->so_options |= SOF_REUSEADDR;
+
+    msg->Output.Bind.Error = tcp_bind(pcb,
                                       msg->Input.Bind.IpAddress,
                                       ntohs(msg->Input.Bind.Port));
-    
+
 done:
     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
 }
@@ -363,7 +395,7 @@ LibTCPBind(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const
 {
     struct lwip_callback_msg *msg;
     err_t ret;
-    
+
     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
     if (msg)
     {
@@ -371,19 +403,19 @@ LibTCPBind(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const
         msg->Input.Bind.Connection = Connection;
         msg->Input.Bind.IpAddress = ipaddr;
         msg->Input.Bind.Port = port;
-        
+
         tcpip_callback_with_block(LibTCPBindCallback, msg, 1);
-        
+
         if (WaitForEventSafely(&msg->Event))
             ret = msg->Output.Bind.Error;
         else
             ret = ERR_CLSD;
-        
+
         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-        
+
         return ret;
     }
-    
+
     return ERR_MEM;
 }
 
@@ -392,9 +424,9 @@ void
 LibTCPListenCallback(void *arg)
 {
     struct lwip_callback_msg *msg = arg;
-    
+
     ASSERT(msg);
-    
+
     if (!msg->Input.Listen.Connection->SocketContext)
     {
         msg->Output.Listen.NewPcb = NULL;
@@ -402,12 +434,12 @@ LibTCPListenCallback(void *arg)
     }
 
     msg->Output.Listen.NewPcb = tcp_listen_with_backlog((PTCP_PCB)msg->Input.Listen.Connection->SocketContext, msg->Input.Listen.Backlog);
-    
+
     if (msg->Output.Listen.NewPcb)
     {
         tcp_accept(msg->Output.Listen.NewPcb, InternalAcceptEventHandler);
     }
-    
+
 done:
     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
 }
@@ -417,23 +449,23 @@ LibTCPListen(PCONNECTION_ENDPOINT Connection, const u8_t backlog)
 {
     struct lwip_callback_msg *msg;
     PTCP_PCB ret;
-    
+
     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
     if (msg)
     {
         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
         msg->Input.Listen.Connection = Connection;
         msg->Input.Listen.Backlog = backlog;
-        
+
         tcpip_callback_with_block(LibTCPListenCallback, msg, 1);
-        
+
         if (WaitForEventSafely(&msg->Event))
             ret = msg->Output.Listen.NewPcb;
         else
             ret = NULL;
-        
+
         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-        
+
         return ret;
     }
 
@@ -445,46 +477,67 @@ void
 LibTCPSendCallback(void *arg)
 {
     struct lwip_callback_msg *msg = arg;
-    
+    PTCP_PCB pcb = msg->Input.Send.Connection->SocketContext;
+    ULONG SendLength;
+    UCHAR SendFlags;
+
     ASSERT(msg);
-    
+
     if (!msg->Input.Send.Connection->SocketContext)
     {
         msg->Output.Send.Error = ERR_CLSD;
         goto done;
     }
-    
+
     if (msg->Input.Send.Connection->SendShutdown)
     {
         msg->Output.Send.Error = ERR_CLSD;
         goto done;
     }
-    
-    msg->Output.Send.Error = tcp_write((PTCP_PCB)msg->Input.Send.Connection->SocketContext,
-                                       msg->Input.Send.Data,
-                                       msg->Input.Send.DataLength,
-                                       TCP_WRITE_FLAG_COPY);
-    if (msg->Output.Send.Error == ERR_MEM)
+
+    SendFlags = TCP_WRITE_FLAG_COPY;
+    SendLength = msg->Input.Send.DataLength;
+    if (tcp_sndbuf(pcb) == 0)
     {
         /* No buffer space so return pending */
         msg->Output.Send.Error = ERR_INPROGRESS;
+        goto done;
     }
-    else if (msg->Output.Send.Error == ERR_OK)
+    else if (tcp_sndbuf(pcb) < SendLength)
     {
-        /* Queued successfully so try to send it */   
+        /* We've got some room so let's send what we can */
+        SendLength = tcp_sndbuf(pcb);
+
+        /* Don't set the push flag */
+        SendFlags |= TCP_WRITE_FLAG_MORE;
+    }
+
+    msg->Output.Send.Error = tcp_write(pcb,
+                                       msg->Input.Send.Data,
+                                       SendLength,
+                                       SendFlags);
+    if (msg->Output.Send.Error == ERR_OK)
+    {
+        /* Queued successfully so try to send it */
         tcp_output((PTCP_PCB)msg->Input.Send.Connection->SocketContext);
+        msg->Output.Send.Information = SendLength;
     }
-    
+    else if (msg->Output.Send.Error == ERR_MEM)
+    {
+        /* The queue is too long */
+        msg->Output.Send.Error = ERR_INPROGRESS;
+    }
+
 done:
     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
 }
 
 err_t
-LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, const int safe)
+LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, u32_t *sent, const int safe)
 {
     err_t ret;
     struct lwip_callback_msg *msg;
-    
+
     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
     if (msg)
     {
@@ -492,19 +545,24 @@ LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len
         msg->Input.Send.Connection = Connection;
         msg->Input.Send.Data = dataptr;
         msg->Input.Send.DataLength = len;
-        
+
         if (safe)
             LibTCPSendCallback(msg);
         else
             tcpip_callback_with_block(LibTCPSendCallback, msg, 1);
-        
+
         if (WaitForEventSafely(&msg->Event))
             ret = msg->Output.Send.Error;
         else
             ret = ERR_CLSD;
-        
+
+        if (ret == ERR_OK)
+            *sent = msg->Output.Send.Information;
+        else
+            *sent = 0;
+
         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-        
+
         return ret;
     }
 
@@ -516,24 +574,25 @@ void
 LibTCPConnectCallback(void *arg)
 {
     struct lwip_callback_msg *msg = arg;
-    
+    err_t Error;
+
     ASSERT(arg);
-    
+
     if (!msg->Input.Connect.Connection->SocketContext)
     {
         msg->Output.Connect.Error = ERR_CLSD;
         goto done;
     }
-    
+
     tcp_recv((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalRecvEventHandler);
     tcp_sent((PTCP_PCB)msg->Input.Connect.Connection->SocketContext, InternalSendEventHandler);
 
-    err_t Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
-                                msg->Input.Connect.IpAddress, ntohs(msg->Input.Connect.Port),
-                                InternalConnectEventHandler);
+    Error = tcp_connect((PTCP_PCB)msg->Input.Connect.Connection->SocketContext,
+                        msg->Input.Connect.IpAddress, ntohs(msg->Input.Connect.Port),
+                        InternalConnectEventHandler);
 
     msg->Output.Connect.Error = Error == ERR_OK ? ERR_INPROGRESS : Error;
-    
+
 done:
     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
 }
@@ -543,7 +602,7 @@ LibTCPConnect(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, con
 {
     struct lwip_callback_msg *msg;
     err_t ret;
-    
+
     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
     if (msg)
     {
@@ -551,21 +610,21 @@ LibTCPConnect(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, con
         msg->Input.Connect.Connection = Connection;
         msg->Input.Connect.IpAddress = ipaddr;
         msg->Input.Connect.Port = port;
-        
+
         tcpip_callback_with_block(LibTCPConnectCallback, msg, 1);
-        
+
         if (WaitForEventSafely(&msg->Event))
         {
             ret = msg->Output.Connect.Error;
         }
         else
             ret = ERR_CLSD;
-        
+
         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-        
+
         return ret;
     }
-    
+
     return ERR_MEM;
 }
 
@@ -582,26 +641,40 @@ LibTCPShutdownCallback(void *arg)
         goto done;
     }
 
-    if (pcb->state == CLOSE_WAIT)
-    {
-        /* This case actually results in a socket closure later (lwIP bug?) */
-        msg->Input.Shutdown.Connection->SocketContext = NULL;
+    /* LwIP makes the (questionable) assumption that SHUTDOWN_RDWR is equivalent to tcp_close().
+     * This assumption holds even if the shutdown calls are done separately (even through multiple
+     * WinSock shutdown() calls). This assumption means that lwIP has the right to deallocate our
+     * PCB without telling us if we shutdown TX and RX. To avoid these problems, we'll clear the
+     * socket context if we have called shutdown for TX and RX.
+     */
+    if (msg->Input.Shutdown.shut_rx) {
+        msg->Output.Shutdown.Error = tcp_shutdown(pcb, TRUE, FALSE);
     }
-
-    msg->Output.Shutdown.Error = tcp_shutdown(pcb, msg->Input.Shutdown.shut_rx, msg->Input.Shutdown.shut_tx);
-    if (msg->Output.Shutdown.Error)
-    {
-        msg->Input.Shutdown.Connection->SocketContext = pcb;
+    if (msg->Input.Shutdown.shut_tx) {
+        msg->Output.Shutdown.Error = tcp_shutdown(pcb, FALSE, TRUE);
     }
-    else
+
+    if (!msg->Output.Shutdown.Error)
     {
         if (msg->Input.Shutdown.shut_rx)
+        {
             msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
-        
+            msg->Input.Shutdown.Connection->ReceiveShutdownStatus = STATUS_FILE_CLOSED;
+        }
+
         if (msg->Input.Shutdown.shut_tx)
             msg->Input.Shutdown.Connection->SendShutdown = TRUE;
+
+        if (msg->Input.Shutdown.Connection->ReceiveShutdown &&
+            msg->Input.Shutdown.Connection->SendShutdown)
+        {
+            /* The PCB is not ours anymore */
+            msg->Input.Shutdown.Connection->SocketContext = NULL;
+            tcp_arg(pcb, NULL);
+            TCPFinEventHandler(msg->Input.Shutdown.Connection, ERR_CLSD);
+        }
     }
-    
+
 done:
     KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
 }
@@ -616,23 +689,23 @@ LibTCPShutdown(PCONNECTION_ENDPOINT Connection, const int shut_rx, const int shu
     if (msg)
     {
         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
-        
+
         msg->Input.Shutdown.Connection = Connection;
         msg->Input.Shutdown.shut_rx = shut_rx;
         msg->Input.Shutdown.shut_tx = shut_tx;
-                
+
         tcpip_callback_with_block(LibTCPShutdownCallback, msg, 1);
-        
+
         if (WaitForEventSafely(&msg->Event))
             ret = msg->Output.Shutdown.Error;
         else
             ret = ERR_CLSD;
-        
+
         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-        
+
         return ret;
     }
-    
+
     return ERR_MEM;
 }
 
@@ -642,46 +715,45 @@ LibTCPCloseCallback(void *arg)
 {
     struct lwip_callback_msg *msg = arg;
     PTCP_PCB pcb = msg->Input.Close.Connection->SocketContext;
-    int state;
 
     /* Empty the queue even if we're already "closed" */
     LibTCPEmptyQueue(msg->Input.Close.Connection);
 
+    /* Check if we've already been closed */
+    if (msg->Input.Close.Connection->Closing)
+    {
+        msg->Output.Close.Error = ERR_OK;
+        goto done;
+    }
+
+    /* Enter "closing" mode if we're doing a normal close */
+    if (msg->Input.Close.Callback)
+        msg->Input.Close.Connection->Closing = TRUE;
+
+    /* Check if the PCB was already "closed" but the client doesn't know it yet */
     if (!msg->Input.Close.Connection->SocketContext)
     {
         msg->Output.Close.Error = ERR_OK;
         goto done;
     }
 
-    /* Clear the PCB pointer */
+    /* Clear the PCB pointer and stop callbacks */
     msg->Input.Close.Connection->SocketContext = NULL;
+    tcp_arg(pcb, NULL);
 
-    /* Save the old PCB state */
-    state = pcb->state;
-
+    /* This may generate additional callbacks but we don't care,
+     * because they're too inconsistent to rely on */
     msg->Output.Close.Error = tcp_close(pcb);
-    if (!msg->Output.Close.Error)
-    {
-        if (msg->Input.Close.Callback)
-        {
-            /* Call the FIN handler in the cases where it will not be called by lwIP */
-            switch (state)
-            {
-                case CLOSED:
-                case LISTEN:
-                case SYN_SENT:
-                   TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
-                   break;
-
-                default:
-                   break;
-            }
-        }
-    }
-    else
+
+    if (msg->Output.Close.Error)
     {
         /* Restore the PCB pointer */
         msg->Input.Close.Connection->SocketContext = pcb;
+        msg->Input.Close.Connection->Closing = FALSE;
+    }
+    else if (msg->Input.Close.Callback)
+    {
+        TCPFinEventHandler(msg->Input.Close.Connection, ERR_CLSD);
     }
 
 done:
@@ -693,27 +765,27 @@ LibTCPClose(PCONNECTION_ENDPOINT Connection, const int safe, const int callback)
 {
     err_t ret;
     struct lwip_callback_msg *msg;
-    
+
     msg = ExAllocateFromNPagedLookasideList(&MessageLookasideList);
     if (msg)
     {
         KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
-        
+
         msg->Input.Close.Connection = Connection;
         msg->Input.Close.Callback = callback;
-        
+
         if (safe)
             LibTCPCloseCallback(msg);
         else
             tcpip_callback_with_block(LibTCPCloseCallback, msg, 1);
-        
+
         if (WaitForEventSafely(&msg->Event))
             ret = msg->Output.Close.Error;
         else
             ret = ERR_CLSD;
-        
+
         ExFreeToNPagedLookasideList(&MessageLookasideList, msg);
-        
+
         return ret;
     }
 
@@ -724,13 +796,13 @@ void
 LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg)
 {
     ASSERT(arg);
-    
+
     tcp_arg(pcb, NULL);
     tcp_recv(pcb, InternalRecvEventHandler);
     tcp_sent(pcb, InternalSendEventHandler);
     tcp_err(pcb, InternalErrorEventHandler);
     tcp_arg(pcb, arg);
-    
+
     tcp_accepted(listen_pcb);
 }
 
@@ -739,21 +811,32 @@ LibTCPGetHostName(PTCP_PCB pcb, struct ip_addr *const ipaddr, u16_t *const port)
 {
     if (!pcb)
         return ERR_CLSD;
-    
+
     *ipaddr = pcb->local_ip;
     *port = pcb->local_port;
-    
+
     return ERR_OK;
 }
 
 err_t
 LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr * const ipaddr, u16_t * const port)
-{    
+{
     if (!pcb)
         return ERR_CLSD;
-    
+
     *ipaddr = pcb->remote_ip;
     *port = pcb->remote_port;
-    
+
     return ERR_OK;
 }
+
+void
+LibTCPSetNoDelay(
+    PTCP_PCB pcb,
+    BOOLEAN Set)
+{
+    if (Set)
+        pcb->flags |= TF_NODELAY;
+    else
+        pcb->flags &= ~TF_NODELAY;
+}