[MSAFD] Implement async connect
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
index 4102fd9..dda602f 100644 (file)
@@ -95,13 +95,33 @@ WSPSocket(int AddressFamily,
         Protocol = SharedData->Protocol;
     }
 
-    if (AddressFamily == AF_UNSPEC && SocketType == 0 && Protocol == 0)
+    if (lpProtocolInfo)
+    {
+        if (lpProtocolInfo->iAddressFamily && AddressFamily <= 0)
+            AddressFamily = lpProtocolInfo->iAddressFamily;
+        if (lpProtocolInfo->iSocketType && SocketType <= 0)
+            SocketType = lpProtocolInfo->iSocketType;
+        if (lpProtocolInfo->iProtocol && Protocol <= 0)
+            Protocol = lpProtocolInfo->iProtocol;
+    }
+
+    /* FIXME: AF_NETDES should be AF_MAX */
+    if (AddressFamily < AF_UNSPEC || AddressFamily > AF_NETDES)
         return WSAEINVAL;
 
-    /* Set the defaults */
-    if (AddressFamily == AF_UNSPEC)
-        AddressFamily = AF_INET;
+    if (SocketType < 0 && SocketType > SOCK_SEQPACKET)
+        return WSAEINVAL;
 
+    if (Protocol < 0 && Protocol > IPPROTO_MAX)
+        return WSAEINVAL;
+
+    /* when no protocol and socket type are specified the first entry
+    * from WSAEnumProtocols that has the flag PFL_MATCHES_PROTOCOL_ZERO
+    * is returned */
+    if (SocketType == 0 && Protocol == 0 && lpProtocolInfo && (lpProtocolInfo->dwProviderFlags & PFL_MATCHES_PROTOCOL_ZERO) == 0)
+        return WSAEINVAL;
+
+    /* Set the defaults */
     if (SocketType == 0)
     {
         switch (Protocol)
@@ -117,8 +137,7 @@ WSPSocket(int AddressFamily,
             break;
         default:
             TRACE("Unknown Protocol (%d). We will try SOCK_STREAM.\n", Protocol);
-            SocketType = SOCK_STREAM;
-            break;
+            return WSAEINVAL;
         }
     }
 
@@ -137,11 +156,13 @@ WSPSocket(int AddressFamily,
             break;
         default:
             TRACE("Unknown SocketType (%d). We will try IPPROTO_TCP.\n", SocketType);
-            Protocol = IPPROTO_TCP;
-            break;
+            return WSAEINVAL;
         }
     }
 
+    if (AddressFamily == AF_UNSPEC)
+        return WSAEINVAL;
+
     /* Get Helper Data and Transport */
     Status = SockGetTdiName (&AddressFamily,
                              &SocketType,
@@ -1031,13 +1052,14 @@ WSPSelect(IN int nfds,
     PAFD_POLL_INFO      PollInfo;
     NTSTATUS            Status;
     ULONG               HandleCount;
-    LONG                OutCount = 0;
     ULONG               PollBufferSize;
     PVOID               PollBuffer;
     ULONG               i, j = 0, x;
     HANDLE              SockEvent;
-    BOOL                HandleCounted;
     LARGE_INTEGER       Timeout;
+    PSOCKET_INFORMATION Socket;
+    SOCKET              Handle;
+    ULONG               Events;
 
     /* Find out how many sockets we have, and how large the buffer needs
      * to be */
@@ -1119,28 +1141,68 @@ WSPSelect(IN int nfds,
     if (readfds != NULL) {
         for (i = 0; i < readfds->fd_count; i++, j++)
         {
+            Socket = GetSocketStructure(readfds->fd_array[i]);
+            if (!Socket)
+            {
+                ERR("Invalid socket handle provided in readfds %d\n", readfds->fd_array[i]);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             PollInfo->Handles[j].Handle = readfds->fd_array[i];
             PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
                                           AFD_EVENT_DISCONNECT |
                                           AFD_EVENT_ABORT |
                                           AFD_EVENT_CLOSE |
                                           AFD_EVENT_ACCEPT;
+            if (Socket->SharedData->OobInline != 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
         }
     }
     if (writefds != NULL)
     {
         for (i = 0; i < writefds->fd_count; i++, j++)
         {
+            Socket = GetSocketStructure(writefds->fd_array[i]);
+            if (!Socket)
+            {
+                ERR("Invalid socket handle provided in writefds %d\n", writefds->fd_array[i]);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             PollInfo->Handles[j].Handle = writefds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
+            PollInfo->Handles[j].Events = AFD_EVENT_SEND;
+            if (Socket->SharedData->NonBlocking != 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT;
         }
     }
     if (exceptfds != NULL)
     {
         for (i = 0; i < exceptfds->fd_count; i++, j++)
         {
+            Socket = GetSocketStructure(exceptfds->fd_array[i]);
+            if (!Socket)
+            {
+                TRACE("Invalid socket handle provided in exceptfds %d\n", exceptfds->fd_array[i]);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
+            PollInfo->Handles[j].Events = 0;
+            if (Socket->SharedData->OobInline == 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
+            if (Socket->SharedData->NonBlocking != 0)
+                PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT_FAIL;
+            if (PollInfo->Handles[j].Events == 0)
+            {
+                TRACE("No events can be checked for exceptfds %d. It is nonblocking and OOB line is disabled. Skipping it.", exceptfds->fd_array[i]);
+                j--;
+            }
         }
     }
 
@@ -1182,10 +1244,20 @@ WSPSelect(IN int nfds,
     /* Return in FDSET Format */
     for (i = 0; i < HandleCount; i++)
     {
-        HandleCounted = FALSE;
+        Events = PollInfo->Handles[i].Events;
+        Handle = PollInfo->Handles[i].Handle;
         for(x = 1; x; x<<=1)
         {
-            switch (PollInfo->Handles[i].Events & x)
+            Socket = GetSocketStructure(Handle);
+            if (!Socket)
+            {
+                TRACE("Invalid socket handle found %d\n", Handle);
+                if (lpErrno) *lpErrno = WSAENOTSOCK;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
+            switch (Events & x)
             {
                 case AFD_EVENT_RECEIVE:
                 case AFD_EVENT_DISCONNECT:
@@ -1193,41 +1265,40 @@ WSPSelect(IN int nfds,
                 case AFD_EVENT_ACCEPT:
                 case AFD_EVENT_CLOSE:
                     TRACE("Event %x on handle %x\n",
-                        PollInfo->Handles[i].Events,
-                        PollInfo->Handles[i].Handle);
-                    if (! HandleCounted)
-                    {
-                        OutCount++;
-                        HandleCounted = TRUE;
-                    }
+                        Events,
+                        Handle);
                     if( readfds )
-                        FD_SET(PollInfo->Handles[i].Handle, readfds);
+                        FD_SET(Handle, readfds);
                     break;
                 case AFD_EVENT_SEND:
+                    TRACE("Event %x on handle %x\n",
+                        Events,
+                        Handle);
+                    if (writefds)
+                        FD_SET(Handle, writefds);
+                    break;
                 case AFD_EVENT_CONNECT:
                     TRACE("Event %x on handle %x\n",
-                        PollInfo->Handles[i].Events,
-                        PollInfo->Handles[i].Handle);
-                    if (! HandleCounted)
-                    {
-                        OutCount++;
-                        HandleCounted = TRUE;
-                    }
-                    if( writefds )
-                        FD_SET(PollInfo->Handles[i].Handle, writefds);
+                        Events,
+                        Handle);
+                    if( writefds && Socket->SharedData->NonBlocking != 0 )
+                        FD_SET(Handle, writefds);
                     break;
                 case AFD_EVENT_OOB_RECEIVE:
+                    TRACE("Event %x on handle %x\n",
+                        Events,
+                        Handle);
+                    if( readfds && Socket->SharedData->OobInline != 0 )
+                        FD_SET(Handle, readfds);
+                    if( exceptfds && Socket->SharedData->OobInline == 0 )
+                        FD_SET(Handle, exceptfds);
+                    break;
                 case AFD_EVENT_CONNECT_FAIL:
                     TRACE("Event %x on handle %x\n",
-                        PollInfo->Handles[i].Events,
-                        PollInfo->Handles[i].Handle);
-                    if (! HandleCounted)
-                    {
-                        OutCount++;
-                        HandleCounted = TRUE;
-                    }
-                    if( exceptfds )
-                        FD_SET(PollInfo->Handles[i].Handle, exceptfds);
+                        Events,
+                        Handle);
+                    if( exceptfds && Socket->SharedData->NonBlocking != 0 )
+                        FD_SET(Handle, exceptfds);
                     break;
             }
         }
@@ -1251,9 +1322,13 @@ WSPSelect(IN int nfds,
         TRACE("*lpErrno = %x\n", *lpErrno);
     }
 
-    TRACE("%d events\n", OutCount);
+    HandleCount = (readfds ? readfds->fd_count : 0) +
+                  (writefds && writefds != readfds ? writefds->fd_count : 0) +
+                  (exceptfds && exceptfds != readfds && exceptfds != writefds ? exceptfds->fd_count : 0);
+
+    TRACE("%d events\n", HandleCount);
 
-    return OutCount;
+    return HandleCount;
 }
 
 SOCKET
@@ -1644,6 +1719,46 @@ WSPAccept(SOCKET Handle,
     return AcceptSocket;
 }
 
+VOID
+NTAPI
+AfdConnectAPC(PVOID ApcContext,
+              PIO_STATUS_BLOCK IoStatusBlock,
+              ULONG Reserved)
+{
+    PAFDCONNECTAPCCONTEXT Context = ApcContext;
+
+    if (IoStatusBlock->Status == STATUS_SUCCESS)
+    {
+        Context->lpSocket->SharedData->State = SocketConnected;
+        Context->lpSocket->TdiConnectionHandle = (HANDLE)IoStatusBlock->Information;
+    }
+
+    if (Context->lpConnectInfo) HeapFree(GetProcessHeap(), 0, Context->lpConnectInfo);
+
+    /* Re-enable Async Event */
+    SockReenableAsyncSelectEvent(Context->lpSocket, FD_WRITE);
+
+    /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
+    SockReenableAsyncSelectEvent(Context->lpSocket, FD_CONNECT);
+
+    if (IoStatusBlock->Status == STATUS_SUCCESS && (Context->lpSocket->HelperEvents & WSH_NOTIFY_CONNECT))
+    {
+        Context->lpSocket->HelperData->WSHNotify(Context->lpSocket->HelperContext,
+                                                 Context->lpSocket->Handle,
+                                                 Context->lpSocket->TdiAddressHandle,
+                                                 Context->lpSocket->TdiConnectionHandle,
+                                                 WSH_NOTIFY_CONNECT);
+    }
+    else if (IoStatusBlock->Status != STATUS_SUCCESS && (Context->lpSocket->HelperEvents & WSH_NOTIFY_CONNECT_ERROR))
+    {
+        Context->lpSocket->HelperData->WSHNotify(Context->lpSocket->HelperContext,
+                                                 Context->lpSocket->Handle,
+                                                 Context->lpSocket->TdiAddressHandle,
+                                                 Context->lpSocket->TdiConnectionHandle,
+                                                 WSH_NOTIFY_CONNECT_ERROR);
+    }
+    HeapFree(GlobalHeap, 0, ApcContext);
+}
 int
 WSPAPI
 WSPConnect(SOCKET Handle,
@@ -1666,15 +1781,8 @@ WSPConnect(SOCKET Handle,
     PSOCKADDR               BindAddress;
     HANDLE                  SockEvent;
     int                     SocketDataLength;
-
-    Status = NtCreateEvent(&SockEvent,
-                           EVENT_ALL_ACCESS,
-                           NULL,
-                           1,
-                           FALSE);
-
-    if (!NT_SUCCESS(Status))
-        return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
+    PVOID                   APCContext = NULL;
+    PVOID                   APCFunction = NULL;
 
     TRACE("Called\n");
 
@@ -1682,11 +1790,19 @@ WSPConnect(SOCKET Handle,
     Socket = GetSocketStructure(Handle);
     if (!Socket)
     {
-        NtClose(SockEvent);
         if (lpErrno) *lpErrno = WSAENOTSOCK;
         return SOCKET_ERROR;
     }
 
+    Status = NtCreateEvent(&SockEvent,
+                           EVENT_ALL_ACCESS,
+                           NULL,
+                           1,
+                           FALSE);
+
+    if (!NT_SUCCESS(Status))
+        return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
+
     /* Bind us First */
     if (Socket->SharedData->State == SocketOpen)
     {
@@ -1797,14 +1913,22 @@ WSPConnect(SOCKET Handle,
     /* FIXME: Handle Async Connect */
     if (Socket->SharedData->NonBlocking)
     {
-        ERR("Async Connect UNIMPLEMENTED!\n");
+        APCFunction = &AfdConnectAPC; // should be a private io completition function inside us
+        APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDCONNECTAPCCONTEXT));
+        if (!APCContext)
+        {
+            ERR("Not enough memory for APC Context\n");
+            return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+        }
+        ((PAFDCONNECTAPCCONTEXT)APCContext)->lpConnectInfo = ConnectInfo;
+        ((PAFDCONNECTAPCCONTEXT)APCContext)->lpSocket = Socket;
     }
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)Handle,
                                    SockEvent,
-                                   NULL,
-                                   NULL,
+                                   APCFunction,
+                                   APCContext,
                                    &IOSB,
                                    IOCTL_AFD_CONNECT,
                                    ConnectInfo,
@@ -1812,12 +1936,20 @@ WSPConnect(SOCKET Handle,
                                    NULL,
                                    0);
     /* Wait for return */
-    if (Status == STATUS_PENDING)
+    if (Status == STATUS_PENDING && !Socket->SharedData->NonBlocking)
     {
         WaitForSingleObject(SockEvent, INFINITE);
         Status = IOSB.Status;
     }
 
+    if (Status == STATUS_PENDING)
+    {
+        TRACE("Leaving (Pending)\n");
+        return MsafdReturnWithErrno(STATUS_CANT_WAIT, lpErrno, 0, NULL);
+    }
+
+    if (APCContext) HeapFree(GetProcessHeap(), 0, APCContext);
+
     if (Status != STATUS_SUCCESS)
         goto notify;