[WS2_32_APITEST] Test overlapping recv overread. CORE-18328
authorThomas Faber <thomas.faber@reactos.org>
Sun, 19 Mar 2023 16:39:09 +0000 (12:39 -0400)
committerThomas Faber <thomas.faber@reactos.org>
Sun, 19 Mar 2023 16:39:09 +0000 (12:39 -0400)
modules/rostests/apitests/ws2_32/recv.c

index c64fdbb..80ad202 100644 (file)
@@ -132,6 +132,16 @@ static void Test_Overread(void)
     int ret;
     int error;
     int len;
+    struct
+    {
+        char buffer[32];
+        DWORD flags;
+        WSAOVERLAPPED overlapped;
+    } receivers[4] = { 0 };
+    DWORD bytesTransferred;
+    DWORD flags;
+    WSABUF wsaBuf;
+    size_t i;
 
     ret = WSAStartup(MAKEWORD(2, 2), &wsaData);
     if (ret != 0)
@@ -141,10 +151,9 @@ static void Test_Overread(void)
     }
 
     ListeningSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
-    ClientSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
-    if (ListeningSocket == INVALID_SOCKET || ClientSocket == INVALID_SOCKET)
+    if (ListeningSocket == INVALID_SOCKET)
     {
-        skip("Failed to create sockets, error %d\n", WSAGetLastError());
+        skip("Failed to create listening socket, error %d\n", WSAGetLastError());
         goto Exit;
     }
 
@@ -175,6 +184,16 @@ static void Test_Overread(void)
     }
     ok(len == sizeof(address), "getsocketname length %d\n", len);
 
+    /**************************************************************************
+     * Test 1: non-overlapped client socket
+     *************************************************************************/
+    ClientSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+    if (ClientSocket == INVALID_SOCKET)
+    {
+        skip("Failed to create client socket, error %d\n", WSAGetLastError());
+        goto Exit;
+    }
+
     ret = connect(ClientSocket, (SOCKADDR *)&address, sizeof(address));
     if (ret == SOCKET_ERROR)
     {
@@ -202,38 +221,160 @@ static void Test_Overread(void)
 
     memset(buffer, 0, sizeof(buffer));
     ret = recv(ClientSocket, buffer, sizeof(buffer), 0);
-    if (ret == SOCKET_ERROR)
-    {
-        skip("Failed to receive from socket (1), error %d\n", WSAGetLastError());
-        goto Exit;
-    }
     error = WSAGetLastError();
     ok(ret == 4, "recv (1) returned %d\n", ret);
-    ok(error == NO_ERROR, "recv (1) returned error %d\n", ret);
+    ok(error == NO_ERROR, "recv (1) returned error %d\n", error);
     buffer[4] = 0;
     ok(!strcmp(buffer, "blah"), "recv (1) returned data: %s\n", buffer);
 
     ret = recv(ClientSocket, buffer, sizeof(buffer), 0);
+    error = WSAGetLastError();
+    ok(ret == 0, "recv (2) returned %d\n", ret);
+    ok(error == NO_ERROR, "recv (2) returned error %d\n", error);
+
+    ret = recv(ClientSocket, buffer, sizeof(buffer), 0);
+    error = WSAGetLastError();
+    ok(ret == 0, "recv (3) returned %d\n", ret);
+    ok(error == NO_ERROR, "recv (3) returned error %d\n", error);
+
+    closesocket(ClientSocket);
+
+    /**************************************************************************
+     * Test 2: overlapped client socket with multiple pending receives
+     *************************************************************************/
+    ClientSocket = WSASocketW(AF_INET,
+                              SOCK_STREAM,
+                              IPPROTO_TCP,
+                              NULL,
+                              0,
+                              WSA_FLAG_OVERLAPPED);
+    if (ClientSocket == INVALID_SOCKET)
+    {
+        skip("Failed to create overlapped client socket, error %d\n", WSAGetLastError());
+        goto Exit;
+    }
+
+    ret = connect(ClientSocket, (SOCKADDR *)&address, sizeof(address));
     if (ret == SOCKET_ERROR)
     {
-        skip("Failed to receive from socket (2), error %d\n", WSAGetLastError());
+        skip("Failed to connect to socket, error %d\n", WSAGetLastError());
         goto Exit;
     }
-    error = WSAGetLastError();
-    ok(ret == 0, "recv (2) returned %d\n", ret);
-    ok(error == NO_ERROR, "recv (2) returned error %d\n", ret);
 
-    ret = recv(ClientSocket, buffer, sizeof(buffer), 0);
+    ServerSocket = accept(ListeningSocket, NULL, NULL);
+    if (ServerSocket == INVALID_SOCKET)
+    {
+        skip("Failed to accept client socket, error %d\n", WSAGetLastError());
+        goto Exit;
+    }
+
+    /* Start overlapping receive calls */
+    for (i = 0; i < RTL_NUMBER_OF(receivers); i++)
+    {
+        wsaBuf.len = sizeof(receivers[i].buffer);
+        wsaBuf.buf = receivers[i].buffer;
+        receivers[i].flags = 0;
+        receivers[i].overlapped.hEvent = WSACreateEvent();
+        ret = WSARecv(ClientSocket,
+                      &wsaBuf,
+                      1,
+                      NULL,
+                      &receivers[i].flags,
+                      &receivers[i].overlapped,
+                      NULL);
+        error = WSAGetLastError();
+        ok(ret == SOCKET_ERROR, "[%Iu] WSARecv returned %d\n", i, ret);
+        ok(error == WSA_IO_PENDING, "[%Iu] WSARecv returned error %d\n", i, error);
+    }
+
+    /* They should all be pending */
+    for (i = 0; i < RTL_NUMBER_OF(receivers); i++)
+    {
+        ret = WSAGetOverlappedResult(ClientSocket,
+                                     &receivers[i].overlapped,
+                                     &bytesTransferred,
+                                     FALSE,
+                                     &flags);
+        error = WSAGetLastError();
+        ok(ret == FALSE, "[%Iu] WSAGetOverlappedResult returned %d\n", i, ret);
+        ok(error == WSA_IO_INCOMPLETE, "[%Iu] WSAGetOverlappedResult returned error %d\n", i, error);
+    }
+
+    /* Sending some data should complete the first receive */
+    ret = send(ServerSocket, "blah", 4, 0);
     if (ret == SOCKET_ERROR)
     {
-        skip("Failed to receive from socket (3), error %d\n", WSAGetLastError());
+        skip("Failed to send to socket, error %d\n", WSAGetLastError());
         goto Exit;
     }
+
+    flags = 0x55555555;
+    bytesTransferred = 0x55555555;
+    ret = WSAGetOverlappedResult(ClientSocket,
+                                 &receivers[0].overlapped,
+                                 &bytesTransferred,
+                                 FALSE,
+                                 &flags);
     error = WSAGetLastError();
-    ok(ret == 0, "recv (3) returned %d\n", ret);
-    ok(error == NO_ERROR, "recv (3) returned error %d\n", ret);
+    ok(ret == TRUE, "WSAGetOverlappedResult returned %d\n", ret);
+    ok(flags == 0, "WSAGetOverlappedResult returned flags 0x%lx\n", flags);
+    ok(bytesTransferred == 4, "WSAGetOverlappedResult returned %lu bytes\n", bytesTransferred);
+    receivers[0].buffer[4] = 0;
+    ok(!strcmp(receivers[0].buffer, "blah"), "WSARecv returned data: %s\n", receivers[0].buffer);
+
+    /* Others should still be in progress */
+    for (i = 1; i < RTL_NUMBER_OF(receivers); i++)
+    {
+        ret = WSAGetOverlappedResult(ClientSocket,
+                                     &receivers[i].overlapped,
+                                     &bytesTransferred,
+                                     FALSE,
+                                     &flags);
+        error = WSAGetLastError();
+        ok(ret == FALSE, "[%Iu] WSAGetOverlappedResult returned %d\n", i, ret);
+        ok(error == WSA_IO_INCOMPLETE, "[%Iu] WSAGetOverlappedResult returned error %d\n", i, error);
+    }
+
+    /* Closing the server end should make all receives complete */
+    ret = closesocket(ServerSocket);
+    ServerSocket = INVALID_SOCKET;
+    ok(ret == 0, "Failed to close socket with %d\n", WSAGetLastError());
+
+    for (i = 1; i < RTL_NUMBER_OF(receivers); i++)
+    {
+        flags = 0x55555555;
+        bytesTransferred = 0x55555555;
+        ret = WSAGetOverlappedResult(ClientSocket,
+                                     &receivers[i].overlapped,
+                                     &bytesTransferred,
+                                     FALSE,
+                                     &flags);
+        error = WSAGetLastError();
+        ok(ret == TRUE, "[%Iu] WSAGetOverlappedResult returned %d\n", i, ret);
+        ok(flags == 0, "[%Iu] WSAGetOverlappedResult returned flags 0x%lx\n", i, flags);
+        ok(bytesTransferred == 0, "[%Iu] WSAGetOverlappedResult returned %lu bytes\n", i, bytesTransferred);
+    }
+
+    /* Start two more receives -- they should immediately return success */
+    ret = recv(ClientSocket, receivers[0].buffer, sizeof(receivers[0].buffer), 0);
+    error = WSAGetLastError();
+    ok(ret == 0, "recv (N+1) returned %d\n", ret);
+    ok(error == NO_ERROR, "recv (N+1) returned error %d\n", error);
+
+    ret = recv(ClientSocket, receivers[0].buffer, sizeof(receivers[0].buffer), 0);
+    error = WSAGetLastError();
+    ok(ret == 0, "recv (N+2) returned %d\n", ret);
+    ok(error == NO_ERROR, "recv (N+2) returned error %d\n", error);
 
 Exit:
+    for (i = 0; i < RTL_NUMBER_OF(receivers); i++)
+    {
+        if (receivers[i].overlapped.hEvent != NULL)
+        {
+            WSACloseEvent(receivers[i].overlapped.hEvent);
+        }
+    }
+
     if (ListeningSocket != INVALID_SOCKET)
     {
         ret = closesocket(ListeningSocket);