[AFD]
[reactos.git] / reactos / lib / drivers / lwip / src / rostcp.c
index 005f62e..1eb8920 100755 (executable)
@@ -31,6 +31,9 @@ extern KEVENT TerminationEvent;
 extern NPAGED_LOOKASIDE_LIST MessageLookasideList;
 extern NPAGED_LOOKASIDE_LIST QueueEntryLookasideList;
 
+/* Required for ERR_T to NTSTATUS translation in receive error handling */
+NTSTATUS TCPTranslateError(const err_t err);
+
 static
 void
 LibTCPEmptyQueue(PCONNECTION_ENDPOINT Connection)
@@ -84,9 +87,8 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
     PQUEUE_ENTRY qp;
     struct pbuf* p;
     NTSTATUS Status;
-    UINT ReadLength, PayloadLength;
+    UINT ReadLength, PayloadLength, Offset, Copied;
     KIRQL OldIrql;
-    PUCHAR Payload;
 
     (*Received) = 0;
 
@@ -98,14 +100,14 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
         {
             p = qp->p;
 
-            /* Calculate the payload first */
-            Payload = p->payload;
-            Payload += qp->Offset;
-            PayloadLength = p->len;
+            /* Calculate the payload length first */
+            PayloadLength = p->tot_len;
             PayloadLength -= qp->Offset;
+            Offset = qp->Offset;
 
             /* Check if we're reading the whole buffer */
             ReadLength = MIN(PayloadLength, RecvLen);
+            ASSERT(ReadLength != 0);
             if (ReadLength != PayloadLength)
             {
                 /* Save this one for later */
@@ -116,10 +118,8 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
 
             UnlockObject(Connection, OldIrql);
 
-            /* Return to a lower IRQL because the receive buffer may be pageable memory */
-            RtlCopyMemory(RecvBuffer,
-                          Payload,
-                          ReadLength);
+            Copied = pbuf_copy_partial(p, RecvBuffer, ReadLength, Offset);
+            ASSERT(Copied == ReadLength);
 
             LockObject(Connection, &OldIrql);
 
@@ -141,6 +141,7 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
                 ASSERT(RecvLen == 0);
             }
 
+            ASSERT((*Received) != 0);
             Status = STATUS_SUCCESS;
 
             if (!RecvLen)
@@ -150,7 +151,7 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
     else
     {
         if (Connection->ReceiveShutdown)
-            Status = STATUS_SUCCESS;
+            Status = Connection->ReceiveShutdownStatus;
         else
             Status = STATUS_PENDING;
     }
@@ -202,8 +203,6 @@ err_t
 InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
 {
     PCONNECTION_ENDPOINT Connection = arg;
-    struct pbuf *pb;
-    ULONG RecvLen;
 
     /* Make sure the socket didn't get closed */
     if (!arg)
@@ -216,19 +215,9 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
 
     if (p)
     {
-        pb = p;
-        RecvLen = 0;
-        while (pb != NULL)
-        {
-            /* Enqueue this buffer */
-            LibTCPEnqueuePacket(Connection, pb);
-            RecvLen += pb->len;
+        LibTCPEnqueuePacket(Connection, p);
 
-            /* Advance and unchain the buffer */
-            pb = pbuf_dechain(pb);;
-        }
-
-        tcp_recved(pcb, RecvLen);
+        tcp_recved(pcb, p->tot_len);
 
         TCPRecvEventHandler(arg);
     }
@@ -239,7 +228,20 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
          * whole socket here (by calling tcp_close()) as that would violate TCP specs
          */
         Connection->ReceiveShutdown = TRUE;
-        TCPFinEventHandler(arg, ERR_OK);
+        Connection->ReceiveShutdownStatus = STATUS_SUCCESS;
+
+        /* This code path executes for both remotely and locally initiated closures,
+         * and we need to distinguish between them */
+        if (Connection->SocketContext)
+        {
+            /* Remotely initiated close */
+            TCPRecvEventHandler(arg);
+        }
+        else
+        {
+            /* Locally initated close */
+            TCPFinEventHandler(arg, ERR_CLSD);
+        }
     }
 
     return ERR_OK;
@@ -279,10 +281,31 @@ static
 void
 InternalErrorEventHandler(void *arg, const err_t err)
 {
+    PCONNECTION_ENDPOINT Connection = arg;
+    KIRQL OldIrql;
+
     /* Make sure the socket didn't get closed */
     if (!arg) return;
 
-    TCPFinEventHandler(arg, err);
+    /* Check if data is left to be read */
+    LockObject(Connection, &OldIrql);
+    if (IsListEmpty(&Connection->PacketQueue))
+    {
+        UnlockObject(Connection, OldIrql);
+
+        /* Deliver the error now */
+        TCPFinEventHandler(arg, err);
+    }
+    else
+    {
+        UnlockObject(Connection, OldIrql);
+
+        /* Defer the error delivery until all data is gone */
+        Connection->ReceiveShutdown = TRUE;
+        Connection->ReceiveShutdownStatus = TCPTranslateError(err);
+
+        TCPRecvEventHandler(arg);
+    }
 }
 
 static
@@ -621,7 +644,10 @@ LibTCPShutdownCallback(void *arg)
     else
     {
         if (msg->Input.Shutdown.shut_rx)
+        {
             msg->Input.Shutdown.Connection->ReceiveShutdown = TRUE;
+            msg->Input.Shutdown.Connection->ReceiveShutdownStatus = STATUS_FILE_CLOSED;
+        }
 
         if (msg->Input.Shutdown.shut_tx)
             msg->Input.Shutdown.Connection->SendShutdown = TRUE;
@@ -688,7 +714,7 @@ LibTCPCloseCallback(void *arg)
            msg->Output.Close.Error = tcp_close(pcb);
 
            if (!msg->Output.Close.Error && msg->Input.Close.Callback)
-               TCPFinEventHandler(msg->Input.Close.Connection, ERR_OK);
+               TCPFinEventHandler(msg->Input.Close.Connection, ERR_CLSD);
            break;
 
         default: