Sync to trunk head (r42241)
[reactos.git] / reactos / drivers / network / tcpip / tcpip / dispatch.c
index 04c5de6..304657b 100644 (file)
@@ -86,15 +86,6 @@ VOID DispDataRequestComplete(
 
     (void)IoSetCancelRoutine(Irp, NULL);
 
-    if (Irp->Cancel || TranContext->CancelIrps) {
-        /* The IRP has been cancelled */
-
-        TI_DbgPrint(DEBUG_IRP, ("IRP is cancelled.\n"));
-
-        Status = STATUS_CANCELLED;
-        Count  = 0;
-    }
-
     IoReleaseCancelSpinLock(OldIrql);
 
     Irp->IoStatus.Status      = Status;
@@ -106,22 +97,16 @@ VOID DispDataRequestComplete(
                            Irp->IoStatus.Information));
     TI_DbgPrint(DEBUG_IRP, ("Completing IRP at (0x%X).\n", Irp));
 
-    IRPFinish(Irp, Irp->IoStatus.Status);
+    IRPFinish(Irp, Status);
 
     TI_DbgPrint(DEBUG_IRP, ("Done Completing IRP\n"));
 }
 
-typedef struct _DISCONNECT_TYPE {
-    UINT Type;
-    PVOID Context;
-    PIRP Irp;
-    PFILE_OBJECT FileObject;
-} DISCONNECT_TYPE, *PDISCONNECT_TYPE;
-
 VOID DispDoDisconnect( PVOID Data ) {
     PDISCONNECT_TYPE DisType = (PDISCONNECT_TYPE)Data;
 
     TI_DbgPrint(DEBUG_IRP, ("PostCancel: DoDisconnect\n"));
+    TcpipRecursiveMutexEnter(&TCPLock, TRUE);
     TCPDisconnect
        ( DisType->Context,
          DisType->Type,
@@ -129,6 +114,7 @@ VOID DispDoDisconnect( PVOID Data ) {
          NULL,
          DispDataRequestComplete,
          DisType->Irp );
+    TcpipRecursiveMutexLeave(&TCPLock);
     TI_DbgPrint(DEBUG_IRP, ("PostCancel: DoDisconnect done\n"));
 
     DispDataRequestComplete(DisType->Irp, STATUS_CANCELLED, 0);
@@ -164,7 +150,7 @@ VOID NTAPI DispCancelRequest(
     Irp->IoStatus.Status = STATUS_CANCELLED;
     Irp->IoStatus.Information = 0;
 
-#ifdef DBG
+#if DBG
     if (!Irp->Cancel)
         TI_DbgPrint(MIN_TRACE, ("Irp->Cancel is FALSE, should be TRUE.\n"));
 #endif
@@ -181,11 +167,11 @@ VOID NTAPI DispCancelRequest(
 
        TCPRemoveIRP( TranContext->Handle.ConnectionContext, Irp );
 
+       IoReleaseCancelSpinLock(Irp->CancelIrql);
+
        if( !ChewCreate( &WorkItem, sizeof(DISCONNECT_TYPE),
                         DispDoDisconnect, &DisType ) )
            ASSERT(0);
-
-       IoReleaseCancelSpinLock(Irp->CancelIrql);
         return;
 
     case TDI_SEND_DATAGRAM:
@@ -212,7 +198,7 @@ VOID NTAPI DispCancelRequest(
     }
 
     IoReleaseCancelSpinLock(Irp->CancelIrql);
-    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+    IRPFinish(Irp, STATUS_CANCELLED);
 
     TI_DbgPrint(MAX_TRACE, ("Leaving.\n"));
 }
@@ -243,20 +229,24 @@ VOID NTAPI DispCancelListenRequest(
 
     TI_DbgPrint(DEBUG_IRP, ("IRP at (0x%X).\n", Irp));
 
-#ifdef DBG
+#if DBG
     if (!Irp->Cancel)
         TI_DbgPrint(MIN_TRACE, ("Irp->Cancel is FALSE, should be TRUE.\n"));
 #endif
 
     /* Try canceling the request */
     Connection = (PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
+
+    TCPRemoveIRP(Connection, Irp);
+
     TCPAbortListenForSocket(
            Connection->AddressFile->Listener,
            Connection );
 
     IoReleaseCancelSpinLock(Irp->CancelIrql);
 
-    DispDataRequestComplete(Irp, STATUS_CANCELLED, 0);
+    Irp->IoStatus.Information = 0;
+    IRPFinish(Irp, STATUS_CANCELLED);
 
     TI_DbgPrint(MAX_TRACE, ("Leaving.\n"));
 }
@@ -396,14 +386,14 @@ NTSTATUS DispTdiConnect(
   TranContext = IrpSp->FileObject->FsContext;
   if (!TranContext) {
     TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-    Status = STATUS_INVALID_CONNECTION;
+    Status = STATUS_INVALID_PARAMETER;
     goto done;
   }
 
   Connection = (PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
   if (!Connection) {
     TI_DbgPrint(MID_TRACE, ("No connection endpoint file object.\n"));
-    Status = STATUS_INVALID_CONNECTION;
+    Status = STATUS_INVALID_PARAMETER;
     goto done;
   }
 
@@ -417,13 +407,13 @@ NTSTATUS DispTdiConnect(
       Irp );
 
 done:
+  TcpipRecursiveMutexLeave( &TCPLock );
+
   if (Status != STATUS_PENDING) {
       DispDataRequestComplete(Irp, Status, 0);
   } else
       IoMarkIrpPending(Irp);
 
-  TcpipRecursiveMutexLeave( &TCPLock );
-
   TI_DbgPrint(MAX_TRACE, ("TCP Connect returned %08x\n", Status));
 
   return Status;
@@ -467,6 +457,12 @@ NTSTATUS DispTdiDisassociateAddress(
     return STATUS_INVALID_PARAMETER;
   }
 
+  /* Remove this connection from the address file */
+  Connection->AddressFile->Connection = NULL;
+
+  /* Remove the address file from this connection */
+  Connection->AddressFile = NULL;
+
   return STATUS_SUCCESS;
 }
 
@@ -499,14 +495,14 @@ NTSTATUS DispTdiDisconnect(
   TranContext = IrpSp->FileObject->FsContext;
   if (!TranContext) {
     TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-    Status = STATUS_INVALID_CONNECTION;
+    Status = STATUS_INVALID_PARAMETER;
     goto done;
   }
 
   Connection = (PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
   if (!Connection) {
     TI_DbgPrint(MID_TRACE, ("No connection endpoint file object.\n"));
-    Status = STATUS_INVALID_CONNECTION;
+    Status = STATUS_INVALID_PARAMETER;
     goto done;
   }
 
@@ -519,13 +515,13 @@ NTSTATUS DispTdiDisconnect(
       Irp );
 
 done:
+   TcpipRecursiveMutexLeave( &TCPLock );
+
    if (Status != STATUS_PENDING) {
        DispDataRequestComplete(Irp, Status, 0);
    } else
        IoMarkIrpPending(Irp);
 
-  TcpipRecursiveMutexLeave( &TCPLock );
-
   TI_DbgPrint(MAX_TRACE, ("TCP Disconnect returned %08x\n", Status));
 
   return Status;
@@ -560,7 +556,7 @@ NTSTATUS DispTdiListen(
   if (TranContext == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
@@ -568,7 +564,7 @@ NTSTATUS DispTdiListen(
   if (Connection == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("No connection endpoint file object.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
@@ -624,13 +620,13 @@ NTSTATUS DispTdiListen(
   }
 
 done:
+  TcpipRecursiveMutexLeave( &TCPLock );
+
   if (Status != STATUS_PENDING) {
       DispDataRequestComplete(Irp, Status, 0);
   } else
       IoMarkIrpPending(Irp);
 
-  TcpipRecursiveMutexLeave( &TCPLock );
-
   TI_DbgPrint(MID_TRACE,("Leaving %x\n", Status));
 
   return Status;
@@ -652,16 +648,20 @@ NTSTATUS DispTdiQueryInformation(
   PTDI_REQUEST_KERNEL_QUERY_INFORMATION Parameters;
   PTRANSPORT_CONTEXT TranContext;
   PIO_STACK_LOCATION IrpSp;
+  NTSTATUS Status;
 
   TI_DbgPrint(DEBUG_IRP, ("Called.\n"));
 
   IrpSp = IoGetCurrentIrpStackLocation(Irp);
   Parameters = (PTDI_REQUEST_KERNEL_QUERY_INFORMATION)&IrpSp->Parameters;
 
+  TcpipRecursiveMutexEnter( &TCPLock, TRUE );
+
   TranContext = IrpSp->FileObject->FsContext;
   if (!TranContext) {
     TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-    return STATUS_INVALID_CONNECTION;
+    TcpipRecursiveMutexLeave(&TCPLock);
+    return STATUS_INVALID_PARAMETER;
   }
 
   switch (Parameters->QueryType)
@@ -671,48 +671,59 @@ NTSTATUS DispTdiQueryInformation(
         PTDI_ADDRESS_INFO AddressInfo;
         PADDRESS_FILE AddrFile;
         PTA_IP_ADDRESS Address;
+        PCONNECTION_ENDPOINT Endpoint = NULL;
+
+
+        if (MmGetMdlByteCount(Irp->MdlAddress) <
+            (FIELD_OFFSET(TDI_ADDRESS_INFO, Address.Address[0].Address) +
+             sizeof(TDI_ADDRESS_IP))) {
+          TI_DbgPrint(MID_TRACE, ("MDL buffer too small.\n"));
+          TcpipRecursiveMutexLeave(&TCPLock);
+          return STATUS_BUFFER_TOO_SMALL;
+        }
 
         AddressInfo = (PTDI_ADDRESS_INFO)MmGetSystemAddressForMdl(Irp->MdlAddress);
+               Address = (PTA_IP_ADDRESS)&AddressInfo->Address;
 
         switch ((ULONG_PTR)IrpSp->FileObject->FsContext2) {
           case TDI_TRANSPORT_ADDRESS_FILE:
             AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle;
-            break;
+
+                       Address->TAAddressCount = 1;
+                       Address->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
+                       Address->Address[0].AddressType = TDI_ADDRESS_TYPE_IP;
+                       Address->Address[0].Address[0].sin_port = AddrFile->Port;
+                       Address->Address[0].Address[0].in_addr = AddrFile->Address.Address.IPv4Address;
+                       RtlZeroMemory(
+                               &Address->Address[0].Address[0].sin_zero,
+                               sizeof(Address->Address[0].Address[0].sin_zero));
+                       TcpipRecursiveMutexLeave(&TCPLock);
+                       return STATUS_SUCCESS;
 
           case TDI_CONNECTION_FILE:
-            AddrFile =
-              ((PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext)->
-              AddressFile;
-            break;
+            Endpoint =
+                               (PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
+                       TCPGetSockAddress( Endpoint, (PTRANSPORT_ADDRESS)Address, FALSE );
+                       DbgPrint("Returning socket address %x\n", Address->Address[0].Address[0].in_addr);
+                       RtlZeroMemory(
+                               &Address->Address[0].Address[0].sin_zero,
+                               sizeof(Address->Address[0].Address[0].sin_zero));
+                       TcpipRecursiveMutexLeave(&TCPLock);
+                       return STATUS_SUCCESS;
 
           default:
             TI_DbgPrint(MIN_TRACE, ("Invalid transport context\n"));
+            TcpipRecursiveMutexLeave(&TCPLock);
             return STATUS_INVALID_PARAMETER;
         }
 
         if (!AddrFile) {
           TI_DbgPrint(MID_TRACE, ("No address file object.\n"));
+          TcpipRecursiveMutexLeave(&TCPLock);
           return STATUS_INVALID_PARAMETER;
         }
 
-        if (MmGetMdlByteCount(Irp->MdlAddress) <
-            (FIELD_OFFSET(TDI_ADDRESS_INFO, Address.Address[0].Address) +
-             sizeof(TDI_ADDRESS_IP))) {
-          TI_DbgPrint(MID_TRACE, ("MDL buffer too small.\n"));
-          return STATUS_BUFFER_OVERFLOW;
-        }
-
-        Address = (PTA_IP_ADDRESS)&AddressInfo->Address;
-        Address->TAAddressCount = 1;
-        Address->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
-        Address->Address[0].AddressType = TDI_ADDRESS_TYPE_IP;
-        Address->Address[0].Address[0].sin_port = AddrFile->Port;
-        Address->Address[0].Address[0].in_addr =
-          AddrFile->Address.Address.IPv4Address;
-        RtlZeroMemory(
-          &Address->Address[0].Address[0].sin_zero,
-          sizeof(Address->Address[0].Address[0].sin_zero));
-
+        TcpipRecursiveMutexLeave(&TCPLock);
         return STATUS_SUCCESS;
       }
 
@@ -722,6 +733,14 @@ NTSTATUS DispTdiQueryInformation(
         PADDRESS_FILE AddrFile;
         PCONNECTION_ENDPOINT Endpoint = NULL;
 
+        if (MmGetMdlByteCount(Irp->MdlAddress) <
+            (FIELD_OFFSET(TDI_CONNECTION_INFORMATION, RemoteAddress) +
+             sizeof(PVOID))) {
+          TI_DbgPrint(MID_TRACE, ("MDL buffer too small (ptr).\n"));
+          TcpipRecursiveMutexLeave(&TCPLock);
+          return STATUS_BUFFER_TOO_SMALL;
+        }
+
         AddressInfo = (PTDI_CONNECTION_INFORMATION)
           MmGetSystemAddressForMdl(Irp->MdlAddress);
 
@@ -737,25 +756,24 @@ NTSTATUS DispTdiQueryInformation(
 
           default:
             TI_DbgPrint(MIN_TRACE, ("Invalid transport context\n"));
+            TcpipRecursiveMutexLeave(&TCPLock);
             return STATUS_INVALID_PARAMETER;
         }
 
         if (!Endpoint) {
           TI_DbgPrint(MID_TRACE, ("No connection object.\n"));
+          TcpipRecursiveMutexLeave(&TCPLock);
           return STATUS_INVALID_PARAMETER;
         }
 
-        if (MmGetMdlByteCount(Irp->MdlAddress) <
-            (FIELD_OFFSET(TDI_CONNECTION_INFORMATION, RemoteAddress) +
-             sizeof(PVOID))) {
-          TI_DbgPrint(MID_TRACE, ("MDL buffer too small (ptr).\n"));
-          return STATUS_BUFFER_OVERFLOW;
-        }
+        Status = TCPGetSockAddress( Endpoint, AddressInfo->RemoteAddress, TRUE );
 
-        return TCPGetPeerAddress( Endpoint, AddressInfo->RemoteAddress );
+        TcpipRecursiveMutexLeave(&TCPLock);
+        return Status;
       }
   }
 
+  TcpipRecursiveMutexLeave(&TCPLock);
   return STATUS_NOT_IMPLEMENTED;
 }
 
@@ -787,14 +805,14 @@ NTSTATUS DispTdiReceive(
   if (TranContext == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
   if (TranContext->Handle.ConnectionContext == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("No connection endpoint file object.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
@@ -818,13 +836,13 @@ NTSTATUS DispTdiReceive(
     }
 
 done:
+  TcpipRecursiveMutexLeave( &TCPLock );
+
   if (Status != STATUS_PENDING) {
       DispDataRequestComplete(Irp, Status, BytesReceived);
   } else
       IoMarkIrpPending(Irp);
 
-  TcpipRecursiveMutexLeave( &TCPLock );
-
   TI_DbgPrint(DEBUG_IRP, ("Leaving. Status is (0x%X)\n", Status));
 
   return Status;
@@ -859,7 +877,7 @@ NTSTATUS DispTdiReceiveDatagram(
   if (TranContext == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
@@ -875,7 +893,7 @@ NTSTATUS DispTdiReceiveDatagram(
 
   if (NT_SUCCESS(Status))
     {
-       PCHAR DataBuffer;
+       PVOID DataBuffer;
        UINT BufferSize;
 
        NdisQueryBuffer( (PNDIS_BUFFER)Irp->MdlAddress,
@@ -896,13 +914,13 @@ NTSTATUS DispTdiReceiveDatagram(
     }
 
 done:
+   TcpipRecursiveMutexLeave( &TCPLock );
+
    if (Status != STATUS_PENDING) {
        DispDataRequestComplete(Irp, Status, BytesReceived);
    } else
        IoMarkIrpPending(Irp);
 
-  TcpipRecursiveMutexLeave( &TCPLock );
-
   TI_DbgPrint(DEBUG_IRP, ("Leaving. Status is (0x%X)\n", Status));
 
   return Status;
@@ -936,14 +954,14 @@ NTSTATUS DispTdiSend(
   if (TranContext == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
   if (TranContext->Handle.ConnectionContext == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("No connection endpoint file object.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
@@ -955,7 +973,7 @@ NTSTATUS DispTdiSend(
   TI_DbgPrint(MID_TRACE,("TCPIP<<< Got an MDL: %x\n", Irp->MdlAddress));
   if (NT_SUCCESS(Status))
     {
-       PCHAR Data;
+       PVOID Data;
        UINT Len;
 
        NdisQueryBuffer( Irp->MdlAddress, &Data, &Len );
@@ -972,13 +990,13 @@ NTSTATUS DispTdiSend(
     }
 
 done:
+   TcpipRecursiveMutexLeave( &TCPLock );
+
    if (Status != STATUS_PENDING) {
        DispDataRequestComplete(Irp, Status, BytesSent);
    } else
        IoMarkIrpPending(Irp);
 
-  TcpipRecursiveMutexLeave( &TCPLock );
-
   TI_DbgPrint(DEBUG_IRP, ("Leaving. Status is (0x%X)\n", Status));
 
   return Status;
@@ -1012,7 +1030,7 @@ NTSTATUS DispTdiSendDatagram(
     if (TranContext == NULL)
     {
       TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
-      Status = STATUS_INVALID_CONNECTION;
+      Status = STATUS_INVALID_PARAMETER;
       goto done;
     }
 
@@ -1027,7 +1045,7 @@ NTSTATUS DispTdiSendDatagram(
         (PDRIVER_CANCEL)DispCancelRequest);
 
     if (NT_SUCCESS(Status)) {
-       PCHAR DataBuffer;
+       PVOID DataBuffer;
        UINT BufferSize;
 
        TI_DbgPrint(MID_TRACE,("About to query buffer %x\n", Irp->MdlAddress));
@@ -1058,13 +1076,13 @@ NTSTATUS DispTdiSendDatagram(
     }
 
 done:
+    TcpipRecursiveMutexLeave( &TCPLock );
+
     if (Status != STATUS_PENDING) {
         DispDataRequestComplete(Irp, Status, Irp->IoStatus.Information);
     } else
         IoMarkIrpPending(Irp);
 
-    TcpipRecursiveMutexLeave( &TCPLock );
-
     TI_DbgPrint(DEBUG_IRP, ("Leaving.\n"));
 
     return Status;
@@ -1285,7 +1303,7 @@ VOID DispTdiQueryInformationExComplete(
     QueryContext->Irp->IoStatus.Information = ByteCount;
     QueryContext->Irp->IoStatus.Status      = Status;
 
-    ExFreePool(QueryContext);
+    exFreePool(QueryContext);
 }
 
 
@@ -1348,7 +1366,7 @@ NTSTATUS DispTdiQueryInformationEx(
             IrpSp->Parameters.DeviceIoControl.Type3InputBuffer;
         OutputBuffer = Irp->UserBuffer;
 
-        QueryContext = ExAllocatePool(NonPagedPool, sizeof(TI_QUERY_CONTEXT));
+        QueryContext = exAllocatePool(NonPagedPool, sizeof(TI_QUERY_CONTEXT));
         if (QueryContext) {
            _SEH2_TRY {
                 InputMdl = IoAllocateMdl(InputBuffer,
@@ -1411,7 +1429,7 @@ NTSTATUS DispTdiQueryInformationEx(
                 IoFreeMdl(OutputMdl);
             }
 
-            ExFreePool(QueryContext);
+            exFreePool(QueryContext);
         } else
             Status = STATUS_INSUFFICIENT_RESOURCES;
     } else if( InputBufferLength ==
@@ -1424,7 +1442,7 @@ NTSTATUS DispTdiQueryInformationEx(
 
        Size = 0;
 
-        QueryContext = ExAllocatePool(NonPagedPool, sizeof(TI_QUERY_CONTEXT));
+        QueryContext = exAllocatePool(NonPagedPool, sizeof(TI_QUERY_CONTEXT));
         if (!QueryContext) return STATUS_INSUFFICIENT_RESOURCES;
 
        _SEH2_TRY {
@@ -1444,7 +1462,7 @@ NTSTATUS DispTdiQueryInformationEx(
 
        if( !NT_SUCCESS(Status) || !InputMdl ) {
            if( InputMdl ) IoFreeMdl( InputMdl );
-           ExFreePool(QueryContext);
+           exFreePool(QueryContext);
            return Status;
        }
 
@@ -1557,6 +1575,7 @@ NTSTATUS DispTdiSetIPAddress( PIRP Irp, PIO_STACK_LOCATION IrpSp ) {
             IF->Unicast.Address.IPv4Address = IpAddrChange->Address;
             IF->Netmask.Type = IP_ADDRESS_V4;
             IF->Netmask.Address.IPv4Address = IpAddrChange->Netmask;
+            IF->Broadcast.Type = IP_ADDRESS_V4;
            IF->Broadcast.Address.IPv4Address =
                IF->Unicast.Address.IPv4Address |
                ~IF->Netmask.Address.IPv4Address;
@@ -1591,6 +1610,8 @@ NTSTATUS DispTdiDeleteIPAddress( PIRP Irp, PIO_STACK_LOCATION IrpSp ) {
             IF->Unicast.Address.IPv4Address = 0;
             IF->Netmask.Type = IP_ADDRESS_V4;
             IF->Netmask.Address.IPv4Address = 0;
+            IF->Broadcast.Type = IP_ADDRESS_V4;
+            IF->Broadcast.Address.IPv4Address = 0;
             Status = STATUS_SUCCESS;
         }
     } EndFor(IF);