- Merge aicom-network-branch (still without the NDIS stuff)
[reactos.git] / reactos / lib / drivers / oskittcp / oskittcp / interface.c
index 33b2d93..3c589e7 100644 (file)
@@ -23,6 +23,8 @@ OSKITTCP_EVENT_HANDLERS OtcpEvent = { 0 };
 //OSK_UINT OskitDebugTraceLevel = OSK_DEBUG_ULTRA;
 OSK_UINT OskitDebugTraceLevel = 0;
 
+KSPIN_LOCK OSKLock;
+
 /* SPL */
 unsigned cpl;
 unsigned net_imask;
@@ -45,8 +47,7 @@ void fbsd_free( void *data, char *file, unsigned line, ... ) {
 
 void InitOskitTCP() {
     OS_DbgPrint(OSK_MID_TRACE,("Init Called\n"));
-    OS_DbgPrint(OSK_MID_TRACE,("MB Init\n"));
-    mbinit();
+    KeInitializeSpinLock(&OSKLock);
     OS_DbgPrint(OSK_MID_TRACE,("Rawip Init\n"));
     rip_init();
     raw_init();
@@ -66,19 +67,23 @@ void DeinitOskitTCP() {
 }
 
 void TimerOskitTCP( int FastTimer, int SlowTimer ) {
+    KIRQL OldIrql;
+
+    /* This function is a special case in which we cannot use OSKLock/OSKUnlock 
+     * because we don't enter with the connection lock held */
+
+    OSKLockAndRaise(&OldIrql);
     if ( SlowTimer ) {
         tcp_slowtimo();
     }
     if ( FastTimer ) {
         tcp_fasttimo();
     }
+    OSKUnlockAndLower(OldIrql);
 }
 
 void RegisterOskitTCPEventHandlers( POSKITTCP_EVENT_HANDLERS EventHandlers ) {
     memcpy( &OtcpEvent, EventHandlers, sizeof(OtcpEvent) );
-    if( OtcpEvent.PacketSend )
-       OS_DbgPrint(OSK_MID_TRACE,("SendPacket handler registered: %x\n",
-                                  OtcpEvent.PacketSend));
 }
 
 void OskitDumpBuffer( OSK_PCHAR Data, OSK_UINT Len )
@@ -118,16 +123,16 @@ int OskitTCPSocket( void *context,
                    int proto )
 {
     struct socket *so;
+
+    OSKLock();
     int error = socreate(domain, &so, type, proto);
     if( !error ) {
        so->so_connection = context;
-       so->so_state = SS_NBIO;
-       so->so_error = 0;
-        so->so_q = so->so_q0 = NULL;
-        so->so_qlen = 0;
-        so->so_head = NULL;
+       so->so_state |= SS_NBIO;
        *aso = so;
     }
+    OSKUnlock();
+
     return error;
 }
 
@@ -141,8 +146,6 @@ int OskitTCPRecv( void *connection,
     int error = 0;
     int tcp_flags = 0;
 
-    *OutLen = 0;
-
     OS_DbgPrint(OSK_MID_TRACE,
                 ("so->so_state %x\n", ((struct socket *)connection)->so_state));
 
@@ -150,26 +153,29 @@ int OskitTCPRecv( void *connection,
     if( Flags & OSK_MSG_DONTWAIT ) tcp_flags |= MSG_DONTWAIT;
     if( Flags & OSK_MSG_PEEK )     tcp_flags |= MSG_PEEK;
 
-    uio.uio_resid = Len;
-    uio.uio_iov = &iov;
-    uio.uio_rw = UIO_READ;
-    uio.uio_iovcnt = 1;
     iov.iov_len = Len;
     iov.iov_base = (char *)Data;
+    uio.uio_iov = &iov;
+    uio.uio_iovcnt = 1;
+    uio.uio_offset = 0;
+    uio.uio_resid = Len;
+    uio.uio_segflg = UIO_SYSSPACE;
+    uio.uio_rw = UIO_READ;
+    uio.uio_procp = NULL;
 
     OS_DbgPrint(OSK_MID_TRACE,("Reading %d bytes from TCP:\n", Len));
 
+    OSKLock();
     error = soreceive( connection, NULL, &uio, NULL, NULL /* SCM_RIGHTS */,
                       &tcp_flags );
+    OSKUnlock();
 
-    if( error == 0 ) {
-       *OutLen = Len - uio.uio_resid;
-    }
+    *OutLen = Len - uio.uio_resid;
 
     return error;
 }
 
-int OskitTCPBind( void *socket, void *connection,
+int OskitTCPBind( void *socket,
                  void *nam, OSK_UINT namelen ) {
     int error = EFAULT;
     struct socket *so = socket;
@@ -178,6 +184,9 @@ int OskitTCPBind( void *socket, void *connection,
 
     OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     if( nam )
        addr = *((struct sockaddr *)nam);
 
@@ -188,14 +197,15 @@ int OskitTCPBind( void *socket, void *connection,
     addr.sa_family = addr.sa_len;
     addr.sa_len = sizeof(struct sockaddr);
 
+    OSKLock();
     error = sobind(so, &sabuf);
+    OSKUnlock();
 
     OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
     return (error);
 }
 
-int OskitTCPConnect( void *socket, void *connection,
-                    void *nam, OSK_UINT namelen ) {
+int OskitTCPConnect( void *socket, void *nam, OSK_UINT namelen ) {
     struct socket *so = socket;
     int error = EFAULT;
     struct mbuf sabuf;
@@ -203,8 +213,7 @@ int OskitTCPConnect( void *socket, void *connection,
 
     OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
 
-    so->so_connection = connection;
-
+    OSKLock();
     if ((so->so_state & SS_NBIO) && (so->so_state & SS_ISCONNECTING)) {
        error = EALREADY;
        goto done;
@@ -223,7 +232,9 @@ int OskitTCPConnect( void *socket, void *connection,
 
     error = soconnect(so, &sabuf);
 
-    if (error)
+    if (error == EINPROGRESS)
+        goto done;
+    else if (error)
        goto bad;
 
     if ((so->so_state & SS_NBIO) && (so->so_state & SS_ISCONNECTING)) {
@@ -238,19 +249,49 @@ bad:
        error = EINTR;
 
 done:
+    OSKUnlock();
     OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
     return (error);
 }
 
+int OskitTCPDisconnect(void *socket)
+{
+    int error;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    error = sodisconnect(socket);
+    OSKUnlock();
+
+    return error;
+}
+
 int OskitTCPShutdown( void *socket, int disconn_type ) {
-    return soshutdown( socket, disconn_type );
+    int error;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    error = soshutdown( socket, disconn_type );
+    OSKUnlock();
+
+    return error;
 }
 
 int OskitTCPClose( void *socket ) {
-    struct socket *so = socket;
-    so->so_connection = 0;
-    soclose( so );
-    return 0;
+    int error;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    error = soclose( socket );
+    OSKUnlock();
+
+    return error;
 }
 
 int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
@@ -259,6 +300,9 @@ int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
     struct uio uio;
     struct iovec iov;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     iov.iov_len = Len;
     iov.iov_base = (char *)Data;
     uio.uio_iov = &iov;
@@ -269,17 +313,18 @@ int OskitTCPSend( void *socket, OSK_PCHAR Data, OSK_UINT Len,
     uio.uio_rw = UIO_WRITE;
     uio.uio_procp = NULL;
 
+    OSKLock();
     error = sosend( socket, NULL, &uio, NULL, NULL, 0 );
-    if (OSK_EWOULDBLOCK == error) {
-       ((struct socket *) socket)->so_snd.sb_flags |= SB_WAIT;
-    }
-    *OutLen = uio.uio_offset;
+    OSKUnlock();
+
+    *OutLen = Len - uio.uio_resid;
 
     return error;
 }
 
 int OskitTCPAccept( void *socket,
                    void **new_socket,
+                   void *context,
                    void *AddrOut,
                    OSK_UINT AddrLen,
                    OSK_UINT *OutAddrLen,
@@ -293,6 +338,12 @@ int OskitTCPAccept( void *socket,
     struct inpcb *inp;
     int namelen = 0, error = 0, s;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    if (!new_socket || !AddrOut)
+        return OSK_EINVAL;
+
     OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: Doing accept (Finish %d)\n",
                               FinishAccepting));
 
@@ -302,11 +353,12 @@ int OskitTCPAccept( void *socket,
        /* that's a copyin actually */
        namelen = *OutAddrLen;
 
+    OSKLock();
+
     s = splnet();
 
 #if 0
     if ((head->so_options & SO_ACCEPTCONN) == 0) {
-       splx(s);
        OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: head->so_options = %x, wanted bit %x\n",
                                   head->so_options, SO_ACCEPTCONN));
        error = EINVAL;
@@ -318,39 +370,10 @@ int OskitTCPAccept( void *socket,
                               head->so_q, head->so_state));
 
     if ((head->so_state & SS_NBIO) && head->so_q == NULL) {
-       splx(s);
        error = EWOULDBLOCK;
        goto out;
     }
 
-    OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-    while (head->so_q == NULL && head->so_error == 0) {
-       if (head->so_state & SS_CANTRCVMORE) {
-           head->so_error = ECONNABORTED;
-           break;
-       }
-       OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-       error = tsleep((caddr_t)&head->so_timeo, PSOCK | PCATCH,
-                      "accept", 0);
-       if (error) {
-           splx(s);
-           goto out;
-       }
-       OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-    }
-    OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-
-#if 0
-    if (head->so_error) {
-       OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-       error = head->so_error;
-       head->so_error = 0;
-       splx(s);
-       goto out;
-    }
-    OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-#endif
-
     /*
      * At this point we know that there is at least one connection
      * ready to be accepted. Remove it from the queue.
@@ -358,30 +381,31 @@ int OskitTCPAccept( void *socket,
     so = head->so_q;
 
     inp = so ? (struct inpcb *)so->so_pcb : NULL;
-    if( inp ) {
+    if( inp && name ) {
         ((struct sockaddr_in *)AddrOut)->sin_addr.s_addr =
             inp->inp_faddr.s_addr;
         ((struct sockaddr_in *)AddrOut)->sin_port = inp->inp_fport;
     }
 
     OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-    if( FinishAccepting ) {
+    if( FinishAccepting && so ) {
        head->so_q = so->so_q;
        head->so_qlen--;
 
-       *newso = so;
-
-       /*so->so_state &= ~SS_COMP;*/
-
        mnam.m_data = (char *)&sa;
        mnam.m_len = sizeof(sa);
 
-       (void) soaccept(so, &mnam);
+       error = soaccept(so, &mnam);
+        if (error)
+            goto out;
 
-       so->so_state = SS_NBIO | SS_ISCONNECTED;
+       so->so_state |= SS_NBIO | SS_ISCONNECTED;
         so->so_q = so->so_q0 = NULL;
-        so->so_qlen = 0;
+        so->so_qlen = so->so_q0len = 0;
         so->so_head = 0;
+        so->so_connection = context;
+
+       *newso = so;
 
        OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
        if (name) {
@@ -391,9 +415,10 @@ int OskitTCPAccept( void *socket,
            *OutAddrLen = namelen;      /* copyout actually */
        }
        OS_DbgPrint(OSK_MID_TRACE,("error = %d\n", error));
-       splx(s);
     }
 out:
+    splx(s);
+    OSKUnlock();
     OS_DbgPrint(OSK_MID_TRACE,("OSKITTCP: Returning %d\n", error));
     return (error);
 }
@@ -406,10 +431,16 @@ out:
 
 void OskitTCPReceiveDatagram( OSK_PCHAR Data, OSK_UINT Len,
                              OSK_UINT IpHeaderLen ) {
-    struct mbuf *Ip = m_devget( (char *)Data, Len, 0, NULL, NULL );
+    struct mbuf *Ip;
     struct ip *iph;
 
-    if( !Ip ) return; /* drop the segment */
+    OSKLock();
+    Ip = m_devget( (char *)Data, Len, 0, NULL, NULL );
+    if( !Ip )
+    {
+       OSKUnlock();
+       return; /* drop the segment */
+    }
 
     //memcpy( Ip->m_data, Data, Len );
     Ip->m_pkthdr.len = IpHeaderLen;
@@ -424,46 +455,152 @@ void OskitTCPReceiveDatagram( OSK_PCHAR Data, OSK_UINT Len,
                 IpHeaderLen));
 
     tcp_input(Ip, IpHeaderLen);
+    OSKUnlock();
 
     /* The buffer Ip is freed by tcp_input */
 }
 
+int OskitTCPSetSockOpt(void *socket,
+                       int level,
+                       int optname,
+                       char *buffer,
+                       int size)
+{
+    struct mbuf *m;
+    int error;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    if (size >= MLEN)
+        return OSK_EINVAL;
+
+    OSKLock();
+    m = m_get(M_WAIT, MT_SOOPTS);
+    if (!m)
+    {
+        OSKUnlock();
+        return OSK_ENOMEM;
+    }
+
+    m->m_len = size;
+
+    memcpy(m->m_data, buffer, size);
+
+    /* m is freed by sosetopt */
+    error = sosetopt(socket, level, optname, m);
+    OSKUnlock();
+
+    return error;
+}   
+
+int OskitTCPGetSockOpt(void *socket,
+                       int level,
+                       int optname,
+                       char *buffer,
+                       int *size)
+{
+    int error, oldsize = *size;
+    struct mbuf *m;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    error = sogetopt(socket, level, optname, &m);
+    if (!error)
+    {
+        *size = m->m_len;
+
+        if (!buffer || oldsize < m->m_len)
+        {
+            m_freem(m);
+            OSKUnlock();
+            return OSK_EINVAL;
+        }
+
+        memcpy(buffer, m->m_data, m->m_len);
+
+        m_freem(m);
+    }
+    OSKUnlock();
+
+    return error;
+}
+
 int OskitTCPListen( void *socket, int backlog ) {
     int error;
 
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
     OS_DbgPrint(OSK_MID_TRACE,("Called, socket = %08x\n", socket));
+
+    OSKLock();
     error = solisten( socket, backlog );
+    OSKUnlock();
+
     OS_DbgPrint(OSK_MID_TRACE,("Ending: %08x\n", error));
 
     return error;
 }
 
-void OskitTCPSetAddress( void *socket,
+int OskitTCPSetAddress( void *socket,
                         OSK_UINT LocalAddress,
                         OSK_UI16 LocalPort,
                         OSK_UINT RemoteAddress,
                         OSK_UI16 RemotePort ) {
     struct socket *so = socket;
-    struct inpcb *inp = (struct inpcb *)so->so_pcb;
+    struct inpcb *inp;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    inp = (struct inpcb *)so->so_pcb;
     inp->inp_laddr.s_addr = LocalAddress;
     inp->inp_lport = LocalPort;
     inp->inp_faddr.s_addr = RemoteAddress;
     inp->inp_fport = RemotePort;
+    OSKUnlock();
+
+    return 0;
 }
 
-void OskitTCPGetAddress( void *socket,
+int OskitTCPGetAddress( void *socket,
                         OSK_UINT *LocalAddress,
                         OSK_UI16 *LocalPort,
                         OSK_UINT *RemoteAddress,
                         OSK_UI16 *RemotePort ) {
     struct socket *so = socket;
-    struct inpcb *inp = so ? (struct inpcb *)so->so_pcb : NULL;
-    if( inp ) {
-       *LocalAddress = inp->inp_laddr.s_addr;
-       *LocalPort = inp->inp_lport;
-       *RemoteAddress = inp->inp_faddr.s_addr;
-       *RemotePort = inp->inp_fport;
-    }
+    struct inpcb *inp;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    inp = (struct inpcb *)so->so_pcb;
+    *LocalAddress = inp->inp_laddr.s_addr;
+    *LocalPort = inp->inp_lport;
+    *RemoteAddress = inp->inp_faddr.s_addr;
+    *RemotePort = inp->inp_fport;
+    OSKUnlock();
+
+    return 0;
+}
+
+int OskitTCPGetSocketError(void *socket) {
+    struct socket *so = socket;
+    int error;
+
+    if (!socket)
+        return OSK_ESHUTDOWN;
+
+    OSKLock();
+    error = so->so_error;
+    OSKUnlock();
+
+    return error;
 }
 
 struct ifaddr *ifa_iffind(struct sockaddr *addr, int type)
@@ -479,7 +616,7 @@ struct ifaddr *ifa_iffind(struct sockaddr *addr, int type)
 
 void oskittcp_die( const char *file, int line ) {
     DbgPrint("\n\n*** OSKITTCP: Panic Called at %s:%d ***\n", file, line);
-    *((int *)0) = 0;
+    ASSERT(FALSE);
 }
 
 /* Stuff supporting the BSD network-interface interface */