[TCPIP]
[reactos.git] / reactos / lib / drivers / ip / transport / tcp / accept.c
index 643f59d..7fcddda 100644 (file)
 
 #include "precomp.h"
 
-/* Listener->Lock MUST be acquired */
-NTSTATUS TCPServiceListeningSocket( PCONNECTION_ENDPOINT Listener,
-                    PCONNECTION_ENDPOINT Connection,
-                    PTDI_REQUEST_KERNEL Request ) {
+#include "rosip.h"
+
+NTSTATUS TCPCheckPeerForAccept(PVOID Context,
+                               PTDI_REQUEST_KERNEL Request)
+{
+    struct tcp_pcb *newpcb = (struct tcp_pcb*)Context;
     NTSTATUS Status;
-    SOCKADDR_IN OutAddr;
-    OSK_UINT OutAddrLen;
-    PTA_IP_ADDRESS RequestAddressReturn;
     PTDI_CONNECTION_INFORMATION WhoIsConnecting;
-
-    /* Unpack TDI info -- We need the return connection information
-     * struct to return the address so it can be filtered if needed
-     * by WSAAccept -- The returned address will be passed on to
-     * userland after we complete this irp */
-    WhoIsConnecting = (PTDI_CONNECTION_INFORMATION)
-    Request->ReturnConnectionInformation;
-
-    Status = TCPTranslateError
-    ( OskitTCPAccept( Listener->SocketContext,
-              &Connection->SocketContext,
-              Connection,
-              &OutAddr,
-              sizeof(OutAddr),
-              &OutAddrLen,
-              Request->RequestFlags & TDI_QUERY_ACCEPT ? 0 : 1 ) );
-
-    TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
-
-    if( NT_SUCCESS(Status) && Status != STATUS_PENDING ) {
-    RequestAddressReturn = WhoIsConnecting->RemoteAddress;
-
-    TI_DbgPrint(DEBUG_TCP,("Copying address to %x (Who %x)\n",
-                   RequestAddressReturn, WhoIsConnecting));
-
-        RequestAddressReturn->TAAddressCount = 1;
-    RequestAddressReturn->Address[0].AddressLength = OutAddrLen;
-
-        /* BSD uses the first byte of the sockaddr struct as a length.
-         * Since windows doesn't do that we strip it */
-    RequestAddressReturn->Address[0].AddressType =
-        (OutAddr.sin_family >> 8) & 0xff;
-
-    RtlCopyMemory( &RequestAddressReturn->Address[0].Address,
-               ((PCHAR)&OutAddr) + sizeof(USHORT),
-               sizeof(RequestAddressReturn->Address[0].Address[0]) );
-
-    TI_DbgPrint(DEBUG_TCP,("Done copying\n"));
-    }
-
-    TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
-
+    PTA_IP_ADDRESS RemoteAddress;
+    struct ip_addr ipaddr;
+    
+    if (Request->RequestFlags & TDI_QUERY_ACCEPT)
+        DbgPrint("TDI_QUERY_ACCEPT NOT SUPPORTED!!!\n");
+
+    WhoIsConnecting = (PTDI_CONNECTION_INFORMATION)Request->ReturnConnectionInformation;
+    RemoteAddress = (PTA_IP_ADDRESS)WhoIsConnecting->RemoteAddress;
+    
+    RemoteAddress->TAAddressCount = 1;
+    RemoteAddress->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
+    RemoteAddress->Address[0].AddressType = TDI_ADDRESS_TYPE_IP;
+    
+    Status = TCPTranslateError(LibTCPGetPeerName(newpcb,
+                                                 &ipaddr,
+                                                 &RemoteAddress->Address[0].Address[0].sin_port));
+    
+    RemoteAddress->Address[0].Address[0].in_addr = ipaddr.addr;
+    
     return Status;
 }
 
 /* This listen is on a socket we keep as internal.  That socket has the same
  * lifetime as the address file */
-NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
+NTSTATUS TCPListen(PCONNECTION_ENDPOINT Connection, UINT Backlog)
+{
     NTSTATUS Status = STATUS_SUCCESS;
-    SOCKADDR_IN AddressToBind;
+    struct ip_addr AddressToBind;
     KIRQL OldIrql;
+    TA_IP_ADDRESS LocalAddress;
 
     ASSERT(Connection);
-    ASSERT_KM_POINTER(Connection->AddressFile);
 
     LockObject(Connection, &OldIrql);
 
-    TI_DbgPrint(DEBUG_TCP,("TCPListen started\n"));
+    ASSERT_KM_POINTER(Connection->AddressFile);
 
-    TI_DbgPrint(DEBUG_TCP,("Connection->SocketContext %x\n",
-    Connection->SocketContext));
+    TI_DbgPrint(DEBUG_TCP,("[IP, TCPListen] Called\n"));
 
-    AddressToBind.sin_family = AF_INET;
-    memcpy( &AddressToBind.sin_addr,
-        &Connection->AddressFile->Address.Address.IPv4Address,
-        sizeof(AddressToBind.sin_addr) );
-    AddressToBind.sin_port = Connection->AddressFile->Port;
+    TI_DbgPrint(DEBUG_TCP, ("Connection->SocketContext %x\n",
+        Connection->SocketContext));
+    
+    AddressToBind.addr = Connection->AddressFile->Address.Address.IPv4Address;
 
-    TI_DbgPrint(DEBUG_TCP,("AddressToBind - %x:%x\n", AddressToBind.sin_addr, AddressToBind.sin_port));
+    Status = TCPTranslateError(LibTCPBind(Connection,
+                                          &AddressToBind,
+                                          Connection->AddressFile->Port));
 
-    Status = TCPTranslateError( OskitTCPBind( Connection->SocketContext,
-                        &AddressToBind,
-                        sizeof(AddressToBind) ) );
+    if (NT_SUCCESS(Status))
+    {
+        /* Check if we had an unspecified port */
+        if (!Connection->AddressFile->Port)
+        {
+            /* We did, so we need to copy back the port */
+            Status = TCPGetSockAddress(Connection, (PTRANSPORT_ADDRESS)&LocalAddress, FALSE);
+            if (NT_SUCCESS(Status))
+            {
+                /* Allocate the port in the port bitmap */
+                Connection->AddressFile->Port = TCPAllocatePort(LocalAddress.Address[0].Address[0].sin_port);
+                
+                /* This should never fail */
+                ASSERT(Connection->AddressFile->Port != 0xFFFF);
+            }
+        }
+    }
 
     if (NT_SUCCESS(Status))
-        Status = TCPTranslateError( OskitTCPListen( Connection->SocketContext, Backlog ) );
+    {
+        Connection->SocketContext = LibTCPListen(Connection, Backlog);
+        if (!Connection->SocketContext)
+            Status = STATUS_UNSUCCESSFUL;
+    }
 
     UnlockObject(Connection, OldIrql);
 
-    TI_DbgPrint(DEBUG_TCP,("TCPListen finished %x\n", Status));
+    TI_DbgPrint(DEBUG_TCP,("[IP, TCPListen] Leaving. Status = %x\n", Status));
 
     return Status;
 }
 
-BOOLEAN TCPAbortListenForSocket( PCONNECTION_ENDPOINT Listener,
-                  PCONNECTION_ENDPOINT Connection ) {
+BOOLEAN TCPAbortListenForSocket
+(   PCONNECTION_ENDPOINT Listener,
+    PCONNECTION_ENDPOINT Connection)
+{
     PLIST_ENTRY ListEntry;
     PTDI_BUCKET Bucket;
     KIRQL OldIrql;
@@ -113,17 +110,20 @@ BOOLEAN TCPAbortListenForSocket( PCONNECTION_ENDPOINT Listener,
     LockObject(Listener, &OldIrql);
 
     ListEntry = Listener->ListenRequest.Flink;
-    while ( ListEntry != &Listener->ListenRequest ) {
-    Bucket = CONTAINING_RECORD(ListEntry, TDI_BUCKET, Entry);
-
-    if( Bucket->AssociatedEndpoint == Connection ) {
-        RemoveEntryList( &Bucket->Entry );
-        ExFreePoolWithTag( Bucket, TDI_BUCKET_TAG );
-        Found = TRUE;
-        break;
-    }
-
-    ListEntry = ListEntry->Flink;
+    while (ListEntry != &Listener->ListenRequest)
+    {
+        Bucket = CONTAINING_RECORD(ListEntry, TDI_BUCKET, Entry);
+
+        if (Bucket->AssociatedEndpoint == Connection)
+        {
+            DereferenceObject(Bucket->AssociatedEndpoint);
+            RemoveEntryList( &Bucket->Entry );
+            ExFreePoolWithTag( Bucket, TDI_BUCKET_TAG );
+            Found = TRUE;
+            break;
+        }
+
+        ListEntry = ListEntry->Flink;
     }
 
     UnlockObject(Listener, OldIrql);
@@ -141,29 +141,26 @@ NTSTATUS TCPAccept ( PTDI_REQUEST Request,
     PTDI_BUCKET Bucket;
     KIRQL OldIrql;
 
-    TI_DbgPrint(DEBUG_TCP,("TCPAccept started\n"));
-
     LockObject(Listener, &OldIrql);
 
-    Status = TCPServiceListeningSocket( Listener, Connection,
-                       (PTDI_REQUEST_KERNEL)Request );
-
-    if( Status == STATUS_PENDING ) {
-        Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket),
-                                        TDI_BUCKET_TAG );
-
-        if( Bucket ) {
-            ReferenceObject(Connection);
-            Bucket->AssociatedEndpoint = Connection;
-            Bucket->Request.RequestNotifyObject = Complete;
-            Bucket->Request.RequestContext = Context;
-            InsertTailList( &Listener->ListenRequest, &Bucket->Entry );
-        } else
-            Status = STATUS_NO_MEMORY;
+    Bucket = (PTDI_BUCKET)ExAllocatePoolWithTag(NonPagedPool,
+                                    sizeof(*Bucket),
+                                    TDI_BUCKET_TAG );
+    
+    if (Bucket)
+    {
+        Bucket->AssociatedEndpoint = Connection;
+        ReferenceObject(Bucket->AssociatedEndpoint);
+
+        Bucket->Request.RequestNotifyObject = Complete;
+        Bucket->Request.RequestContext = Context;
+        InsertTailList( &Listener->ListenRequest, &Bucket->Entry );
+        Status = STATUS_PENDING;
     }
+    else
+        Status = STATUS_NO_MEMORY;
 
     UnlockObject(Listener, OldIrql);
 
-    TI_DbgPrint(DEBUG_TCP,("TCPAccept finished %x\n", Status));
     return Status;
 }