- Rework our oskittcp signalling
[reactos.git] / reactos / lib / drivers / ip / transport / tcp / event.c
index 39de114..ee55ce9 100644 (file)
 #include "precomp.h"
 
 int TCPSocketState(void *ClientData,
-                  void *WhichSocket,
-                  void *WhichConnection,
-                  OSK_UINT NewState ) {
+           void *WhichSocket,
+           void *WhichConnection,
+           OSK_UINT NewState ) {
     PCONNECTION_ENDPOINT Connection = WhichConnection;
+    ULONG OldState;
+    KIRQL OldIrql;
+
+    ASSERT_LOCKED(&TCPLock);
 
     TI_DbgPrint(MID_TRACE,("Flags: %c%c%c%c\n",
-                          NewState & SEL_CONNECT ? 'C' : 'c',
-                          NewState & SEL_READ    ? 'R' : 'r',
-                          NewState & SEL_FIN     ? 'F' : 'f',
-                          NewState & SEL_ACCEPT  ? 'A' : 'a'));
+               NewState & SEL_CONNECT ? 'C' : 'c',
+               NewState & SEL_READ    ? 'R' : 'r',
+               NewState & SEL_FIN     ? 'F' : 'f',
+               NewState & SEL_ACCEPT  ? 'A' : 'a'));
 
     TI_DbgPrint(DEBUG_TCP,("Called: NewState %x (Conn %x) (Change %x)\n",
-                          NewState, Connection,
-                          Connection ? Connection->State ^ NewState :
-                          NewState));
+               NewState, Connection,
+               Connection ? Connection->SignalState ^ NewState :
+               NewState));
 
     if( !Connection ) {
-       TI_DbgPrint(DEBUG_TCP,("Socket closing.\n"));
-       Connection = FileFindConnectionByContext( WhichSocket );
-       if( !Connection )
-           return 0;
-       else
-           TI_DbgPrint(DEBUG_TCP,("Found socket %x\n", Connection));
+    TI_DbgPrint(DEBUG_TCP,("Socket closing.\n"));
+    Connection = FileFindConnectionByContext( WhichSocket );
+    if( !Connection )
+        return 0;
+    else
+        TI_DbgPrint(DEBUG_TCP,("Found socket %x\n", Connection));
     }
 
-    TI_DbgPrint(MID_TRACE,("Connection signalled: %d\n",
-                          Connection->Signalled));
+    OldState = Connection->SignalState;
 
     Connection->SignalState |= NewState;
-    if( !Connection->Signalled ) {
-       Connection->Signalled = TRUE;
-       InsertTailList( &SignalledConnections, &Connection->SignalList );
+
+    NewState = HandleSignalledConnection(Connection);
+
+    KeAcquireSpinLock(&SignalledConnectionsLock, &OldIrql);
+    if ((NewState == 0 || NewState == SEL_FIN) &&
+        (OldState != 0 && OldState != SEL_FIN))
+    {
+        RemoveEntryList(&Connection->SignalList);
     }
+    else if (NewState != 0 && NewState != SEL_FIN)
+    {
+        InsertTailList(&SignalledConnectionsList, &Connection->SignalList);
+    }
+    KeReleaseSpinLock(&SignalledConnectionsLock, OldIrql);
 
     return 0;
 }
 
 void TCPPacketSendComplete( PVOID Context,
-                           PNDIS_PACKET NdisPacket,
-                           NDIS_STATUS NdisStatus ) {
+                PNDIS_PACKET NdisPacket,
+                NDIS_STATUS NdisStatus ) {
     TI_DbgPrint(DEBUG_TCP,("called %x\n", NdisPacket));
     FreeNdisPacket(NdisPacket);
     TI_DbgPrint(DEBUG_TCP,("done\n"));
@@ -65,35 +78,34 @@ int TCPPacketSend(void *ClientData, OSK_PCHAR data, OSK_UINT len ) {
     IP_ADDRESS RemoteAddress, LocalAddress;
     PIPv4_HEADER Header;
 
+    ASSERT_LOCKED(&TCPLock);
+
     if( *data == 0x45 ) { /* IPv4 */
-       Header = (PIPv4_HEADER)data;
-       LocalAddress.Type = IP_ADDRESS_V4;
-       LocalAddress.Address.IPv4Address = Header->SrcAddr;
-       RemoteAddress.Type = IP_ADDRESS_V4;
-       RemoteAddress.Address.IPv4Address = Header->DstAddr;
+    Header = (PIPv4_HEADER)data;
+    LocalAddress.Type = IP_ADDRESS_V4;
+    LocalAddress.Address.IPv4Address = Header->SrcAddr;
+    RemoteAddress.Type = IP_ADDRESS_V4;
+    RemoteAddress.Address.IPv4Address = Header->DstAddr;
     } else {
-       TI_DbgPrint(MIN_TRACE,("Outgoing packet is not IPv4\n"));
-       OskitDumpBuffer( data, len );
-       return OSK_EINVAL;
+    TI_DbgPrint(MIN_TRACE,("Outgoing packet is not IPv4\n"));
+    OskitDumpBuffer( data, len );
+    return OSK_EINVAL;
     }
 
-    RemoteAddress.Type = LocalAddress.Type = IP_ADDRESS_V4;
-
     if(!(NCE = RouteGetRouteToDestination( &RemoteAddress ))) {
-       TI_DbgPrint(MIN_TRACE,("No route to %s\n", A2S(&RemoteAddress)));
-       return OSK_EADDRNOTAVAIL;
+    TI_DbgPrint(MIN_TRACE,("No route to %s\n", A2S(&RemoteAddress)));
+    return OSK_EADDRNOTAVAIL;
     }
 
-    NdisStatus = AllocatePacketWithBuffer( &Packet.NdisPacket, NULL,
-                                          MaxLLHeaderSize + len );
+    NdisStatus = AllocatePacketWithBuffer( &Packet.NdisPacket, NULL, len );
 
     if (NdisStatus != NDIS_STATUS_SUCCESS) {
-       TI_DbgPrint(DEBUG_TCP, ("Error from NDIS: %08x\n", NdisStatus));
-       return OSK_ENOBUFS;
+    TI_DbgPrint(DEBUG_TCP, ("Error from NDIS: %08x\n", NdisStatus));
+    return OSK_ENOBUFS;
     }
 
-    GetDataPtr( Packet.NdisPacket, MaxLLHeaderSize,
-               (PCHAR *)&Packet.Header, &Packet.ContigSize );
+    GetDataPtr( Packet.NdisPacket, 0,
+        (PCHAR *)&Packet.Header, &Packet.ContigSize );
 
     RtlCopyMemory( Packet.Header, data, len );
 
@@ -102,49 +114,59 @@ int TCPPacketSend(void *ClientData, OSK_PCHAR data, OSK_UINT len ) {
     Packet.SrcAddr = LocalAddress;
     Packet.DstAddr = RemoteAddress;
 
-    IPSendDatagram( &Packet, NCE, TCPPacketSendComplete, NULL );
+    if (!NT_SUCCESS(IPSendDatagram( &Packet, NCE, TCPPacketSendComplete, NULL )))
+    {
+        FreeNdisPacket(Packet.NdisPacket);
+        return OSK_EINVAL;
+    }
 
-    if( !NT_SUCCESS(NdisStatus) ) return OSK_EINVAL;
-    else return 0;
+    return 0;
 }
 
 int TCPSleep( void *ClientData, void *token, int priority, char *msg,
-             int tmio ) {
+          int tmio ) {
     PSLEEPING_THREAD SleepingThread;
+    LARGE_INTEGER Timeout;
+
+    ASSERT_LOCKED(&TCPLock);
 
     TI_DbgPrint(DEBUG_TCP,
-               ("Called TSLEEP: tok = %x, pri = %d, wmesg = %s, tmio = %x\n",
-                token, priority, msg, tmio));
+        ("Called TSLEEP: tok = %x, pri = %d, wmesg = %s, tmio = %x\n",
+         token, priority, msg, tmio));
 
-    SleepingThread = PoolAllocateBuffer( sizeof( *SleepingThread ) );
+    SleepingThread = exAllocatePool( NonPagedPool, sizeof( *SleepingThread ) );
     if( SleepingThread ) {
-       KeInitializeEvent( &SleepingThread->Event, NotificationEvent, FALSE );
-       SleepingThread->SleepToken = token;
+    KeInitializeEvent( &SleepingThread->Event, NotificationEvent, FALSE );
+    SleepingThread->SleepToken = token;
 
-       /* We're going to sleep and need to release the lock, otherwise
+    /* We're going to sleep and need to release the lock, otherwise
            it's impossible to re-enter oskittcp to deliver the event that's
            going to wake us */
-       TcpipRecursiveMutexLeave( &TCPLock );
+    TcpipRecursiveMutexLeave( &TCPLock );
+
+    TcpipAcquireFastMutex( &SleepingThreadsLock );
+    InsertTailList( &SleepingThreadsList, &SleepingThread->Entry );
+    TcpipReleaseFastMutex( &SleepingThreadsLock );
 
-       TcpipAcquireFastMutex( &SleepingThreadsLock );
-       InsertTailList( &SleepingThreadsList, &SleepingThread->Entry );
-       TcpipReleaseFastMutex( &SleepingThreadsLock );
+        Timeout.QuadPart = Int32x32To64(tmio, -10000);
 
-       TI_DbgPrint(DEBUG_TCP,("Waiting on %x\n", token));
-       KeWaitForSingleObject( &SleepingThread->Event,
-                              WrSuspended,
-                              KernelMode,
-                              TRUE,
-                              NULL );
+    TI_DbgPrint(DEBUG_TCP,("Waiting on %x\n", token));
+    KeWaitForSingleObject( &SleepingThread->Event,
+                   Executive,
+                   KernelMode,
+                   TRUE,
+                   (tmio != 0) ? &Timeout : NULL );
 
-       TcpipAcquireFastMutex( &SleepingThreadsLock );
-       RemoveEntryList( &SleepingThread->Entry );
-       TcpipReleaseFastMutex( &SleepingThreadsLock );
+    TcpipAcquireFastMutex( &SleepingThreadsLock );
+    RemoveEntryList( &SleepingThread->Entry );
+    TcpipReleaseFastMutex( &SleepingThreadsLock );
 
-       TcpipRecursiveMutexEnter( &TCPLock, TRUE );
+    TcpipRecursiveMutexEnter( &TCPLock, TRUE );
+
+    exFreePool( SleepingThread );
+    } else
+        return OSK_ENOBUFS;
 
-       PoolFreeBuffer( SleepingThread );
-    }
     TI_DbgPrint(DEBUG_TCP,("Waiting finished: %x\n", token));
     return 0;
 }
@@ -153,16 +175,18 @@ void TCPWakeup( void *ClientData, void *token ) {
     PLIST_ENTRY Entry;
     PSLEEPING_THREAD SleepingThread;
 
+    ASSERT_LOCKED(&TCPLock);
+
     TcpipAcquireFastMutex( &SleepingThreadsLock );
     Entry = SleepingThreadsList.Flink;
     while( Entry != &SleepingThreadsList ) {
-       SleepingThread = CONTAINING_RECORD(Entry, SLEEPING_THREAD, Entry);
-       TI_DbgPrint(DEBUG_TCP,("Sleeper @ %x\n", SleepingThread));
-       if( SleepingThread->SleepToken == token ) {
-           TI_DbgPrint(DEBUG_TCP,("Setting event to wake %x\n", token));
-           KeSetEvent( &SleepingThread->Event, IO_NETWORK_INCREMENT, FALSE );
-       }
-       Entry = Entry->Flink;
+    SleepingThread = CONTAINING_RECORD(Entry, SLEEPING_THREAD, Entry);
+    TI_DbgPrint(DEBUG_TCP,("Sleeper @ %x\n", SleepingThread));
+    if( SleepingThread->SleepToken == token ) {
+        TI_DbgPrint(DEBUG_TCP,("Setting event to wake %x\n", token));
+        KeSetEvent( &SleepingThread->Event, IO_NETWORK_INCREMENT, FALSE );
+    }
+    Entry = Entry->Flink;
     }
     TcpipReleaseFastMutex( &SleepingThreadsLock );
 }
@@ -183,10 +207,10 @@ void TCPWakeup( void *ClientData, void *token ) {
 #define SMALL_SIZE 128
 #define LARGE_SIZE 2048
 
-#define SIGNATURE_LARGE TAG('L','L','L','L')
-#define SIGNATURE_SMALL TAG('S','S','S','S')
-#define SIGNATURE_OTHER TAG('O','O','O','O')
-#define TCP_TAG TAG('T','C','P',' ')
+#define SIGNATURE_LARGE 'LLLL'
+#define SIGNATURE_SMALL 'SSSS'
+#define SIGNATURE_OTHER 'OOOO'
+#define TCP_TAG ' PCT'
 
 static NPAGED_LOOKASIDE_LIST LargeLookasideList;
 static NPAGED_LOOKASIDE_LIST SmallLookasideList;
@@ -213,10 +237,12 @@ TCPMemStartup( void )
 }
 
 void *TCPMalloc( void *ClientData,
-                OSK_UINT Bytes, OSK_PCHAR File, OSK_UINT Line ) {
+         OSK_UINT Bytes, OSK_PCHAR File, OSK_UINT Line ) {
     void *v;
     ULONG Signature;
 
+    ASSERT_LOCKED(&TCPLock);
+
 #if 0 != MEM_PROFILE
     static OSK_UINT *Sizes = NULL, *Counts = NULL, ArrayAllocated = 0;
     static OSK_UINT ArrayUsed = 0, AllocationCount = 0;
@@ -225,84 +251,86 @@ void *TCPMalloc( void *ClientData,
 
     Found = 0;
     for ( i = 0; i < ArrayUsed && ! Found; i++ ) {
-       Found = ( Sizes[i] == Bytes );
-       if ( Found ) {
-           Counts[i]++;
-       }
+    Found = ( Sizes[i] == Bytes );
+    if ( Found ) {
+        Counts[i]++;
+    }
     }
     if ( ! Found ) {
-       if ( ArrayAllocated <= ArrayUsed ) {
-           NewSize = ( 0 == ArrayAllocated ? 16 : 2 * ArrayAllocated );
-           NewArray = PoolAllocateBuffer( 2 * NewSize * sizeof( OSK_UINT ) );
-           if ( NULL != NewArray ) {
-               if ( 0 != ArrayAllocated ) {
-                   memcpy( NewArray, Sizes,
-                           ArrayAllocated * sizeof( OSK_UINT ) );
-                   PoolFreeBuffer( Sizes );
-                   memcpy( NewArray + NewSize, Counts,
-                           ArrayAllocated * sizeof( OSK_UINT ) );
-                   PoolFreeBuffer( Counts );
-               }
-               Sizes = NewArray;
-               Counts = NewArray + NewSize;
-               ArrayAllocated = NewSize;
-           } else if ( 0 != ArrayAllocated ) {
-               PoolFreeBuffer( Sizes );
-               PoolFreeBuffer( Counts );
-               ArrayAllocated = 0;
-           }
-       }
-       if ( ArrayUsed < ArrayAllocated ) {
-           Sizes[ArrayUsed] = Bytes;
-           Counts[ArrayUsed] = 1;
-           ArrayUsed++;
-       }
+    if ( ArrayAllocated <= ArrayUsed ) {
+        NewSize = ( 0 == ArrayAllocated ? 16 : 2 * ArrayAllocated );
+        NewArray = exAllocatePool( NonPagedPool, 2 * NewSize * sizeof( OSK_UINT ) );
+        if ( NULL != NewArray ) {
+        if ( 0 != ArrayAllocated ) {
+            memcpy( NewArray, Sizes,
+                    ArrayAllocated * sizeof( OSK_UINT ) );
+            exFreePool( Sizes );
+            memcpy( NewArray + NewSize, Counts,
+                    ArrayAllocated * sizeof( OSK_UINT ) );
+            exFreePool( Counts );
+        }
+        Sizes = NewArray;
+        Counts = NewArray + NewSize;
+        ArrayAllocated = NewSize;
+        } else if ( 0 != ArrayAllocated ) {
+        exFreePool( Sizes );
+        exFreePool( Counts );
+        ArrayAllocated = 0;
+        }
+    }
+    if ( ArrayUsed < ArrayAllocated ) {
+        Sizes[ArrayUsed] = Bytes;
+        Counts[ArrayUsed] = 1;
+        ArrayUsed++;
+    }
     }
 
     if ( 0 == (++AllocationCount % MEM_PROFILE) ) {
-       TI_DbgPrint(DEBUG_TCP, ("Memory allocation size profile:\n"));
-       for ( i = 0; i < ArrayUsed; i++ ) {
-           TI_DbgPrint(DEBUG_TCP,
-                       ("Size %4u Count %5u\n", Sizes[i], Counts[i]));
-       }
-       TI_DbgPrint(DEBUG_TCP, ("End of memory allocation size profile\n"));
+    TI_DbgPrint(DEBUG_TCP, ("Memory allocation size profile:\n"));
+    for ( i = 0; i < ArrayUsed; i++ ) {
+        TI_DbgPrint(DEBUG_TCP,
+                    ("Size %4u Count %5u\n", Sizes[i], Counts[i]));
+    }
+    TI_DbgPrint(DEBUG_TCP, ("End of memory allocation size profile\n"));
     }
 #endif /* MEM_PROFILE */
 
     if ( SMALL_SIZE == Bytes ) {
-       v = ExAllocateFromNPagedLookasideList( &SmallLookasideList );
-       Signature = SIGNATURE_SMALL;
+    v = ExAllocateFromNPagedLookasideList( &SmallLookasideList );
+    Signature = SIGNATURE_SMALL;
     } else if ( LARGE_SIZE == Bytes ) {
-       v = ExAllocateFromNPagedLookasideList( &LargeLookasideList );
-       Signature = SIGNATURE_LARGE;
+    v = ExAllocateFromNPagedLookasideList( &LargeLookasideList );
+    Signature = SIGNATURE_LARGE;
     } else {
-       v = PoolAllocateBuffer( Bytes + sizeof(ULONG) );
-       Signature = SIGNATURE_OTHER;
+    v = ExAllocatePool( NonPagedPool, Bytes + sizeof(ULONG) );
+    Signature = SIGNATURE_OTHER;
     }
     if( v ) {
-       *((ULONG *) v) = Signature;
-       v = (void *)((char *) v + sizeof(ULONG));
-       TrackWithTag( FOURCC('f','b','s','d'), v, (PCHAR)File, Line );
+    *((ULONG *) v) = Signature;
+    v = (void *)((char *) v + sizeof(ULONG));
+    TrackWithTag( FOURCC('f','b','s','d'), v, (PCHAR)File, Line );
     }
 
     return v;
 }
 
 void TCPFree( void *ClientData,
-             void *data, OSK_PCHAR File, OSK_UINT Line ) {
+          void *data, OSK_PCHAR File, OSK_UINT Line ) {
     ULONG Signature;
 
-    UntrackFL( (PCHAR)File, Line, data );
+    ASSERT_LOCKED(&TCPLock);
+
+    UntrackFL( (PCHAR)File, Line, data, FOURCC('f','b','s','d') );
     data = (void *)((char *) data - sizeof(ULONG));
     Signature = *((ULONG *) data);
     if ( SIGNATURE_SMALL == Signature ) {
-       ExFreeToNPagedLookasideList( &SmallLookasideList, data );
+    ExFreeToNPagedLookasideList( &SmallLookasideList, data );
     } else if ( SIGNATURE_LARGE == Signature ) {
-       ExFreeToNPagedLookasideList( &LargeLookasideList, data );
+    ExFreeToNPagedLookasideList( &LargeLookasideList, data );
     } else if ( SIGNATURE_OTHER == Signature ) {
-       PoolFreeBuffer( data );
+    ExFreePool( data );
     } else {
-       ASSERT( FALSE );
+    ASSERT( FALSE );
     }
 }