[OSKITTCP]
[reactos.git] / reactos / lib / drivers / oskittcp / oskittcp / interface.c
index 5aea3dd..db0b3a8 100644 (file)
@@ -23,6 +23,8 @@ OSKITTCP_EVENT_HANDLERS OtcpEvent = { 0 };
 //OSK_UINT OskitDebugTraceLevel = OSK_DEBUG_ULTRA;
 OSK_UINT OskitDebugTraceLevel = 0;
 
+KSPIN_LOCK OSKLock;
+
 /* SPL */
 unsigned cpl;
 unsigned net_imask;
@@ -45,8 +47,7 @@ void fbsd_free( void *data, char *file, unsigned line, ... ) {
 
 void InitOskitTCP() {
     OS_DbgPrint(OSK_MID_TRACE,("Init Called\n"));
-    OS_DbgPrint(OSK_MID_TRACE,("MB Init\n"));
-    mbinit();
+    KeInitializeSpinLock(&OSKLock);
     OS_DbgPrint(OSK_MID_TRACE,("Rawip Init\n"));
     rip_init();
     raw_init();
@@ -66,12 +67,19 @@ void DeinitOskitTCP() {
 }
 
 void TimerOskitTCP( int FastTimer, int SlowTimer ) {
+    KIRQL OldIrql;
+
+    /* This function is a special case in which we cannot use OSKLock/OSKUnlock 
+     * because we don't enter with the connection lock held */
+
+    OSKLockAndRaise(&OldIrql);
     if ( SlowTimer ) {
         tcp_slowtimo();
     }
     if ( FastTimer ) {
         tcp_fasttimo();
     }
+    OSKUnlockAndLower(OldIrql);
 }
 
 void RegisterOskitTCPEventHandlers( POSKITTCP_EVENT_HANDLERS EventHandlers ) {
@@ -115,12 +123,17 @@ int OskitTCPSocket( void *context,
                    int proto )
 {
     struct socket *so;
+
+    OSKLock();
     int error = socreate(domain, &so, type, proto);
     if( !error ) {
        so->so_connection = context;
        so->so_state |= SS_NBIO;
+    so->so_options |= SO_DONTROUTE;
        *aso = so;
     }
+    OSKUnlock();
+
     return error;
 }
 
@@ -153,8 +166,11 @@ int OskitTCPRecv( void *connection,
 
     OS_DbgPrint(OSK_MID_TRACE,("Reading %d bytes from TCP:\n", Len));
 
+    OSKLock();
     error = soreceive( connection, NULL, &uio, NULL, NULL /* SCM_RIGHTS */,
                       &tcp_flags );
+    OSKUnlock();
+
     *OutLen = Len - uio.uio_resid;
 
     return error;
@@ -182,14 +198,15 @@ int OskitTCPBind( void *socket,
     addr.sa_family = addr.sa_len;
     addr.sa_len = sizeof(struct sockaddr);
 
+    OSKLock();
     error = sobind(so, &sabuf);
+    OSKUnlock();
 
     OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
     return (error);
 }
 
-int OskitTCPConnect( void *socket, void *connection,
-                    void *nam, OSK_UINT namelen ) {
+int OskitTCPConnect( void *socket, void *nam, OSK_UINT namelen ) {
     struct socket *so = socket;
     int error = EFAULT;
     struct mbuf sabuf;
@@ -197,8 +214,7 @@ int OskitTCPConnect( void *socket, void *connection,
 
     OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
 
-    so->so_connection = connection;
-
+    OSKLock();
     if ((so->so_state & SS_NBIO) && (so->so_state & SS_ISCONNECTING)) {
        error = EALREADY;
        goto done;
@@ -217,7 +233,9 @@ int OskitTCPConnect( void *socket, void *connection,
 
     error = soconnect(so, &sabuf);
 
-    if (error)
+    if (error == EINPROGRESS)
+        goto done;
+    else if (error)
        goto bad;
 
     if ((so->so_state & SS_NBIO) && (so->so_state & SS_ISCONNECTING)) {
@@ -232,34 +250,49 @@ bad:
        error = EINTR;
 
 done:
+    OSKUnlock();
     OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
     return (error);
 }
 
 int OskitTCPDisconnect(void *socket)
 {
+    int error;
+
     if (!socket)
         return OSK_ESHUTDOWN;
 
-    return sodisconnect(socket);
+    OSKLock();
+    error = sodisconnect(socket);
+    OSKUnlock();
+
+    return error;
 }
 
 int OskitTCPShutdown( void *socket, int disconn_type ) {
+    int error;
+
     if (!socket)
         return OSK_ESHUTDOWN;
 
-    return soshutdown( socket, disconn_type );
+    OSKLock();
+    error = soshutdown( socket, disconn_type );
+    OSKUnlock();
+
+    return error;
 }
 
 int OskitTCPClose( void *socket ) {
-    struct socket *so = socket;
+    int error;
 
     if (!socket)
         return OSK_ESHUTDOWN;
 
-    so->so_connection = 0;
-    soclose( so );
-    return 0;
+    OSKLock();
+    error = soclose( socket );
+    OSKUnlock();
+
+    return error;
 }
 
 int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
@@ -281,7 +314,10 @@ int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
     uio.uio_rw = UIO_WRITE;
     uio.uio_procp = NULL;
 
+    OSKLock();
     error = sosend( socket, NULL, &uio, NULL, NULL, 0 );
+    OSKUnlock();
+
     *OutLen = Len - uio.uio_resid;
 
     return error;
@@ -289,6 +325,7 @@ int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
 
 int OskitTCPAccept( void *socket,
                    void **new_socket,
+                   void *context,
                    void *AddrOut,
                    OSK_UINT AddrLen,
                    OSK_UINT *OutAddrLen,
@@ -317,11 +354,12 @@ int OskitTCPAccept( void *socket,
        /* that's a copyin actually */
        namelen = *OutAddrLen;
 
+    OSKLock();
+
     s = splnet();
 
 #if 0
     if ((head->so_options & SO_ACCEPTCONN) == 0) {
-       splx(s);
        OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: head->so_options = %x, wanted bit %x\n",
                                   head->so_options, SO_ACCEPTCONN));
        error = EINVAL;
@@ -333,39 +371,10 @@ int OskitTCPAccept( void *socket,
                               head->so_q, head->so_state));
 
     if ((head->so_state & SS_NBIO) && head->so_q == NULL) {
-       splx(s);
        error = EWOULDBLOCK;
        goto out;
     }
 
-    OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-    while (head->so_q == NULL && head->so_error == 0) {
-       if (head->so_state & SS_CANTRCVMORE) {
-           head->so_error = ECONNABORTED;
-           break;
-       }
-       OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-       error = tsleep((caddr_t)&head->so_timeo, PSOCK | PCATCH,
-                      "accept", 0);
-       if (error) {
-           splx(s);
-           goto out;
-       }
-       OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-    }
-    OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-
-#if 0
-    if (head->so_error) {
-       OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-       error = head->so_error;
-       head->so_error = 0;
-       splx(s);
-       goto out;
-    }
-    OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-#endif
-
     /*
      * At this point we know that there is at least one connection
      * ready to be accepted. Remove it from the queue.
@@ -384,19 +393,20 @@ int OskitTCPAccept( void *socket,
        head->so_q = so->so_q;
        head->so_qlen--;
 
-       *newso = so;
-
-       /*so->so_state &= ~SS_COMP;*/
-
        mnam.m_data = (char *)&sa;
        mnam.m_len = sizeof(sa);
 
-       (void) soaccept(so, &mnam);
+       error = soaccept(so, &mnam);
+        if (error)
+            goto out;
 
-       so->so_state = SS_NBIO | SS_ISCONNECTED;
+       so->so_state |= SS_NBIO | SS_ISCONNECTED;
         so->so_q = so->so_q0 = NULL;
-        so->so_qlen = 0;
+        so->so_qlen = so->so_q0len = 0;
         so->so_head = 0;
+        so->so_connection = context;
+
+       *newso = so;
 
        OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
        if (name) {
@@ -406,9 +416,10 @@ int OskitTCPAccept( void *socket,
            *OutAddrLen = namelen;      /* copyout actually */
        }
        OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-       splx(s);
     }
 out:
+    splx(s);
+    OSKUnlock();
     OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: Returning %d\n", error));
     return (error);
 }
@@ -421,10 +432,16 @@ out:
 
 void OskitTCPReceiveDatagram( OSK_PCHAR Data, OSK_UINT Len,
                              OSK_UINT IpHeaderLen ) {
-    struct mbuf *Ip = m_devget( (char *)Data, Len, 0, NULL, NULL );
+    struct mbuf *Ip;
     struct ip *iph;
 
-    if( !Ip ) return; /* drop the segment */
+    OSKLock();
+    Ip = m_devget( (char *)Data, Len, 0, NULL, NULL );
+    if( !Ip )
+    {
+       OSKUnlock();
+       return; /* drop the segment */
+    }
 
     //memcpy( Ip->m_data, Data, Len );
     Ip->m_pkthdr.len = IpHeaderLen;
@@ -439,6 +456,7 @@ void OskitTCPReceiveDatagram( OSK_PCHAR Data, OSK_UINT Len,
                 IpHeaderLen));
 
     tcp_input(Ip, IpHeaderLen);
+    OSKUnlock();
 
     /* The buffer Ip is freed by tcp_input */
 }
@@ -450,6 +468,7 @@ int OskitTCPSetSockOpt(void *socket,
                        int size)
 {
     struct mbuf *m;
+    int error;
 
     if (!socket)
         return OSK_ESHUTDOWN;
@@ -457,16 +476,23 @@ int OskitTCPSetSockOpt(void *socket,
     if (size >= MLEN)
         return OSK_EINVAL;
 
+    OSKLock();
     m = m_get(M_WAIT, MT_SOOPTS);
     if (!m)
+    {
+        OSKUnlock();
         return OSK_ENOMEM;
+    }
 
     m->m_len = size;
 
     memcpy(m->m_data, buffer, size);
 
     /* m is freed by sosetopt */
-    return sosetopt(socket, level, optname, m);
+    error = sosetopt(socket, level, optname, m);
+    OSKUnlock();
+
+    return error;
 }   
 
 int OskitTCPGetSockOpt(void *socket,
@@ -481,6 +507,7 @@ int OskitTCPGetSockOpt(void *socket,
     if (!socket)
         return OSK_ESHUTDOWN;
 
+    OSKLock();
     error = sogetopt(socket, level, optname, &m);
     if (!error)
     {
@@ -489,6 +516,7 @@ int OskitTCPGetSockOpt(void *socket,
         if (!buffer || oldsize < m->m_len)
         {
             m_freem(m);
+            OSKUnlock();
             return OSK_EINVAL;
         }
 
@@ -496,6 +524,7 @@ int OskitTCPGetSockOpt(void *socket,
 
         m_freem(m);
     }
+    OSKUnlock();
 
     return error;
 }
@@ -507,7 +536,11 @@ int OskitTCPListen( void *socket, int backlog ) {
         return OSK_ESHUTDOWN;
 
     OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
+
+    OSKLock();
     error = solisten( socket, backlog );
+    OSKUnlock();
+
     OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
 
     return error;
@@ -524,11 +557,13 @@ int OskitTCPSetAddress( void *socket,
     if (!socket)
         return OSK_ESHUTDOWN;
 
+    OSKLock();
     inp = (struct inpcb *)so->so_pcb;
     inp->inp_laddr.s_addr = LocalAddress;
     inp->inp_lport = LocalPort;
     inp->inp_faddr.s_addr = RemoteAddress;
     inp->inp_fport = RemotePort;
+    OSKUnlock();
 
     return 0;
 }
@@ -544,15 +579,31 @@ int OskitTCPGetAddress( void *socket,
     if (!socket)
         return OSK_ESHUTDOWN;
 
+    OSKLock();
     inp = (struct inpcb *)so->so_pcb;
     *LocalAddress = inp->inp_laddr.s_addr;
     *LocalPort = inp->inp_lport;
     *RemoteAddress = inp->inp_faddr.s_addr;
     *RemotePort = inp->inp_fport;
+    OSKUnlock();
 
     return 0;
 }
 
+int OskitTCPGetSocketError(void *socket) {
+    struct socket *so = socket;
+    int error;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    error = so->so_error;
+    OSKUnlock();
+
+    return error;
+}
+
 struct ifaddr *ifa_iffind(struct sockaddr *addr, int type)
 {
     if( OtcpEvent.FindInterface )