[TCPIP]
[reactos.git] / reactos / lib / drivers / ip / transport / tcp / accept.c
index 80103aa..7fcddda 100644 (file)
 
 #include "precomp.h"
 
-/* Listener->Lock MUST be acquired */
-NTSTATUS TCPServiceListeningSocket( PCONNECTION_ENDPOINT Listener,
-                    PCONNECTION_ENDPOINT Connection,
-                    PTDI_REQUEST_KERNEL Request ) {
+#include "rosip.h"
+
+NTSTATUS TCPCheckPeerForAccept(PVOID Context,
+                               PTDI_REQUEST_KERNEL Request)
+{
+    struct tcp_pcb *newpcb = (struct tcp_pcb*)Context;
     NTSTATUS Status;
-    SOCKADDR_IN OutAddr;
-    OSK_UINT OutAddrLen;
-    PTA_IP_ADDRESS RequestAddressReturn;
     PTDI_CONNECTION_INFORMATION WhoIsConnecting;
+    PTA_IP_ADDRESS RemoteAddress;
+    struct ip_addr ipaddr;
+    
+    if (Request->RequestFlags & TDI_QUERY_ACCEPT)
+        DbgPrint("TDI_QUERY_ACCEPT NOT SUPPORTED!!!\n");
 
-    /* Unpack TDI info -- We need the return connection information
-     * struct to return the address so it can be filtered if needed
-     * by WSAAccept -- The returned address will be passed on to
-     * userland after we complete this irp */
-    WhoIsConnecting = (PTDI_CONNECTION_INFORMATION)
-    Request->ReturnConnectionInformation;
-
-    Status = TCPTranslateError
-    ( OskitTCPAccept( Listener,
-              &Connection->SocketContext,
-              Connection,
-              &OutAddr,
-              sizeof(OutAddr),
-              &OutAddrLen,
-              Request->RequestFlags & TDI_QUERY_ACCEPT ? 0 : 1 ) );
-
-    TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
-
-    if( NT_SUCCESS(Status) && Status != STATUS_PENDING ) {
-        RequestAddressReturn = WhoIsConnecting->RemoteAddress;
-
-        TI_DbgPrint(DEBUG_TCP,("Copying address to %x (Who %x)\n",
-                               RequestAddressReturn, WhoIsConnecting));
-
-        RequestAddressReturn->TAAddressCount = 1;
-        RequestAddressReturn->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
-        RequestAddressReturn->Address[0].AddressType = TDI_ADDRESS_TYPE_IP;
-        RequestAddressReturn->Address[0].Address[0].sin_port = OutAddr.sin_port;
-        RequestAddressReturn->Address[0].Address[0].in_addr = OutAddr.sin_addr.s_addr;
-        RtlZeroMemory(RequestAddressReturn->Address[0].Address[0].sin_zero, 8);
-
-        TI_DbgPrint(DEBUG_TCP,("Done copying\n"));
-    }
-
-    TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
-
+    WhoIsConnecting = (PTDI_CONNECTION_INFORMATION)Request->ReturnConnectionInformation;
+    RemoteAddress = (PTA_IP_ADDRESS)WhoIsConnecting->RemoteAddress;
+    
+    RemoteAddress->TAAddressCount = 1;
+    RemoteAddress->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
+    RemoteAddress->Address[0].AddressType = TDI_ADDRESS_TYPE_IP;
+    
+    Status = TCPTranslateError(LibTCPGetPeerName(newpcb,
+                                                 &ipaddr,
+                                                 &RemoteAddress->Address[0].Address[0].sin_port));
+    
+    RemoteAddress->Address[0].Address[0].in_addr = ipaddr.addr;
+    
     return Status;
 }
 
 /* This listen is on a socket we keep as internal.  That socket has the same
  * lifetime as the address file */
-NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
+NTSTATUS TCPListen(PCONNECTION_ENDPOINT Connection, UINT Backlog)
+{
     NTSTATUS Status = STATUS_SUCCESS;
-    SOCKADDR_IN AddressToBind;
+    struct ip_addr AddressToBind;
     KIRQL OldIrql;
     TA_IP_ADDRESS LocalAddress;
 
     ASSERT(Connection);
 
     LockObject(Connection, &OldIrql);
-    
-    ASSERT_KM_POINTER(Connection->AddressFile);
 
-    TI_DbgPrint(DEBUG_TCP,("TCPListen started\n"));
+    ASSERT_KM_POINTER(Connection->AddressFile);
 
-    if (Connection->AddressFile->Port)
-    {
-        AddressToBind.sin_family = AF_INET;
-        memcpy( &AddressToBind.sin_addr,
-               &Connection->AddressFile->Address.Address.IPv4Address,
-               sizeof(AddressToBind.sin_addr) );
-        AddressToBind.sin_port = Connection->AddressFile->Port;
-        TI_DbgPrint(DEBUG_TCP,("AddressToBind - %x:%x\n", AddressToBind.sin_addr, AddressToBind.sin_port));
-
-        /* Perform an explicit bind */
-        Status = TCPTranslateError(OskitTCPBind(Connection,
-                                                &AddressToBind,
-                                                sizeof(AddressToBind)));
-    }
-    else
-    {
-        /* An implicit bind will be performed */
-        Status = STATUS_SUCCESS;
-    }
+    TI_DbgPrint(DEBUG_TCP,("[IP, TCPListen] Called\n"));
 
-    if (NT_SUCCESS(Status))
-        Status = TCPTranslateError( OskitTCPListen( Connection, Backlog ) );
+    TI_DbgPrint(DEBUG_TCP, ("Connection->SocketContext %x\n",
+        Connection->SocketContext));
     
+    AddressToBind.addr = Connection->AddressFile->Address.Address.IPv4Address;
+
+    Status = TCPTranslateError(LibTCPBind(Connection,
+                                          &AddressToBind,
+                                          Connection->AddressFile->Port));
+
     if (NT_SUCCESS(Status))
     {
         /* Check if we had an unspecified port */
         if (!Connection->AddressFile->Port)
         {
             /* We did, so we need to copy back the port */
-            if (NT_SUCCESS(TCPGetSockAddress(Connection, (PTRANSPORT_ADDRESS)&LocalAddress, FALSE)))
+            Status = TCPGetSockAddress(Connection, (PTRANSPORT_ADDRESS)&LocalAddress, FALSE);
+            if (NT_SUCCESS(Status))
             {
                 /* Allocate the port in the port bitmap */
                 Connection->AddressFile->Port = TCPAllocatePort(LocalAddress.Address[0].Address[0].sin_port);
@@ -115,15 +84,24 @@ NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
         }
     }
 
+    if (NT_SUCCESS(Status))
+    {
+        Connection->SocketContext = LibTCPListen(Connection, Backlog);
+        if (!Connection->SocketContext)
+            Status = STATUS_UNSUCCESSFUL;
+    }
+
     UnlockObject(Connection, OldIrql);
 
-    TI_DbgPrint(DEBUG_TCP,("TCPListen finished %x\n", Status));
+    TI_DbgPrint(DEBUG_TCP,("[IP, TCPListen] Leaving. Status = %x\n", Status));
 
     return Status;
 }
 
-BOOLEAN TCPAbortListenForSocket( PCONNECTION_ENDPOINT Listener,
-                  PCONNECTION_ENDPOINT Connection ) {
+BOOLEAN TCPAbortListenForSocket
+(   PCONNECTION_ENDPOINT Listener,
+    PCONNECTION_ENDPOINT Connection)
+{
     PLIST_ENTRY ListEntry;
     PTDI_BUCKET Bucket;
     KIRQL OldIrql;
@@ -132,18 +110,20 @@ BOOLEAN TCPAbortListenForSocket( PCONNECTION_ENDPOINT Listener,
     LockObject(Listener, &OldIrql);
 
     ListEntry = Listener->ListenRequest.Flink;
-    while ( ListEntry != &Listener->ListenRequest ) {
-    Bucket = CONTAINING_RECORD(ListEntry, TDI_BUCKET, Entry);
-
-    if( Bucket->AssociatedEndpoint == Connection ) {
-        DereferenceObject(Bucket->AssociatedEndpoint);
-        RemoveEntryList( &Bucket->Entry );
-        ExFreePoolWithTag( Bucket, TDI_BUCKET_TAG );
-        Found = TRUE;
-        break;
-    }
+    while (ListEntry != &Listener->ListenRequest)
+    {
+        Bucket = CONTAINING_RECORD(ListEntry, TDI_BUCKET, Entry);
 
-    ListEntry = ListEntry->Flink;
+        if (Bucket->AssociatedEndpoint == Connection)
+        {
+            DereferenceObject(Bucket->AssociatedEndpoint);
+            RemoveEntryList( &Bucket->Entry );
+            ExFreePoolWithTag( Bucket, TDI_BUCKET_TAG );
+            Found = TRUE;
+            break;
+        }
+
+        ListEntry = ListEntry->Flink;
     }
 
     UnlockObject(Listener, OldIrql);
@@ -161,29 +141,26 @@ NTSTATUS TCPAccept ( PTDI_REQUEST Request,
     PTDI_BUCKET Bucket;
     KIRQL OldIrql;
 
-    TI_DbgPrint(DEBUG_TCP,("TCPAccept started\n"));
-
     LockObject(Listener, &OldIrql);
 
-    Status = TCPServiceListeningSocket( Listener, Connection,
-                       (PTDI_REQUEST_KERNEL)Request );
-
-    if( Status == STATUS_PENDING ) {
-        Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket),
-                                        TDI_BUCKET_TAG );
-
-        if( Bucket ) {
-            ReferenceObject(Connection);
-            Bucket->AssociatedEndpoint = Connection;
-            Bucket->Request.RequestNotifyObject = Complete;
-            Bucket->Request.RequestContext = Context;
-            InsertTailList( &Listener->ListenRequest, &Bucket->Entry );
-        } else
-            Status = STATUS_NO_MEMORY;
+    Bucket = (PTDI_BUCKET)ExAllocatePoolWithTag(NonPagedPool,
+                                    sizeof(*Bucket),
+                                    TDI_BUCKET_TAG );
+    
+    if (Bucket)
+    {
+        Bucket->AssociatedEndpoint = Connection;
+        ReferenceObject(Bucket->AssociatedEndpoint);
+
+        Bucket->Request.RequestNotifyObject = Complete;
+        Bucket->Request.RequestContext = Context;
+        InsertTailList( &Listener->ListenRequest, &Bucket->Entry );
+        Status = STATUS_PENDING;
     }
+    else
+        Status = STATUS_NO_MEMORY;
 
     UnlockObject(Listener, OldIrql);
 
-    TI_DbgPrint(DEBUG_TCP,("TCPAccept finished %x\n", Status));
     return Status;
 }