[TCPIP] Store creator PID in ADDRESS_FILE
[reactos.git] / drivers / network / tcpip / tcpip / fileobjs.c
index 6c475ee..c637150 100644 (file)
@@ -10,6 +10,8 @@
 
 #include "precomp.h"
 
+/* Uncomment for logging of connections and address files every 10 seconds */
+//#define LOG_OBJECTS
 
 /* List of all address file objects managed by this driver */
 LIST_ENTRY AddressFileListHead;
@@ -99,12 +101,128 @@ BOOLEAN AddrReceiveMatch(
    return FALSE;
 }
 
+VOID
+LogActiveObjects(VOID)
+{
+#ifdef LOG_OBJECTS
+    PLIST_ENTRY CurrentEntry;
+    KIRQL OldIrql;
+    PADDRESS_FILE AddrFile;
+    PCONNECTION_ENDPOINT Conn;
+
+    DbgPrint("----------- TCP/IP Active Object Dump -------------\n");
+    
+    TcpipAcquireSpinLock(&AddressFileListLock, &OldIrql);
+
+    CurrentEntry = AddressFileListHead.Flink;
+    while (CurrentEntry != &AddressFileListHead)
+    {
+        AddrFile = CONTAINING_RECORD(CurrentEntry, ADDRESS_FILE, ListEntry);
+
+        DbgPrint("Address File (%s, %d, %d) @ 0x%p | Ref count: %d | Sharers: %d\n",
+                 A2S(&AddrFile->Address), WN2H(AddrFile->Port), AddrFile->Protocol,
+                 AddrFile, AddrFile->RefCount, AddrFile->Sharers);
+        DbgPrint("\tListener: ");
+        if (AddrFile->Listener == NULL)
+            DbgPrint("<None>\n");
+        else
+            DbgPrint("0x%p\n", AddrFile->Listener);
+        DbgPrint("\tAssociated endpoints: ");
+        if (AddrFile->Connection == NULL)
+            DbgPrint("<None>\n");
+        else
+        {
+            Conn = AddrFile->Connection;
+            while (Conn)
+            {
+                DbgPrint("0x%p ", Conn);
+                Conn = Conn->Next;
+            }
+            DbgPrint("\n");
+        }
+        
+        CurrentEntry = CurrentEntry->Flink;
+    }
+    
+    TcpipReleaseSpinLock(&AddressFileListLock, OldIrql);
+    
+    TcpipAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql);
+    
+    CurrentEntry = ConnectionEndpointListHead.Flink;
+    while (CurrentEntry != &ConnectionEndpointListHead)
+    {
+        Conn = CONTAINING_RECORD(CurrentEntry, CONNECTION_ENDPOINT, ListEntry);
+        
+        DbgPrint("Connection @ 0x%p | Ref count: %d\n", Conn, Conn->RefCount);
+        DbgPrint("\tPCB: ");
+        if (Conn->SocketContext == NULL)
+            DbgPrint("<None>\n");
+        else
+        {
+            DbgPrint("0x%p\n", Conn->SocketContext);
+            LibTCPDumpPcb(Conn->SocketContext);
+        }
+        DbgPrint("\tPacket queue status: %s\n", IsListEmpty(&Conn->PacketQueue) ? "Empty" : "Not Empty");
+        DbgPrint("\tRequest lists: Connect: %s | Recv: %s | Send: %s | Shutdown: %s | Listen: %s\n",
+                 IsListEmpty(&Conn->ConnectRequest) ? "Empty" : "Not Empty",
+                 IsListEmpty(&Conn->ReceiveRequest) ? "Empty" : "Not Empty",
+                 IsListEmpty(&Conn->SendRequest) ? "Empty" : "Not Empty",
+                 IsListEmpty(&Conn->ShutdownRequest) ? "Empty" : "Not Empty",
+                 IsListEmpty(&Conn->ListenRequest) ? "Empty" : "Not Empty");
+        DbgPrint("\tSend shutdown: %s\n", Conn->SendShutdown ? "Yes" : "No");
+        DbgPrint("\tReceive shutdown: %s\n", Conn->ReceiveShutdown ? "Yes" : "No");
+        if (Conn->ReceiveShutdown) DbgPrint("\tReceive shutdown status: 0x%x\n", Conn->ReceiveShutdownStatus);
+        DbgPrint("\tClosing: %s\n", Conn->Closing ? "Yes" : "No");
+        
+        CurrentEntry = CurrentEntry->Flink;
+    }
+    
+    TcpipReleaseSpinLock(&ConnectionEndpointListLock, OldIrql);
+
+    DbgPrint("---------------------------------------------------\n");
+#endif
+}
+
+PADDRESS_FILE AddrFindShared(
+    PIP_ADDRESS BindAddress,
+    USHORT Port,
+    USHORT Protocol)
+{
+    PLIST_ENTRY CurrentEntry;
+    KIRQL OldIrql;
+    PADDRESS_FILE Current = NULL;
+
+    TcpipAcquireSpinLock(&AddressFileListLock, &OldIrql);
+
+    CurrentEntry = AddressFileListHead.Flink;
+    while (CurrentEntry != &AddressFileListHead) {
+        Current = CONTAINING_RECORD(CurrentEntry, ADDRESS_FILE, ListEntry);
+
+        /* See if this address matches the search criteria */
+        if ((Current->Port == Port) &&
+            (Current->Protocol == Protocol))
+        {
+            /* Increase the sharer count */
+            ASSERT(Current->Sharers != 0);
+            InterlockedIncrement(&Current->Sharers);
+            break;
+        }
+
+        CurrentEntry = CurrentEntry->Flink;
+        Current = NULL;
+    }
+
+    TcpipReleaseSpinLock(&AddressFileListLock, OldIrql);
+
+    return Current;
+}
+
 /*
  * FUNCTION: Searches through address file entries to find next match
  * ARGUMENTS:
  *     SearchContext = Pointer to search context
  * RETURNS:
- *     Pointer to address file, NULL if none was found
+ *     Pointer to referenced address file, NULL if none was found
  */
 PADDRESS_FILE AddrSearchNext(
     PAF_SEARCH SearchContext)
@@ -114,6 +232,7 @@ PADDRESS_FILE AddrSearchNext(
     KIRQL OldIrql;
     PADDRESS_FILE Current = NULL;
     BOOLEAN Found = FALSE;
+    PADDRESS_FILE StartingAddrFile;
     
     TcpipAcquireSpinLock(&AddressFileListLock, &OldIrql);
 
@@ -123,8 +242,8 @@ PADDRESS_FILE AddrSearchNext(
         return NULL;
     }
 
-    /* Remove the extra reference we added to keep this address file in memory */
-    DereferenceObject(CONTAINING_RECORD(SearchContext->Next, ADDRESS_FILE, ListEntry));
+    /* Save this pointer so we can dereference it later */
+    StartingAddrFile = CONTAINING_RECORD(SearchContext->Next, ADDRESS_FILE, ListEntry);
 
     CurrentEntry = SearchContext->Next;
 
@@ -161,10 +280,16 @@ PADDRESS_FILE AddrSearchNext(
             /* Reference the next address file to prevent the link from disappearing behind our back */
             ReferenceObject(CONTAINING_RECORD(SearchContext->Next, ADDRESS_FILE, ListEntry));
         }
+
+        /* Reference the returned address file before dereferencing the starting
+         * address file because it may be that Current == StartingAddrFile */
+        ReferenceObject(Current);
     }
     else
         Current = NULL;
 
+    DereferenceObject(StartingAddrFile);
+
     TcpipReleaseSpinLock(&AddressFileListLock, OldIrql);
 
     return Current;
@@ -254,6 +379,7 @@ VOID ControlChannelFree(
  *     Request  = Pointer to TDI request structure for this request
  *     Address  = Pointer to address to be opened
  *     Protocol = Protocol on which to open the address
+ *     Shared   = Specifies if the address is opened for shared access
  *     Options  = Pointer to option buffer
  * RETURNS:
  *     Status of operation
@@ -262,12 +388,24 @@ NTSTATUS FileOpenAddress(
   PTDI_REQUEST Request,
   PTA_IP_ADDRESS Address,
   USHORT Protocol,
+  BOOLEAN Shared,
   PVOID Options)
 {
   PADDRESS_FILE AddrFile;
 
   TI_DbgPrint(MID_TRACE, ("Called (Proto %d).\n", Protocol));
 
+  /* If it's shared and has a port specified, look for a match */
+  if ((Shared != FALSE) && (Address->Address[0].Address[0].sin_port != 0))
+  {
+    AddrFile = AddrFindShared(NULL, Address->Address[0].Address[0].sin_port, Protocol);
+    if (AddrFile != NULL)
+    {
+        Request->Handle.AddressHandle = AddrFile;
+        return STATUS_SUCCESS;
+    }
+  }
+
   AddrFile = ExAllocatePoolWithTag(NonPagedPool, sizeof(ADDRESS_FILE),
                                    ADDR_FILE_TAG);
   if (!AddrFile) {
@@ -279,12 +417,14 @@ NTSTATUS FileOpenAddress(
 
   AddrFile->RefCount = 1;
   AddrFile->Free = AddrFileFree;
+  AddrFile->Sharers = 1;
 
   /* Set our default options */
   AddrFile->TTL = 128;
   AddrFile->DF = 0;
   AddrFile->BCast = 1;
   AddrFile->HeaderIncl = 1;
+  AddrFile->ProcessId = PsGetCurrentProcessId();
 
   /* Make sure address is a local unicast address or 0 */
   /* FIXME: IPv4 only */
@@ -294,8 +434,8 @@ NTSTATUS FileOpenAddress(
 
   if (!AddrIsUnspecified(&AddrFile->Address) &&
       !AddrLocateInterface(&AddrFile->Address)) {
-         ExFreePoolWithTag(AddrFile, ADDR_FILE_TAG);
          TI_DbgPrint(MIN_TRACE, ("Non-local address given (0x%X).\n", A2S(&AddrFile->Address)));
+         ExFreePoolWithTag(AddrFile, ADDR_FILE_TAG);
          return STATUS_INVALID_ADDRESS;
   }
 
@@ -431,6 +571,13 @@ NTSTATUS FileCloseAddress(
 
   LockObject(AddrFile, &OldIrql);
 
+  if (InterlockedDecrement(&AddrFile->Sharers) != 0)
+  {
+      /* Still other guys have open handles to this, so keep it around */
+      UnlockObject(AddrFile, OldIrql);
+      return STATUS_SUCCESS;
+  }
+
   /* We have to close this listener because we started it */
   if( AddrFile->Listener )
   {