[MSAFD] Add some parameter checks on send/recv based on wine tests. CORE-12104
[reactos.git] / reactos / dll / win32 / msafd / misc / sndrcv.c
index 1a791d3..e94ab2b 100644 (file)
@@ -94,6 +94,92 @@ WSPAsyncSelect(IN  SOCKET Handle,
 }
 
 
+BOOL
+WSPAPI
+WSPGetOverlappedResult(
+    IN  SOCKET Handle,
+    IN  LPWSAOVERLAPPED lpOverlapped,
+    OUT LPDWORD lpdwBytes,
+    IN  BOOL fWait,
+    OUT LPDWORD lpdwFlags,
+    OUT LPINT lpErrno)
+{
+    PIO_STATUS_BLOCK        IOSB;
+    NTSTATUS                Status;
+    PSOCKET_INFORMATION     Socket;
+
+    TRACE("Called (%x)\n", Handle);
+
+    /* Get the Socket Structure associate to this Socket*/
+    Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+        if(lpErrno)
+            *lpErrno = WSAENOTSOCK;
+        return FALSE;
+    }
+    if (!lpOverlapped || !lpdwBytes || !lpdwFlags)
+    {
+        if (lpErrno)
+            *lpErrno = WSAEFAULT;
+        return FALSE;
+    }
+    IOSB = (PIO_STATUS_BLOCK)&lpOverlapped->Internal;
+    if (!IOSB)
+    {
+        if (lpErrno)
+            *lpErrno = WSAEFAULT;
+        return FALSE;
+    }
+    Status = IOSB->Status;
+
+    /* Wait for completition of overlapped */
+    if (Status == STATUS_PENDING)
+    {
+        /* It's up to the protocol to time out recv.  We must wait
+        * until the protocol decides it's had enough.
+        */
+        if (fWait)
+        {
+            WaitForSingleObject(lpOverlapped->hEvent, INFINITE);
+            Status = IOSB->Status;
+        }
+    }
+
+    TRACE("Status %x Information %d\n", Status, IOSB->Information);
+
+    if (Status != STATUS_PENDING)
+    {
+        *lpdwFlags = 0;
+
+        *lpdwBytes = IOSB->Information;
+
+        /* Re-enable Async Event */
+        SockReenableAsyncSelectEvent(Socket, FD_OOB);
+        SockReenableAsyncSelectEvent(Socket, FD_WRITE);
+        SockReenableAsyncSelectEvent(Socket, FD_READ);
+    }
+
+    return Status == STATUS_SUCCESS;
+}
+
+VOID
+NTAPI
+AfdAPC(PVOID ApcContext,
+       PIO_STATUS_BLOCK IoStatusBlock,
+       ULONG Reserved)
+{
+    PAFDAPCCONTEXT Context = ApcContext;
+
+    /* Re-enable Async Event */
+    SockReenableAsyncSelectEvent(Context->lpSocket, FD_OOB);
+    SockReenableAsyncSelectEvent(Context->lpSocket, FD_READ);
+    SockReenableAsyncSelectEvent(Context->lpSocket, FD_WRITE);
+
+    Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0);
+    HeapFree(GlobalHeap, 0, ApcContext);
+}
+
 int
 WSPAPI
 WSPRecv(SOCKET Handle,
@@ -111,7 +197,7 @@ WSPRecv(SOCKET Handle,
     AFD_RECV_INFO           RecvInfo;
     NTSTATUS                Status;
     PVOID                   APCContext;
-    PVOID                   APCFunction;
+    PIO_APC_ROUTINE         APCFunction;
     HANDLE                  Event = NULL;
     HANDLE                  SockEvent;
     PSOCKET_INFORMATION     Socket;
@@ -122,8 +208,15 @@ WSPRecv(SOCKET Handle,
     Socket = GetSocketStructure(Handle);
     if (!Socket)
     {
-       *lpErrno = WSAENOTSOCK;
-       return SOCKET_ERROR;
+        if (lpErrno)
+            *lpErrno = WSAENOTSOCK;
+        return SOCKET_ERROR;
+    }
+    if (!lpNumberOfBytesRead && !lpOverlapped)
+    {
+        if (lpErrno)
+            *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
     }
 
     Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS,
@@ -173,6 +266,12 @@ WSPRecv(SOCKET Handle,
     }
     else
     {
+        /* Overlapped request for non overlapped opened socket */
+        if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
+        {
+            TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
+            return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesRead);
+        }
         if (lpCompletionRoutine == NULL)
         {
             /* Using Overlapped Structure, but no Completition Routine, so no need for APC */
@@ -183,8 +282,16 @@ WSPRecv(SOCKET Handle,
         else
         {
             /* Using Overlapped Structure and a Completition Routine, so use an APC */
-            APCFunction = NULL; // should be a private io completition function inside us
-            APCContext = lpCompletionRoutine;
+            APCFunction = &AfdAPC; // should be a private io completition function inside us
+            APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
+            if (!APCContext)
+            {
+                ERR("Not enough memory for APC Context\n");
+                return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesRead);
+            }
+            ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
+            ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
+            ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
             RecvInfo.AfdFlags |= AFD_SKIP_FIO;
         }
 
@@ -196,15 +303,15 @@ WSPRecv(SOCKET Handle,
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)Handle,
-        Event,
-        APCFunction,
-        APCContext,
-        IOSB,
-        IOCTL_AFD_RECV,
-        &RecvInfo,
-        sizeof(RecvInfo),
-        NULL,
-        0);
+                                   Event,
+                                   APCFunction,
+                                   APCContext,
+                                   IOSB,
+                                   IOCTL_AFD_RECV,
+                                   &RecvInfo,
+                                   sizeof(RecvInfo),
+                                   NULL,
+                                   0);
 
     /* Wait for completition of not overlapped */
     if (Status == STATUS_PENDING && lpOverlapped == NULL)
@@ -220,6 +327,12 @@ WSPRecv(SOCKET Handle,
 
     TRACE("Status %x Information %d\n", Status, IOSB->Information);
 
+    if (Status == STATUS_PENDING)
+    {
+        TRACE("Leaving (Pending)\n");
+        return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesRead);
+    }
+
     /* Return the Flags */
     *ReceiveFlags = 0;
 
@@ -246,6 +359,12 @@ WSPRecv(SOCKET Handle,
         SockReenableAsyncSelectEvent(Socket, FD_READ);
     }
 
+    if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
+    {
+        lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, *ReceiveFlags);
+        HeapFree(GlobalHeap, 0, (PVOID)APCContext);
+    }
+
     return MsafdReturnWithErrno ( Status, lpErrno, IOSB->Information, lpNumberOfBytesRead );
 }
 
@@ -277,8 +396,15 @@ WSPRecvFrom(SOCKET Handle,
     Socket = GetSocketStructure(Handle);
     if (!Socket)
     {
-       *lpErrno = WSAENOTSOCK;
-       return SOCKET_ERROR;
+        if (lpErrno)
+            *lpErrno = WSAENOTSOCK;
+        return SOCKET_ERROR;
+    }
+    if (!lpNumberOfBytesRead && !lpOverlapped)
+    {
+        if (lpErrno)
+            *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
     }
 
     if (!(Socket->SharedData->ServiceFlags1 & XP1_CONNECTIONLESS))
@@ -355,6 +481,12 @@ WSPRecvFrom(SOCKET Handle,
     }
     else
     {
+        /* Overlapped request for non overlapped opened socket */
+        if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
+        {
+            TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
+            return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesRead);
+        }
         if (lpCompletionRoutine == NULL)
         {
             /* Using Overlapped Structure, but no Completition Routine, so no need for APC */
@@ -365,8 +497,16 @@ WSPRecvFrom(SOCKET Handle,
         else
         {
             /* Using Overlapped Structure and a Completition Routine, so use an APC */
-            APCFunction = NULL; // should be a private io completition function inside us
-            APCContext = lpCompletionRoutine;
+            APCFunction = &AfdAPC; // should be a private io completition function inside us
+            APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
+            if (!APCContext)
+            {
+                ERR("Not enough memory for APC Context\n");
+                return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesRead);
+            }
+            ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
+            ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
+            ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
             RecvInfo.AfdFlags |= AFD_SKIP_FIO;
         }
 
@@ -397,12 +537,19 @@ WSPRecvFrom(SOCKET Handle,
 
     NtClose( SockEvent );
 
+    if (Status == STATUS_PENDING)
+    {
+        TRACE("Leaving (Pending)\n");
+        return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesRead);
+    }
+
     /* Return the Flags */
     *ReceiveFlags = 0;
 
     switch (Status)
     {
-        case STATUS_RECEIVE_EXPEDITED: *ReceiveFlags = MSG_OOB;
+        case STATUS_RECEIVE_EXPEDITED:
+            *ReceiveFlags = MSG_OOB;
             break;
         case STATUS_RECEIVE_PARTIAL_EXPEDITED:
             *ReceiveFlags = MSG_PARTIAL | MSG_OOB;
@@ -422,6 +569,12 @@ WSPRecvFrom(SOCKET Handle,
         SockReenableAsyncSelectEvent(Socket, FD_READ);
     }
 
+    if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
+    {
+        lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, *ReceiveFlags);
+        HeapFree(GlobalHeap, 0, (PVOID)APCContext);
+    }
+
     return MsafdReturnWithErrno ( Status, lpErrno, IOSB->Information, lpNumberOfBytesRead );
 }
 
@@ -452,8 +605,15 @@ WSPSend(SOCKET Handle,
     Socket = GetSocketStructure(Handle);
     if (!Socket)
     {
-       *lpErrno = WSAENOTSOCK;
-       return SOCKET_ERROR;
+        if (lpErrno)
+            *lpErrno = WSAENOTSOCK;
+        return SOCKET_ERROR;
+    }
+    if (!lpNumberOfBytesSent && !lpOverlapped)
+    {
+        if (lpErrno)
+            *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
     }
 
     Status = NtCreateEvent( &SockEvent, EVENT_ALL_ACCESS,
@@ -494,6 +654,12 @@ WSPSend(SOCKET Handle,
     }
     else
     {
+        /* Overlapped request for non overlapped opened socket */
+        if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
+        {
+            TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
+            return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesSent);
+        }
         if (lpCompletionRoutine == NULL)
         {
             /* Using Overlapped Structure, but no Completition Routine, so no need for APC */
@@ -504,8 +670,16 @@ WSPSend(SOCKET Handle,
         else
         {
             /* Using Overlapped Structure and a Completition Routine, so use an APC */
-            APCFunction = NULL; // should be a private io completition function inside us
-            APCContext = lpCompletionRoutine;
+            APCFunction = &AfdAPC; // should be a private io completition function inside us
+            APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
+            if (!APCContext)
+            {
+                ERR("Not enough memory for APC Context\n");
+                return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesSent);
+            }
+            ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
+            ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
+            ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
             SendInfo.AfdFlags |= AFD_SKIP_FIO;
         }
 
@@ -547,6 +721,12 @@ WSPSend(SOCKET Handle,
 
     TRACE("Leaving (Success, %d)\n", IOSB->Information);
 
+    if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
+    {
+        lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, 0);
+        HeapFree(GlobalHeap, 0, (PVOID)APCContext);
+    }
+
     return MsafdReturnWithErrno( Status, lpErrno, IOSB->Information, lpNumberOfBytesSent );
 }
 
@@ -581,8 +761,15 @@ WSPSendTo(SOCKET Handle,
     Socket = GetSocketStructure(Handle);
     if (!Socket)
     {
-       *lpErrno = WSAENOTSOCK;
-       return SOCKET_ERROR;
+        if (lpErrno)
+            *lpErrno = WSAENOTSOCK;
+        return SOCKET_ERROR;
+    }
+    if (!lpNumberOfBytesSent && !lpOverlapped)
+    {
+        if (lpErrno)
+            *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
     }
 
     if (!(Socket->SharedData->ServiceFlags1 & XP1_CONNECTIONLESS))
@@ -666,6 +853,12 @@ WSPSendTo(SOCKET Handle,
     }
     else
     {
+        /* Overlapped request for non overlapped opened socket */
+        if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
+        {
+            TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
+            return MsafdReturnWithErrno(0, lpErrno, 0, lpNumberOfBytesSent);
+        }
         if (lpCompletionRoutine == NULL)
         {
             /* Using Overlapped Structure, but no Completition Routine, so no need for APC */
@@ -676,9 +869,16 @@ WSPSendTo(SOCKET Handle,
         else
         {
             /* Using Overlapped Structure and a Completition Routine, so use an APC */
-            /* Should be a private io completition function inside us */
-            APCFunction = NULL;
-            APCContext = lpCompletionRoutine;
+            APCFunction = &AfdAPC; // should be a private io completition function inside us
+            APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
+            if (!APCContext)
+            {
+                ERR("Not enough memory for APC Context\n");
+                return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, lpNumberOfBytesSent);
+            }
+            ((PAFDAPCCONTEXT)APCContext)->lpCompletionRoutine = lpCompletionRoutine;
+            ((PAFDAPCCONTEXT)APCContext)->lpOverlapped = lpOverlapped;
+            ((PAFDAPCCONTEXT)APCContext)->lpSocket = Socket;
             SendInfo.AfdFlags |= AFD_SKIP_FIO;
         }
 
@@ -713,8 +913,20 @@ WSPSendTo(SOCKET Handle,
         HeapFree(GlobalHeap, 0, BindAddress);
     }
 
+    if (Status == STATUS_PENDING)
+    {
+        TRACE("Leaving (Pending)\n");
+        return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent);
+    }
+
     SockReenableAsyncSelectEvent(Socket, FD_WRITE);
 
+    if (Status == STATUS_SUCCESS && lpOverlapped && lpCompletionRoutine)
+    {
+        lpCompletionRoutine(Status, IOSB->Information, lpOverlapped, 0);
+        HeapFree(GlobalHeap, 0, (PVOID)APCContext);
+    }
+
     return MsafdReturnWithErrno(Status, lpErrno, IOSB->Information, lpNumberOfBytesSent);
 }