[TCPIP DRIVER]
[reactos.git] / lib / drivers / ip / transport / tcp / accept.c
index 643f59d..94b6eeb 100644 (file)
 
 #include "precomp.h"
 
-/* Listener->Lock MUST be acquired */
-NTSTATUS TCPServiceListeningSocket( PCONNECTION_ENDPOINT Listener,
-                    PCONNECTION_ENDPOINT Connection,
-                    PTDI_REQUEST_KERNEL Request ) {
+#include "rosip.h"
+
+NTSTATUS TCPCheckPeerForAccept(PVOID Context,
+                               PTDI_REQUEST_KERNEL Request) {
+    struct tcp_pcb *newpcb = Context;
     NTSTATUS Status;
-    SOCKADDR_IN OutAddr;
-    OSK_UINT OutAddrLen;
-    PTA_IP_ADDRESS RequestAddressReturn;
     PTDI_CONNECTION_INFORMATION WhoIsConnecting;
-
-    /* Unpack TDI info -- We need the return connection information
-     * struct to return the address so it can be filtered if needed
-     * by WSAAccept -- The returned address will be passed on to
-     * userland after we complete this irp */
-    WhoIsConnecting = (PTDI_CONNECTION_INFORMATION)
-    Request->ReturnConnectionInformation;
-
-    Status = TCPTranslateError
-    ( OskitTCPAccept( Listener->SocketContext,
-              &Connection->SocketContext,
-              Connection,
-              &OutAddr,
-              sizeof(OutAddr),
-              &OutAddrLen,
-              Request->RequestFlags & TDI_QUERY_ACCEPT ? 0 : 1 ) );
-
-    TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
-
-    if( NT_SUCCESS(Status) && Status != STATUS_PENDING ) {
-    RequestAddressReturn = WhoIsConnecting->RemoteAddress;
-
-    TI_DbgPrint(DEBUG_TCP,("Copying address to %x (Who %x)\n",
-                   RequestAddressReturn, WhoIsConnecting));
-
-        RequestAddressReturn->TAAddressCount = 1;
-    RequestAddressReturn->Address[0].AddressLength = OutAddrLen;
-
-        /* BSD uses the first byte of the sockaddr struct as a length.
-         * Since windows doesn't do that we strip it */
-    RequestAddressReturn->Address[0].AddressType =
-        (OutAddr.sin_family >> 8) & 0xff;
-
-    RtlCopyMemory( &RequestAddressReturn->Address[0].Address,
-               ((PCHAR)&OutAddr) + sizeof(USHORT),
-               sizeof(RequestAddressReturn->Address[0].Address[0]) );
-
-    TI_DbgPrint(DEBUG_TCP,("Done copying\n"));
-    }
-
+    PTA_IP_ADDRESS RemoteAddress;
+    struct ip_addr ipaddr;
+    
+    if (Request->RequestFlags & TDI_QUERY_ACCEPT)
+        DbgPrint("TDI_QUERY_ACCEPT NOT SUPPORTED!!!\n");
+
+    WhoIsConnecting = (PTDI_CONNECTION_INFORMATION)Request->ReturnConnectionInformation;
+    RemoteAddress = (PTA_IP_ADDRESS)WhoIsConnecting->RemoteAddress;
+    
+    RemoteAddress->TAAddressCount = 1;
+    RemoteAddress->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
+    RemoteAddress->Address[0].AddressType = TDI_ADDRESS_TYPE_IP;
+    
+    Status = TCPTranslateError(LibTCPGetPeerName(newpcb,
+                                                 &ipaddr,
+                                                 &RemoteAddress->Address[0].Address[0].sin_port));
+    
+    RemoteAddress->Address[0].Address[0].in_addr = ipaddr.addr;
+    
     TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
 
     return Status;
@@ -68,7 +45,7 @@ NTSTATUS TCPServiceListeningSocket( PCONNECTION_ENDPOINT Listener,
  * lifetime as the address file */
 NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
     NTSTATUS Status = STATUS_SUCCESS;
-    SOCKADDR_IN AddressToBind;
+    struct ip_addr AddressToBind;
     KIRQL OldIrql;
 
     ASSERT(Connection);
@@ -80,21 +57,19 @@ NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
 
     TI_DbgPrint(DEBUG_TCP,("Connection->SocketContext %x\n",
     Connection->SocketContext));
+    
+    AddressToBind.addr = Connection->AddressFile->Address.Address.IPv4Address;
 
-    AddressToBind.sin_family = AF_INET;
-    memcpy( &AddressToBind.sin_addr,
-        &Connection->AddressFile->Address.Address.IPv4Address,
-        sizeof(AddressToBind.sin_addr) );
-    AddressToBind.sin_port = Connection->AddressFile->Port;
-
-    TI_DbgPrint(DEBUG_TCP,("AddressToBind - %x:%x\n", AddressToBind.sin_addr, AddressToBind.sin_port));
-
-    Status = TCPTranslateError( OskitTCPBind( Connection->SocketContext,
-                        &AddressToBind,
-                        sizeof(AddressToBind) ) );
+    Status = TCPTranslateError(LibTCPBind(Connection->SocketContext,
+                                          &AddressToBind,
+                                          Connection->AddressFile->Port));
 
     if (NT_SUCCESS(Status))
-        Status = TCPTranslateError( OskitTCPListen( Connection->SocketContext, Backlog ) );
+    {
+        Connection->SocketContext = LibTCPListen(Connection->SocketContext, Backlog);
+        if (!Connection->SocketContext)
+            Status = STATUS_UNSUCCESSFUL;
+    }
 
     UnlockObject(Connection, OldIrql);
 
@@ -117,6 +92,7 @@ BOOLEAN TCPAbortListenForSocket( PCONNECTION_ENDPOINT Listener,
     Bucket = CONTAINING_RECORD(ListEntry, TDI_BUCKET, Entry);
 
     if( Bucket->AssociatedEndpoint == Connection ) {
+        DereferenceObject(Bucket->AssociatedEndpoint);
         RemoveEntryList( &Bucket->Entry );
         ExFreePoolWithTag( Bucket, TDI_BUCKET_TAG );
         Found = TRUE;
@@ -145,22 +121,17 @@ NTSTATUS TCPAccept ( PTDI_REQUEST Request,
 
     LockObject(Listener, &OldIrql);
 
-    Status = TCPServiceListeningSocket( Listener, Connection,
-                       (PTDI_REQUEST_KERNEL)Request );
-
-    if( Status == STATUS_PENDING ) {
-        Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket),
-                                        TDI_BUCKET_TAG );
-
-        if( Bucket ) {
-            ReferenceObject(Connection);
-            Bucket->AssociatedEndpoint = Connection;
-            Bucket->Request.RequestNotifyObject = Complete;
-            Bucket->Request.RequestContext = Context;
-            InsertTailList( &Listener->ListenRequest, &Bucket->Entry );
-        } else
-            Status = STATUS_NO_MEMORY;
-    }
+    Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket),
+                                   TDI_BUCKET_TAG );
+    
+    if( Bucket ) {
+        Bucket->AssociatedEndpoint = Connection;
+        Bucket->Request.RequestNotifyObject = Complete;
+        Bucket->Request.RequestContext = Context;
+        InsertTailList( &Listener->ListenRequest, &Bucket->Entry );
+        Status = STATUS_PENDING;
+    } else
+        Status = STATUS_NO_MEMORY;
 
     UnlockObject(Listener, OldIrql);