[MSAFD] Don't update shared state on close if we still have active references to...
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
index 00190a8..747299e 100644 (file)
@@ -79,16 +79,25 @@ WSPSocket(int AddressFamily,
     {
         /* Duplpicating socket from different process */
         if ((HANDLE)lpProtocolInfo->dwServiceFlags3 == INVALID_HANDLE_VALUE)
-            return WSAEINVAL;
+        {
+            Status = WSAEINVAL;
+            goto error;
+        }
         if ((HANDLE)lpProtocolInfo->dwServiceFlags4 == INVALID_HANDLE_VALUE)
-            return WSAEINVAL;
+        {
+            Status = WSAEINVAL;
+            goto error;
+        }
         SharedData = MapViewOfFile((HANDLE)lpProtocolInfo->dwServiceFlags3,
                                    FILE_MAP_ALL_ACCESS,
                                    0,
                                    0,
                                    sizeof(SOCK_SHARED_INFO));
         if (!SharedData)
-            return WSAEINVAL;
+        {
+            Status = WSAEINVAL;
+            goto error;
+        }
         InterlockedIncrement(&SharedData->RefCount);
         AddressFamily = SharedData->AddressFamily;
         SocketType = SharedData->SocketType;
@@ -96,7 +105,10 @@ WSPSocket(int AddressFamily,
     }
 
     if (AddressFamily == AF_UNSPEC && SocketType == 0 && Protocol == 0)
-        return WSAEINVAL;
+    {
+        Status = WSAEINVAL;
+        goto error;
+    }
 
     /* Set the defaults */
     if (AddressFamily == AF_UNSPEC)
@@ -167,7 +179,7 @@ WSPSocket(int AddressFamily,
     Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
     if (!Socket)
     {
-        Status = STATUS_INSUFFICIENT_RESOURCES;
+        Status = WSAENOBUFS;
         goto error;
     }
     RtlZeroMemory(Socket, sizeof(*Socket));
@@ -184,7 +196,7 @@ WSPSocket(int AddressFamily,
         Socket->SharedData = HeapAlloc(GlobalHeap, 0, sizeof(*Socket->SharedData));
         if (!Socket->SharedData)
         {
-            Status = STATUS_INSUFFICIENT_RESOURCES;
+            Status = WSAENOBUFS;
             goto error;
         }
         RtlZeroMemory(Socket->SharedData, sizeof(*Socket->SharedData));
@@ -236,7 +248,7 @@ WSPSocket(int AddressFamily,
     EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
     if (!EABuffer)
     {
-        Status = STATUS_INSUFFICIENT_RESOURCES;
+        Status = WSAENOBUFS;
         goto error;
     }
 
@@ -263,6 +275,7 @@ WSPSocket(int AddressFamily,
         if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW))
         {
             /* Only RAW or UDP can be Connectionless */
+            Status = WSAEINVAL;
             goto error;
         }
         AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
@@ -275,6 +288,7 @@ WSPSocket(int AddressFamily,
             if ((Socket->SharedData->ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
             {
                 /* The Provider doesn't actually support Message Oriented Streams */
+                Status = WSAEINVAL;
                 goto error;
             }
         }
@@ -291,6 +305,7 @@ WSPSocket(int AddressFamily,
         if ((Socket->SharedData->ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
         {
             /* The Provider doesn't actually support Multipoint */
+            Status = WSAEINVAL;
             goto error;
         }
         AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
@@ -301,6 +316,7 @@ WSPSocket(int AddressFamily,
                 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0))
             {
                 /* The Provider doesn't support Control Planes, or you already gave a leaf */
+                Status = WSAEINVAL;
                 goto error;
             }
             AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
@@ -312,6 +328,7 @@ WSPSocket(int AddressFamily,
                 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0))
             {
                 /* The Provider doesn't support Data Planes, or you already gave a leaf */
+                Status = WSAEINVAL;
                 goto error;
             }
             AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
@@ -346,6 +363,7 @@ WSPSocket(int AddressFamily,
     if (!NT_SUCCESS(Status))
     {
         ERR("Failed to open socket. Status 0x%08x\n", Status);
+        Status = TranslateNtStatusError(Status);
         goto error;
     }
 
@@ -659,15 +677,16 @@ WSPCloseSocket(IN SOCKET Handle,
         if (lpErrno) *lpErrno = WSAENOTSOCK;
         return SOCKET_ERROR;
     }
-    /* Set the state to close */
-    OldState = Socket->SharedData->State;
-    Socket->SharedData->State = SocketClosed;
 
     /* Decrement reference count on SharedData */
     References = InterlockedDecrement(&Socket->SharedData->RefCount);
     if (References)
         goto ok;
 
+    /* Set the state to close */
+    OldState = Socket->SharedData->State;
+    Socket->SharedData->State = SocketClosed;
+
     /* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
     /* FIXME: Should we do this on Datagram Sockets too? */
     if ((OldState == SocketConnected) && (Socket->SharedData->LingerData.l_onoff))
@@ -706,7 +725,7 @@ WSPCloseSocket(IN SOCKET Handle,
             /*
              * We have to execute a sleep, so it's kind of like
              * a block. If the socket is Nonblock, we cannot
-             * go on since asyncronous operation is expected
+             * go on since asynchronous operation is expected
              * and we cannot offer it
              */
             if (Socket->SharedData->NonBlocking)
@@ -842,6 +861,27 @@ WSPBind(SOCKET Handle,
        if (lpErrno) *lpErrno = WSAENOTSOCK;
        return SOCKET_ERROR;
     }
+    if (Socket->SharedData->State != SocketOpen)
+    {
+       if (lpErrno) *lpErrno = WSAEINVAL;
+       return SOCKET_ERROR;
+    }
+    if (!SocketAddress || SocketAddressLength < Socket->SharedData->SizeOfLocalAddress)
+    {
+        if (lpErrno) *lpErrno = WSAEINVAL;
+        return SOCKET_ERROR;
+    }
+
+    /* Get Address Information */
+    Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
+                                            SocketAddressLength,
+                                            &SocketInfo);
+
+    if (SocketInfo.AddressInfo == SockaddrAddressInfoBroadcast && !Socket->SharedData->Broadcast)
+    {
+       if (lpErrno) *lpErrno = WSAEADDRNOTAVAIL;
+       return SOCKET_ERROR;
+    }
 
     Status = NtCreateEvent(&SockEvent,
                            EVENT_ALL_ACCESS,
@@ -869,11 +909,6 @@ WSPBind(SOCKET Handle,
                    SocketAddress->sa_data,
                    SocketAddressLength - sizeof(SocketAddress->sa_family));
 
-    /* Get Address Information */
-    Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
-                                            SocketAddressLength,
-                                            &SocketInfo);
-
     /* Set the Share Type */
     if (Socket->SharedData->ExclusiveAddressUse)
     {
@@ -1045,13 +1080,34 @@ WSPSelect(IN int nfds,
     PSOCKET_INFORMATION Socket;
     SOCKET              Handle;
     ULONG               Events;
+    fd_set              selectfds;
 
     /* Find out how many sockets we have, and how large the buffer needs
      * to be */
+    FD_ZERO(&selectfds);
+    if (readfds != NULL)
+    {
+        for (i = 0; i < readfds->fd_count; i++)
+        {
+            FD_SET(readfds->fd_array[i], &selectfds);
+        }
+    }
+    if (writefds != NULL)
+    {
+        for (i = 0; i < writefds->fd_count; i++)
+        {
+            FD_SET(writefds->fd_array[i], &selectfds);
+        }
+    }
+    if (exceptfds != NULL)
+    {
+        for (i = 0; i < exceptfds->fd_count; i++)
+        {
+            FD_SET(exceptfds->fd_array[i], &selectfds);
+        }
+    }
 
-    HandleCount = ( readfds ? readfds->fd_count : 0 ) +
-                  ( writefds ? writefds->fd_count : 0 ) +
-                  ( exceptfds ? exceptfds->fd_count : 0 );
+    HandleCount = selectfds.fd_count;
 
     if ( HandleCount == 0 )
     {
@@ -1082,7 +1138,7 @@ WSPSelect(IN int nfds,
         if (Timeout.QuadPart > 0)
         {
             if (lpErrno) *lpErrno = WSAEINVAL;
-                return SOCKET_ERROR;
+            return SOCKET_ERROR;
         }
         TRACE("Timeout: Orig %d.%06d kernel %d\n",
                      timeout->tv_sec, timeout->tv_usec,
@@ -1123,9 +1179,26 @@ WSPSelect(IN int nfds,
     PollInfo->Exclusive = FALSE;
     PollInfo->Timeout = Timeout;
 
+    for (i = 0; i < selectfds.fd_count; i++)
+    {
+        PollInfo->Handles[i].Handle = selectfds.fd_array[i];
+    }
     if (readfds != NULL) {
-        for (i = 0; i < readfds->fd_count; i++, j++)
+        for (i = 0; i < readfds->fd_count; i++)
         {
+            for (j = 0; j < HandleCount; j++)
+            {
+                if (PollInfo->Handles[j].Handle == readfds->fd_array[i])
+                    break;
+            }
+            if (j >= HandleCount)
+            {
+                ERR("Error while counting readfds %ld > %ld\n", j, HandleCount);
+                if (lpErrno) *lpErrno = WSAEFAULT;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             Socket = GetSocketStructure(readfds->fd_array[i]);
             if (!Socket)
             {
@@ -1135,20 +1208,32 @@ WSPSelect(IN int nfds,
                 NtClose(SockEvent);
                 return SOCKET_ERROR;
             }
-            PollInfo->Handles[j].Handle = readfds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
-                                          AFD_EVENT_DISCONNECT |
-                                          AFD_EVENT_ABORT |
-                                          AFD_EVENT_CLOSE |
-                                          AFD_EVENT_ACCEPT;
-            if (Socket->SharedData->OobInline != 0)
-                PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
+            PollInfo->Handles[j].Events |= AFD_EVENT_RECEIVE |
+                                           AFD_EVENT_DISCONNECT |
+                                           AFD_EVENT_ABORT |
+                                           AFD_EVENT_CLOSE |
+                                           AFD_EVENT_ACCEPT;
+            //if (Socket->SharedData->OobInline != 0)
+            //    PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
         }
     }
     if (writefds != NULL)
     {
-        for (i = 0; i < writefds->fd_count; i++, j++)
+        for (i = 0; i < writefds->fd_count; i++)
         {
+            for (j = 0; j < HandleCount; j++)
+            {
+                if (PollInfo->Handles[j].Handle == writefds->fd_array[i])
+                    break;
+            }
+            if (j >= HandleCount)
+            {
+                ERR("Error while counting writefds %ld > %ld\n", j, HandleCount);
+                if (lpErrno) *lpErrno = WSAEFAULT;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             Socket = GetSocketStructure(writefds->fd_array[i]);
             if (!Socket)
             {
@@ -1159,15 +1244,28 @@ WSPSelect(IN int nfds,
                 return SOCKET_ERROR;
             }
             PollInfo->Handles[j].Handle = writefds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_SEND;
+            PollInfo->Handles[j].Events |= AFD_EVENT_SEND;
             if (Socket->SharedData->NonBlocking != 0)
                 PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT;
         }
     }
     if (exceptfds != NULL)
     {
-        for (i = 0; i < exceptfds->fd_count; i++, j++)
+        for (i = 0; i < exceptfds->fd_count; i++)
         {
+            for (j = 0; j < HandleCount; j++)
+            {
+                if (PollInfo->Handles[j].Handle == exceptfds->fd_array[i])
+                    break;
+            }
+            if (j > HandleCount)
+            {
+                ERR("Error while counting exceptfds %ld > %ld\n", j, HandleCount);
+                if (lpErrno) *lpErrno = WSAEFAULT;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             Socket = GetSocketStructure(exceptfds->fd_array[i]);
             if (!Socket)
             {
@@ -1178,20 +1276,14 @@ WSPSelect(IN int nfds,
                 return SOCKET_ERROR;
             }
             PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
-            PollInfo->Handles[j].Events = 0;
             if (Socket->SharedData->OobInline == 0)
                 PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
             if (Socket->SharedData->NonBlocking != 0)
                 PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT_FAIL;
-            if (PollInfo->Handles[j].Events == 0)
-            {
-                TRACE("No events can be checked for exceptfds %d. It is nonblocking and OOB line is disabled. Skipping it.", exceptfds->fd_array[i]);
-                j--;
-            }
         }
     }
 
-    PollInfo->HandleCount = j;
+    PollInfo->HandleCount = HandleCount;
     PollBufferSize = FIELD_OFFSET(AFD_POLL_INFO, Handles) + PollInfo->HandleCount * sizeof(AFD_HANDLE);
 
     /* Send IOCTL */
@@ -1208,7 +1300,7 @@ WSPSelect(IN int nfds,
 
     TRACE("DeviceIoControlFile => %x\n", Status);
 
-    /* Wait for Completition */
+    /* Wait for Completion */
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
@@ -1321,18 +1413,18 @@ WSPSelect(IN int nfds,
 }
 
 DWORD
-GetCurrentTimeInSeconds()
+GetCurrentTimeInSeconds(VOID)
 {
-    FILETIME Time;
-    FILETIME Adjustment;
-    ULARGE_INTEGER lTime, lAdj;
-    SYSTEMTIME st = { 1970,1,0,1,0,0,0 };
-    SystemTimeToFileTime(&st, &Adjustment);
-    memcpy(&lAdj, &Adjustment, sizeof(lAdj));
-    GetSystemTimeAsFileTime(&Time);
-    memcpy(&lTime, &Time, sizeof(lTime));
-    lTime.QuadPart -= lAdj.QuadPart;
-    return (DWORD)(lTime.QuadPart / 10000000LLU);
+    SYSTEMTIME st1970 = { 1970, 1, 0, 1, 0, 0, 0, 0 };
+    union
+    {
+        FILETIME ft;
+        ULONGLONG ll;
+    } u1970, Time;
+
+    GetSystemTimeAsFileTime(&Time.ft);
+    SystemTimeToFileTime(&st1970, &u1970.ft);
+    return (DWORD)((Time.ll - u1970.ll) / 10000000ULL);
 }
 
 SOCKET
@@ -2056,7 +2148,7 @@ WSPGetSockName(IN SOCKET Handle,
 {
     IO_STATUS_BLOCK         IOSB;
     ULONG                   TdiAddressSize;
-       PTDI_ADDRESS_INFO       TdiAddress;
+    PTDI_ADDRESS_INFO       TdiAddress;
     PTRANSPORT_ADDRESS      SocketAddress;
     PSOCKET_INFORMATION     Socket = NULL;
     NTSTATUS                Status;
@@ -2087,7 +2179,7 @@ WSPGetSockName(IN SOCKET Handle,
 
     /* Allocate a buffer for the address */
     TdiAddressSize =
-               sizeof(TRANSPORT_ADDRESS) + Socket->SharedData->SizeOfLocalAddress;
+        sizeof(TRANSPORT_ADDRESS) + Socket->SharedData->SizeOfLocalAddress;
     TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
 
     if ( TdiAddress == NULL )
@@ -3369,7 +3461,7 @@ int CreateContext(PSOCKET_INFORMATION Socket)
                                    NULL,
                                    0);
 
-    /* Wait for Completition */
+    /* Wait for Completion */
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
@@ -3482,7 +3574,7 @@ BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
     FILE_COMPLETION_INFORMATION CompletionInfo;
     OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
 
-    /* First, make sure we're not already intialized */
+    /* First, make sure we're not already initialized */
     if (SockAsyncHelperAfdHandle)
     {
         return TRUE;
@@ -3567,7 +3659,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
                 if (0 != (Socket->SharedData->AsyncEvents & FD_READ) &&
                     0 == (Socket->SharedData->AsyncDisabledEvents & FD_READ))
                 {
-                    /* Make the Notifcation */
+                    /* Make the Notification */
                     (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
                                                Socket->SharedData->wMsg,
                                                Socket->Handle,
@@ -3581,7 +3673,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
             if (0 != (Socket->SharedData->AsyncEvents & FD_OOB) &&
                 0 == (Socket->SharedData->AsyncDisabledEvents & FD_OOB))
             {
-                /* Make the Notifcation */
+                /* Make the Notification */
                 (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
                                            Socket->SharedData->wMsg,
                                            Socket->Handle,
@@ -3595,7 +3687,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
                 if (0 != (Socket->SharedData->AsyncEvents & FD_WRITE) &&
                     0 == (Socket->SharedData->AsyncDisabledEvents & FD_WRITE))
                 {
-                    /* Make the Notifcation */
+                    /* Make the Notification */
                     (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
                                                Socket->SharedData->wMsg,
                                                Socket->Handle,
@@ -3611,7 +3703,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
                 if (0 != (Socket->SharedData->AsyncEvents & FD_CONNECT) &&
                     0 == (Socket->SharedData->AsyncDisabledEvents & FD_CONNECT))
                 {
-                    /* Make the Notifcation */
+                    /* Make the Notification */
                     (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
                         Socket->SharedData->wMsg,
                         Socket->Handle,
@@ -3625,7 +3717,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
                 if (0 != (Socket->SharedData->AsyncEvents & FD_ACCEPT) &&
                     0 == (Socket->SharedData->AsyncDisabledEvents & FD_ACCEPT))
                 {
-                    /* Make the Notifcation */
+                    /* Make the Notification */
                     (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
                                                Socket->SharedData->wMsg,
                                                Socket->Handle,
@@ -3641,7 +3733,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
                 if (0 != (Socket->SharedData->AsyncEvents & FD_CLOSE) &&
                     0 == (Socket->SharedData->AsyncDisabledEvents & FD_CLOSE))
                 {
-                    /* Make the Notifcation */
+                    /* Make the Notification */
                     (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
                                                Socket->SharedData->wMsg,
                                                Socket->Handle,
@@ -3763,7 +3855,7 @@ VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
         /* Check if the Sequence Number changed by now, in which case quit */
         if (AsyncData->SequenceNumber == Socket->SharedData->SequenceNumber)
         {
-            /* Do the actuall select, if needed */
+            /* Do the actual select, if needed */
             if ((Socket->SharedData->AsyncEvents & (~Socket->SharedData->AsyncDisabledEvents)))
             {
                 SockProcessAsyncSelect(Socket, AsyncData);
@@ -3877,5 +3969,3 @@ DllMain(HANDLE hInstDll,
 }
 
 /* EOF */
-
-