- Make sure the socket is still open before entering oskittcp
authorCameron Gutman <aicommander@gmail.com>
Sat, 31 Oct 2009 01:05:31 +0000 (01:05 +0000)
committerCameron Gutman <aicommander@gmail.com>
Sat, 31 Oct 2009 01:05:31 +0000 (01:05 +0000)
 - Remove an unused parameter from OskitTCPBind
 - Return a status value from OskitTCPGetAddress
 - Add debug print for unhandled error codes

svn path=/trunk/; revision=43864

reactos/lib/drivers/ip/transport/tcp/accept.c
reactos/lib/drivers/ip/transport/tcp/tcp.c
reactos/lib/drivers/oskittcp/include/oskittcp.h
reactos/lib/drivers/oskittcp/oskittcp/interface.c

index 151f309..ee1f3bf 100644 (file)
@@ -90,7 +90,6 @@ NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
     TI_DbgPrint(DEBUG_TCP,("AddressToBind - %x:%x\n", AddressToBind.sin_addr, AddressToBind.sin_port));
 
     Status = TCPTranslateError( OskitTCPBind( Connection->SocketContext,
-                        Connection,
                         &AddressToBind,
                         sizeof(AddressToBind) ) );
 
index 23bbb1c..5500db0 100644 (file)
@@ -555,7 +555,7 @@ NTSTATUS TCPShutdown(VOID)
 }
 
 NTSTATUS TCPTranslateError( int OskitError ) {
-    NTSTATUS Status = STATUS_UNSUCCESSFUL;
+    NTSTATUS Status;
 
     switch( OskitError ) {
     case 0: Status = STATUS_SUCCESS; break;
@@ -565,7 +565,10 @@ NTSTATUS TCPTranslateError( int OskitError ) {
     case OSK_ECONNRESET: Status = STATUS_REMOTE_NOT_LISTENING; break;
     case OSK_EINPROGRESS:
     case OSK_EAGAIN: Status = STATUS_PENDING; break;
-    default: Status = STATUS_INVALID_CONNECTION; break;
+    default:
+       DbgPrint("OskitTCP returned unhandled error code: %d\n", OskitError);
+       Status = STATUS_INVALID_CONNECTION;
+       break;
     }
 
     TI_DbgPrint(DEBUG_TCP,("Error %d -> %x\n", OskitError, Status));
@@ -621,7 +624,6 @@ NTSTATUS TCPConnect
 
     Status = TCPTranslateError
         ( OskitTCPBind( Connection->SocketContext,
-                        Connection,
                         &AddressToBind,
                         sizeof(AddressToBind) ) );
 
@@ -703,6 +705,8 @@ NTSTATUS TCPClose
     DrainSignals();
 
     Status = TCPTranslateError( OskitTCPClose( Connection->SocketContext ) );
+    if (Status == STATUS_SUCCESS)
+        Connection->SocketContext = NULL;
 
     TI_DbgPrint(DEBUG_TCP,("TCPClose finished %x\n", Status));
 
@@ -867,13 +871,15 @@ NTSTATUS TCPGetSockAddress
     OSK_UINT LocalAddress, RemoteAddress;
     OSK_UI16 LocalPort, RemotePort;
     PTA_IP_ADDRESS AddressIP = (PTA_IP_ADDRESS)Address;
+    NTSTATUS Status;
 
     ASSERT_LOCKED(&TCPLock);
 
-    OskitTCPGetAddress
-        ( Connection->SocketContext,
-          &LocalAddress, &LocalPort,
-          &RemoteAddress, &RemotePort );
+    Status = TCPTranslateError(OskitTCPGetAddress(Connection->SocketContext,
+                                                  &LocalAddress, &LocalPort,
+                                                  &RemoteAddress, &RemotePort));
+    if (!NT_SUCCESS(Status))
+        return Status;
 
     AddressIP->TAAddressCount = 1;
     AddressIP->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
@@ -881,7 +887,7 @@ NTSTATUS TCPGetSockAddress
     AddressIP->Address[0].Address[0].sin_port = GetRemote ? RemotePort : LocalPort;
     AddressIP->Address[0].Address[0].in_addr = GetRemote ? RemoteAddress : LocalAddress;
 
-    return STATUS_SUCCESS;
+    return Status;
 }
 
 VOID TCPRemoveIRP( PCONNECTION_ENDPOINT Endpoint, PIRP Irp ) {
index 647aecf..8f49a22 100644 (file)
@@ -127,7 +127,7 @@ extern int OskitTCPConnect( void *socket, void *connection,
                            void *nam, OSK_UINT namelen );
 extern int OskitTCPClose( void *socket );
 
-extern int OskitTCPBind( void *socket, void *connection,
+extern int OskitTCPBind( void *socket,
                         void *nam, OSK_UINT namelen );
 
 extern int OskitTCPAccept( void *socket, void **new_socket,
@@ -144,7 +144,7 @@ extern int OskitTCPRecv( void *connection,
                         OSK_UINT *OutLen,
                         OSK_UINT Flags );
 
-void OskitTCPGetAddress( void *socket,
+int OskitTCPGetAddress( void *socket,
                         OSK_UINT *LocalAddress,
                         OSK_UI16 *LocalPort,
                         OSK_UINT *RemoteAddress,
index e69859b..bcb6bc4 100644 (file)
@@ -76,9 +76,6 @@ void TimerOskitTCP( int FastTimer, int SlowTimer ) {
 
 void RegisterOskitTCPEventHandlers( POSKITTCP_EVENT_HANDLERS EventHandlers ) {
     memcpy( &OtcpEvent, EventHandlers, sizeof(OtcpEvent) );
-    if( OtcpEvent.PacketSend )
-       OS_DbgPrint(OSK_MID_TRACE,("SendPacket handler registered: %x\n",
-                                  OtcpEvent.PacketSend));
 }
 
 void OskitDumpBuffer( OSK_PCHAR Data, OSK_UINT Len )
@@ -169,7 +166,7 @@ int OskitTCPRecv( void *connection,
     return error;
 }
 
-int OskitTCPBind( void *socket, void *connection,
+int OskitTCPBind( void *socket,
                  void *nam, OSK_UINT namelen ) {
     int error = EFAULT;
     struct socket *so = socket;
@@ -178,6 +175,9 @@ int OskitTCPBind( void *socket, void *connection,
 
     OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     if( nam )
        addr = *((struct sockaddr *)nam);
 
@@ -243,11 +243,18 @@ done:
 }
 
 int OskitTCPShutdown( void *socket, int disconn_type ) {
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     return soshutdown( socket, disconn_type );
 }
 
 int OskitTCPClose( void *socket ) {
     struct socket *so = socket;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     so->so_connection = 0;
     soclose( so );
     return 0;
@@ -259,6 +266,9 @@ int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
     struct uio uio;
     struct iovec iov;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     iov.iov_len = Len;
     iov.iov_base = (char *)Data;
     uio.uio_iov = &iov;
@@ -293,6 +303,12 @@ int OskitTCPAccept( void *socket,
     struct inpcb *inp;
     int namelen = 0, error = 0, s;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    if (!new_socket || !AddrOut)
+        return OSK_EINVAL;
+
     OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: Doing accept (Finish %d)\n",
                               FinishAccepting));
 
@@ -436,6 +452,9 @@ int OskitTCPSetSockOpt(void *socket,
 {
     struct mbuf *m;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     if (size >= MLEN)
         return OSK_EINVAL;
 
@@ -460,6 +479,9 @@ int OskitTCPGetSockOpt(void *socket,
     int error, oldsize = *size;
     struct mbuf *m;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     error = sogetopt(socket, level, optname, &m);
     if (!error)
     {
@@ -482,6 +504,9 @@ int OskitTCPGetSockOpt(void *socket,
 int OskitTCPListen( void *socket, int backlog ) {
     int error;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
     error = solisten( socket, backlog );
     OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
@@ -489,32 +514,44 @@ int OskitTCPListen( void *socket, int backlog ) {
     return error;
 }
 
-void OskitTCPSetAddress( void *socket,
+int OskitTCPSetAddress( void *socket,
                         OSK_UINT LocalAddress,
                         OSK_UI16 LocalPort,
                         OSK_UINT RemoteAddress,
                         OSK_UI16 RemotePort ) {
     struct socket *so = socket;
-    struct inpcb *inp = (struct inpcb *)so->so_pcb;
+    struct inpcb *inp;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    inp = (struct inpcb *)so->so_pcb;
     inp->inp_laddr.s_addr = LocalAddress;
     inp->inp_lport = LocalPort;
     inp->inp_faddr.s_addr = RemoteAddress;
     inp->inp_fport = RemotePort;
+
+    return 0;
 }
 
-void OskitTCPGetAddress( void *socket,
+int OskitTCPGetAddress( void *socket,
                         OSK_UINT *LocalAddress,
                         OSK_UI16 *LocalPort,
                         OSK_UINT *RemoteAddress,
                         OSK_UI16 *RemotePort ) {
     struct socket *so = socket;
-    struct inpcb *inp = so ? (struct inpcb *)so->so_pcb : NULL;
-    if( inp ) {
-       *LocalAddress = inp->inp_laddr.s_addr;
-       *LocalPort = inp->inp_lport;
-       *RemoteAddress = inp->inp_faddr.s_addr;
-       *RemotePort = inp->inp_fport;
-    }
+    struct inpcb *inp;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    inp = (struct inpcb *)so->so_pcb;
+    *LocalAddress = inp->inp_laddr.s_addr;
+    *LocalPort = inp->inp_lport;
+    *RemoteAddress = inp->inp_faddr.s_addr;
+    *RemotePort = inp->inp_fport;
+
+    return 0;
 }
 
 struct ifaddr *ifa_iffind(struct sockaddr *addr, int type)