[AFD]
authorCameron Gutman <aicommander@gmail.com>
Sun, 23 Dec 2012 09:53:36 +0000 (09:53 +0000)
committerCameron Gutman <aicommander@gmail.com>
Sun, 23 Dec 2012 09:53:36 +0000 (09:53 +0000)
- Fix a bug during socket closure that could cause buffered data to be lost
[LWIP]
- Only signal errors when the packet queue is empty to avoid dropping received data
- Use pbuf_copy_partial to copy pbuf chains instead of doing it manually which had a bug
CORE-6655 #resolve CORE-6741 #resolve

svn path=/trunk/; revision=57975

reactos/drivers/network/afd/afd/main.c
reactos/drivers/network/afd/afd/read.c
reactos/drivers/network/afd/include/afd.h
reactos/drivers/network/tcpip/include/titypes.h
reactos/lib/drivers/ip/transport/tcp/event.c
reactos/lib/drivers/lwip/src/rostcp.c

index 41c8b7b..eb050bd 100644 (file)
@@ -610,6 +610,7 @@ DisconnectComplete(PDEVICE_OBJECT DeviceObject,
         /* Signal complete connection closure immediately */
         FCB->PollState |= AFD_EVENT_ABORT;
         FCB->PollStatus[FD_CLOSE_BIT] = Irp->IoStatus.Status;
+        FCB->LastReceiveStatus = STATUS_FILE_CLOSED;
         PollReeval(FCB->DeviceExt, FCB->FileObject);
     }
 
@@ -708,8 +709,8 @@ AfdDisconnect(PDEVICE_OBJECT DeviceObject, PIRP Irp,
         /* Mark us as overread to complete future reads with an error */
         FCB->Overread = TRUE;
 
-        /* Set a successful close status to indicate a shutdown on overread */
-        FCB->PollStatus[FD_CLOSE_BIT] = STATUS_SUCCESS;
+        /* Set a successful receive status to indicate a shutdown on overread */
+        FCB->LastReceiveStatus = STATUS_SUCCESS;
 
         /* Clear the receive event */
         FCB->PollState &= ~AFD_EVENT_RECEIVE;
index 68ae10d..fcaea4b 100644 (file)
@@ -52,11 +52,12 @@ static VOID RefillSocketBuffer( PAFD_FCB FCB )
 
 static VOID HandleReceiveComplete( PAFD_FCB FCB, NTSTATUS Status, ULONG_PTR Information )
 {
+    FCB->LastReceiveStatus = Status;
+
     /* We got closed while the receive was in progress */
     if (FCB->TdiReceiveClosed)
     {
-        FCB->Recv.Content = 0;
-        FCB->Recv.BytesUsed = 0;
+        /* The received data is discarded */
     }
     /* Receive successful */
     else if (Status == STATUS_SUCCESS)
@@ -67,30 +68,20 @@ static VOID HandleReceiveComplete( PAFD_FCB FCB, NTSTATUS Status, ULONG_PTR Info
         /* Check for graceful closure */
         if (Information == 0)
         {
+            /* Receive is closed */
             FCB->TdiReceiveClosed = TRUE;
-
-            /* Signal graceful receive shutdown */
-            FCB->PollState |= AFD_EVENT_DISCONNECT;
-            FCB->PollStatus[FD_CLOSE_BIT] = Status;
-
-            PollReeval( FCB->DeviceExt, FCB->FileObject );
         }
-
-        /* Issue another receive IRP to keep the buffer well stocked */
-        RefillSocketBuffer(FCB);
+        else
+        {
+            /* Issue another receive IRP to keep the buffer well stocked */
+            RefillSocketBuffer(FCB);
+        }
     }
     /* Receive failed with no data (unexpected closure) */
     else
     {
-        FCB->Recv.BytesUsed = 0;
-        FCB->Recv.Content = 0;
+        /* Previously received data remains intact */
         FCB->TdiReceiveClosed = TRUE;
-
-        /* Signal complete connection failure immediately */
-        FCB->PollState |= AFD_EVENT_CLOSE;
-        FCB->PollStatus[FD_CLOSE_BIT] = Status;
-
-        PollReeval( FCB->DeviceExt, FCB->FileObject );
     }
 }
 
@@ -168,9 +159,6 @@ static NTSTATUS ReceiveActivity( PAFD_FCB FCB, PIRP Irp ) {
 
     AFD_DbgPrint(MID_TRACE,("%x %x\n", FCB, Irp));
 
-    /* Kick the user that receive would be possible now */
-    /* XXX Not implemented yet */
-
     AFD_DbgPrint(MID_TRACE,("FCB %x Receive data waiting %d\n",
                             FCB, FCB->Recv.Content));
 
@@ -188,7 +176,7 @@ static NTSTATUS ReceiveActivity( PAFD_FCB FCB, PIRP Irp ) {
                                     TotalBytesCopied));
             UnlockBuffers( RecvReq->BufferArray,
                            RecvReq->BufferCount, FALSE );
-            if (FCB->Overread && FCB->PollStatus[FD_CLOSE_BIT] == STATUS_SUCCESS)
+            if (FCB->Overread && FCB->LastReceiveStatus == STATUS_SUCCESS)
             {
                 /* Overread after a graceful disconnect so complete with an error */
                 Status = STATUS_FILE_CLOSED;
@@ -196,7 +184,7 @@ static NTSTATUS ReceiveActivity( PAFD_FCB FCB, PIRP Irp ) {
             else
             {
                 /* Unexpected disconnect by the remote host or initial read after a graceful disconnnect */
-                Status = FCB->PollStatus[FD_CLOSE_BIT];
+                Status = FCB->LastReceiveStatus;
             }
             NextIrp->IoStatus.Status = Status;
             NextIrp->IoStatus.Information = 0;
@@ -260,6 +248,21 @@ static NTSTATUS ReceiveActivity( PAFD_FCB FCB, PIRP Irp ) {
         FCB->PollState &= ~AFD_EVENT_RECEIVE;
     }
 
+    /* Signal FD_CLOSE if no buffered data remains and the socket can't receive any more */
+    if (CantReadMore(FCB))
+    {
+        if (FCB->LastReceiveStatus == STATUS_SUCCESS)
+        {
+            FCB->PollState |= AFD_EVENT_DISCONNECT;
+        }
+        else
+        {
+            FCB->PollState |= AFD_EVENT_CLOSE;
+        }
+        FCB->PollStatus[FD_CLOSE_BIT] = FCB->LastReceiveStatus;
+        PollReeval(FCB->DeviceExt, FCB->FileObject);
+    }
+
     AFD_DbgPrint(MID_TRACE,("RetStatus for irp %x is %x\n", Irp, RetStatus));
 
     /* Sometimes we're called with a NULL Irp */
index d77fbf2..827e057 100644 (file)
@@ -203,6 +203,7 @@ typedef struct _AFD_FCB {
     PVOID Context;
     DWORD PollState;
     NTSTATUS PollStatus[FD_MAX_EVENTS];
+    NTSTATUS LastReceiveStatus;
     UINT ContextSize;
     PVOID ConnectData;
     UINT FilledConnectData;
index 9f1f1f1..7f5d1df 100644 (file)
@@ -278,6 +278,7 @@ typedef struct _CONNECTION_ENDPOINT {
     /* Socket state */
     BOOLEAN SendShutdown;
     BOOLEAN ReceiveShutdown;
+    NTSTATUS ReceiveShutdownStatus;
 
     struct _CONNECTION_ENDPOINT *Next; /* Next connection in address file list */
 } CONNECTION_ENDPOINT, *PCONNECTION_ENDPOINT;
index a2d632e..c5b8b95 100644 (file)
@@ -260,61 +260,53 @@ FlushAllQueues(PCONNECTION_ENDPOINT Connection, NTSTATUS Status)
 VOID
 TCPFinEventHandler(void *arg, const err_t err)
 {
-    PCONNECTION_ENDPOINT Connection = (PCONNECTION_ENDPOINT)arg, LastConnection;
-    const NTSTATUS Status = TCPTranslateError(err);
-    KIRQL OldIrql;
-
-    ASSERT(Connection->AddressFile);
-
-    /* Check if this was a partial socket closure */
-    if (err == ERR_OK && Connection->SocketContext)
-    {
-        /* Just flush the receive queue and get out of here */
-        FlushReceiveQueue(Connection, STATUS_SUCCESS, TRUE);
-    }
-    else
-    {
-        /* First off all, remove the PCB pointer */
-        Connection->SocketContext = NULL;
-
-        /* Complete all outstanding requests now */
-        FlushAllQueues(Connection, Status);
-
-        LockObject(Connection, &OldIrql);
-
-        LockObjectAtDpcLevel(Connection->AddressFile);
-
-        /* Unlink this connection from the address file */
-        if (Connection->AddressFile->Connection == Connection)
-        {
-            Connection->AddressFile->Connection = Connection->Next;
-            DereferenceObject(Connection);
-        }
-        else if (Connection->AddressFile->Listener == Connection)
-        {
-            Connection->AddressFile->Listener = NULL;
-            DereferenceObject(Connection);
-        }
-        else
-        {
-            LastConnection = Connection->AddressFile->Connection;
-            while (LastConnection->Next != Connection && LastConnection->Next != NULL)
-                LastConnection = LastConnection->Next;
-            if (LastConnection->Next == Connection)
-            {
-                LastConnection->Next = Connection->Next;
-                DereferenceObject(Connection);
-            }
-        }
-
-        UnlockObjectFromDpcLevel(Connection->AddressFile);
-
-        /* Remove the address file from this connection */
-        DereferenceObject(Connection->AddressFile);
-        Connection->AddressFile = NULL;
-
-        UnlockObject(Connection, OldIrql);
-    }
+   PCONNECTION_ENDPOINT Connection = (PCONNECTION_ENDPOINT)arg, LastConnection;
+   const NTSTATUS Status = TCPTranslateError(err);
+   KIRQL OldIrql;
+
+   ASSERT(Connection->AddressFile);
+   ASSERT(err != ERR_OK);
+
+   /* First off all, remove the PCB pointer */
+   Connection->SocketContext = NULL;
+
+   /* Complete all outstanding requests now */
+   FlushAllQueues(Connection, Status);
+
+   LockObject(Connection, &OldIrql);
+
+   LockObjectAtDpcLevel(Connection->AddressFile);
+
+   /* Unlink this connection from the address file */
+   if (Connection->AddressFile->Connection == Connection)
+   {
+       Connection->AddressFile->Connection = Connection->Next;
+       DereferenceObject(Connection);
+   }
+   else if (Connection->AddressFile->Listener == Connection)
+   {
+       Connection->AddressFile->Listener = NULL;
+       DereferenceObject(Connection);
+   }
+   else
+   {
+       LastConnection = Connection->AddressFile->Connection;
+       while (LastConnection->Next != Connection && LastConnection->Next != NULL)
+           LastConnection = LastConnection->Next;
+       if (LastConnection->Next == Connection)
+       {
+           LastConnection->Next = Connection->Next;
+           DereferenceObject(Connection);
+       }
+   }
+
+   UnlockObjectFromDpcLevel(Connection->AddressFile);
+
+   /* Remove the address file from this connection */
+   DereferenceObject(Connection->AddressFile);
+   Connection->AddressFile = NULL;
+
+   UnlockObject(Connection, OldIrql);
 }
     
 VOID
index 005f62e..1eb8920 100755 (executable)
@@ -31,6 +31,9 @@ extern KEVENT TerminationEvent;
 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
 
+/* Required for ERR_T to NTSTATUS translation in receive error handling */
+NTSTATUS TCPTranslateError(const err_t err);
+
 static
 void
 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
@@ -84,9 +87,8 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
     PQUEUE_ENTRY qp;
     struct pbuf* p;
     NTSTATUS Status;
-    UINT ReadLength, PayloadLength;
+    UINT ReadLength, PayloadLength, Offset, Copied;
     KIRQL OldIrql;
-    PUCHAR Payload;
 
     (*Received) = 0;
 
@@ -98,14 +100,14 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
         {
             p = qp->p;
 
-            /* Calculate the payload first */
-            Payload = p->payload;
-            Payload += qp->Offset;
-            PayloadLength = p->len;
+            /* Calculate the payload length first */
+            PayloadLength = p->tot_len;
             PayloadLength -= qp->Offset;
+            Offset = qp->Offset;
 
             /* Check if we're reading the whole buffer */
             ReadLength = MIN(PayloadLength, RecvLen);
+            ASSERT(ReadLength != 0);
             if (ReadLength != PayloadLength)
             {
                 /* Save this one for later */
@@ -116,10 +118,8 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
 
             UnlockObject(Connection, OldIrql);
 
-            /* Return to a lower IRQL because the receive buffer may be pageable memory */
-            RtlCopyMemory(RecvBuffer,
-                          Payload,
-                          ReadLength);
+            Copied = pbuf_copy_partial(p, RecvBuffer, ReadLength, Offset);
+            ASSERT(Copied == ReadLength);
 
             LockObject(Connection, &OldIrql);
 
@@ -141,6 +141,7 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
                 ASSERT(RecvLen == 0);
             }
 
+            ASSERT((*Received) != 0);
             Status = STATUS_SUCCESS;
 
             if (!RecvLen)
@@ -150,7 +151,7 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
     else
     {
         if (Connection->ReceiveShutdown)
-            Status = STATUS_SUCCESS;
+            Status = Connection->ReceiveShutdownStatus;
         else
             Status = STATUS_PENDING;
     }
@@ -202,8 +203,6 @@ err_t
 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
 {
     PCONNECTION_ENDPOINT Connection = arg;
-    struct pbuf *pb;
-    ULONG RecvLen;
 
     /* Make sure the socket didn't get closed */
     if (!arg)
@@ -216,19 +215,9 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
 
     if (p)
     {
-        pb = p;
-        RecvLen = 0;
-        while (pb != NULL)
-        {
-            /* Enqueue this buffer */
-            LibTCPEnqueuePacket(Connection, pb);
-            RecvLen += pb->len;
+        LibTCPEnqueuePacket(Connection, p);
 
-            /* Advance and unchain the buffer */
-            pb = pbuf_dechain(pb);;
-        }
-
-        tcp_recved(pcb, RecvLen);
+        tcp_recved(pcb, p->tot_len);
 
         TCPRecvEventHandler(arg);
     }
@@ -239,7 +228,20 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
          * whole socket here (by calling tcp_close()) as that would violate TCP specs
          */
         Connection->ReceiveShutdown = TRUE;
-        TCPFinEventHandler(arg, ERR_OK);
+        Connection->ReceiveShutdownStatus = STATUS_SUCCESS;
+
+        /* This code path executes for both remotely and locally initiated closures,
+         * and we need to distinguish between them */
+        if (Connection->SocketContext)
+        {
+            /* Remotely initiated close */
+            TCPRecvEventHandler(arg);
+        }
+        else
+        {
+            /* Locally initated close */
+            TCPFinEventHandler(arg, ERR_CLSD);
+        }
     }
 
     return ERR_OK;
@@ -279,10 +281,31 @@ static
 void
 InternalErrorEventHandler(void *arg, const err_t err)
 {
+    PCONNECTION_ENDPOINT Connection = arg;
+    KIRQL OldIrql;
+
     /* Make sure the socket didn't get closed */
     if (!arg) return;
 
-    TCPFinEventHandler(arg, err);
+    /* Check if data is left to be read */
+    LockObject(Connection, &OldIrql);
+    if (IsListEmpty(&Connection->PacketQueue))
+    {
+        UnlockObject(Connection, OldIrql);
+
+        /* Deliver the error now */
+        TCPFinEventHandler(arg, err);
+    }
+    else
+    {
+        UnlockObject(Connection, OldIrql);
+
+        /* Defer the error delivery until all data is gone */
+        Connection->ReceiveShutdown = TRUE;
+        Connection->ReceiveShutdownStatus = TCPTranslateError(err);
+
+        TCPRecvEventHandler(arg);
+    }
 }
 
 static
@@ -621,7 +644,10 @@ LibTCPShutdownCallback(void *arg)
     else
     {
         if (msg->Input.Shutdown.shut_rx)
+        {
             msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
+            msg->Input.Shutdown.Connection->ReceiveShutdownStatus = STATUS_FILE_CLOSED;
+        }
 
         if (msg->Input.Shutdown.shut_tx)
             msg->Input.Shutdown.Connection->SendShutdown = TRUE;
@@ -688,7 +714,7 @@ LibTCPCloseCallback(void *arg)
            msg->Output.Close.Error = tcp_close(pcb);
 
            if (!msg->Output.Close.Error && msg->Input.Close.Callback)
-               TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
+               TCPFinEventHandler(msg->Input.Close.Connection, ERR_CLSD);
            break;
 
         default: