[MSAFD] Fix WSPSelect heap corruption and don't repeat sockets if they are in differe...
authorPeter Hater <7element@mail.bg>
Thu, 17 Nov 2016 12:57:32 +0000 (12:57 +0000)
committerPeter Hater <7element@mail.bg>
Thu, 17 Nov 2016 12:57:32 +0000 (12:57 +0000)
CORE-12324

svn path=/trunk/; revision=73241

reactos/dll/win32/msafd/misc/dllmain.c

index 7275492..920ab53 100644 (file)
@@ -1045,13 +1045,34 @@ WSPSelect(IN int nfds,
     PSOCKET_INFORMATION Socket;
     SOCKET              Handle;
     ULONG               Events;
+    fd_set              selectfds;
 
     /* Find out how many sockets we have, and how large the buffer needs
      * to be */
+    FD_ZERO(&selectfds);
+    if (readfds != NULL)
+    {
+        for (i = 0; i < readfds->fd_count; i++)
+        {
+            FD_SET(readfds->fd_array[i], &selectfds);
+        }
+    }
+    if (writefds != NULL)
+    {
+        for (i = 0; i < writefds->fd_count; i++)
+        {
+            FD_SET(writefds->fd_array[i], &selectfds);
+        }
+    }
+    if (exceptfds != NULL)
+    {
+        for (i = 0; i < exceptfds->fd_count; i++)
+        {
+            FD_SET(exceptfds->fd_array[i], &selectfds);
+        }
+    }
 
-    HandleCount = ( readfds ? readfds->fd_count : 0 ) +
-                  ( writefds ? writefds->fd_count : 0 ) +
-                  ( exceptfds ? exceptfds->fd_count : 0 );
+    HandleCount = selectfds.fd_count;
 
     if ( HandleCount == 0 )
     {
@@ -1082,7 +1103,7 @@ WSPSelect(IN int nfds,
         if (Timeout.QuadPart > 0)
         {
             if (lpErrno) *lpErrno = WSAEINVAL;
-                return SOCKET_ERROR;
+            return SOCKET_ERROR;
         }
         TRACE("Timeout: Orig %d.%06d kernel %d\n",
                      timeout->tv_sec, timeout->tv_usec,
@@ -1123,9 +1144,26 @@ WSPSelect(IN int nfds,
     PollInfo->Exclusive = FALSE;
     PollInfo->Timeout = Timeout;
 
+    for (i = 0; i < selectfds.fd_count; i++)
+    {
+        PollInfo->Handles[i].Handle = selectfds.fd_array[i];
+    }
     if (readfds != NULL) {
-        for (i = 0; i < readfds->fd_count; i++, j++)
+        for (i = 0; i < readfds->fd_count; i++)
         {
+            for (j = 0; j < HandleCount; j++)
+            {
+                if (PollInfo->Handles[j].Handle == readfds->fd_array[i])
+                    break;
+            }
+            if (j >= HandleCount)
+            {
+                ERR("Error while counting readfds %ld > %ld\n", j, HandleCount);
+                if (lpErrno) *lpErrno = WSAEFAULT;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             Socket = GetSocketStructure(readfds->fd_array[i]);
             if (!Socket)
             {
@@ -1135,20 +1173,32 @@ WSPSelect(IN int nfds,
                 NtClose(SockEvent);
                 return SOCKET_ERROR;
             }
-            PollInfo->Handles[j].Handle = readfds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
-                                          AFD_EVENT_DISCONNECT |
-                                          AFD_EVENT_ABORT |
-                                          AFD_EVENT_CLOSE |
-                                          AFD_EVENT_ACCEPT;
-            if (Socket->SharedData->OobInline != 0)
-                PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
+            PollInfo->Handles[j].Events |= AFD_EVENT_RECEIVE |
+                                           AFD_EVENT_DISCONNECT |
+                                           AFD_EVENT_ABORT |
+                                           AFD_EVENT_CLOSE |
+                                           AFD_EVENT_ACCEPT;
+            //if (Socket->SharedData->OobInline != 0)
+            //    PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
         }
     }
     if (writefds != NULL)
     {
-        for (i = 0; i < writefds->fd_count; i++, j++)
+        for (i = 0; i < writefds->fd_count; i++)
         {
+            for (j = 0; j < HandleCount; j++)
+            {
+                if (PollInfo->Handles[j].Handle == writefds->fd_array[i])
+                    break;
+            }
+            if (j >= HandleCount)
+            {
+                ERR("Error while counting writefds %ld > %ld\n", j, HandleCount);
+                if (lpErrno) *lpErrno = WSAEFAULT;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             Socket = GetSocketStructure(writefds->fd_array[i]);
             if (!Socket)
             {
@@ -1159,15 +1209,28 @@ WSPSelect(IN int nfds,
                 return SOCKET_ERROR;
             }
             PollInfo->Handles[j].Handle = writefds->fd_array[i];
-            PollInfo->Handles[j].Events = AFD_EVENT_SEND;
+            PollInfo->Handles[j].Events |= AFD_EVENT_SEND;
             if (Socket->SharedData->NonBlocking != 0)
                 PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT;
         }
     }
     if (exceptfds != NULL)
     {
-        for (i = 0; i < exceptfds->fd_count; i++, j++)
+        for (i = 0; i < exceptfds->fd_count; i++)
         {
+            for (j = 0; j < HandleCount; j++)
+            {
+                if (PollInfo->Handles[j].Handle == exceptfds->fd_array[i])
+                    break;
+            }
+            if (j > HandleCount)
+            {
+                ERR("Error while counting exceptfds %ld > %ld\n", j, HandleCount);
+                if (lpErrno) *lpErrno = WSAEFAULT;
+                HeapFree(GlobalHeap, 0, PollBuffer);
+                NtClose(SockEvent);
+                return SOCKET_ERROR;
+            }
             Socket = GetSocketStructure(exceptfds->fd_array[i]);
             if (!Socket)
             {
@@ -1178,20 +1241,14 @@ WSPSelect(IN int nfds,
                 return SOCKET_ERROR;
             }
             PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
-            PollInfo->Handles[j].Events = 0;
             if (Socket->SharedData->OobInline == 0)
                 PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
             if (Socket->SharedData->NonBlocking != 0)
                 PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT_FAIL;
-            if (PollInfo->Handles[j].Events == 0)
-            {
-                TRACE("No events can be checked for exceptfds %d. It is nonblocking and OOB line is disabled. Skipping it.", exceptfds->fd_array[i]);
-                j--;
-            }
         }
     }
 
-    PollInfo->HandleCount = j;
+    PollInfo->HandleCount = HandleCount;
     PollBufferSize = FIELD_OFFSET(AFD_POLL_INFO, Handles) + PollInfo->HandleCount * sizeof(AFD_HANDLE);
 
     /* Send IOCTL */