- Put some code back which was removed in r43270
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
index d5b2bb7..53f03b1 100644 (file)
@@ -14,7 +14,7 @@
 
 #include <debug.h>
 
-#ifdef DBG
+#if DBG
 //DWORD DebugTraceLevel = DEBUG_ULTRA;
 DWORD DebugTraceLevel = 0;
 #endif /* DBG */
@@ -98,6 +98,9 @@ WSPSocket(int AddressFamily,
 
     /* Set Socket Data */
     Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
+    if (!Socket)
+        return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+
     RtlZeroMemory(Socket, sizeof(*Socket));
     Socket->RefCount = 2;
     Socket->Handle = -1;
@@ -140,6 +143,9 @@ WSPSocket(int AddressFamily,
 
     /* Set up EA Buffer */
     EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
+    if (!EABuffer)
+        return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+
     RtlZeroMemory(EABuffer, SizeOfEA);
     EABuffer->NextEntryOffset = 0;
     EABuffer->Flags = 0;
@@ -229,16 +235,28 @@ WSPSocket(int AddressFamily,
     ourselves after every call to NtDeviceIoControlFile. This is
     because the kernel doesn't support overlapping synchronous I/O
     requests (made from multiple threads) at this time (Sep 2005) */
-    ZwCreateFile(&Sock,
-                 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
-                 &Object,
-                 &IOSB,
-                 NULL,
-                 0,
-                 FILE_SHARE_READ | FILE_SHARE_WRITE, FILE_OPEN_IF,
-                 0,
-                 EABuffer,
-                 SizeOfEA);
+    Status = NtCreateFile(&Sock,
+                          GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
+                          &Object,
+                          &IOSB,
+                          NULL,
+                          0,
+                          FILE_SHARE_READ | FILE_SHARE_WRITE,
+                          FILE_OPEN_IF,
+                          0,
+                          EABuffer,
+                          SizeOfEA);
+
+    HeapFree(GlobalHeap, 0, EABuffer);
+
+    if (Status != STATUS_SUCCESS)
+    {
+        AFD_DbgPrint(MIN_TRACE, ("Failed to open socket\n"));
+
+        HeapFree(GlobalHeap, 0, Socket);
+
+        return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
+    }
 
     /* Save Handle */
     Socket->Handle = (SOCKET)Sock;
@@ -291,6 +309,9 @@ WSPSocket(int AddressFamily,
 error:
     AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
 
+    if( Socket )
+        HeapFree(GlobalHeap, 0, Socket);
+
     if( lpErrno )
         *lpErrno = Status;
 
@@ -335,12 +356,16 @@ DWORD MsafdReturnWithErrno(NTSTATUS Status,
         case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
         case STATUS_INSUFFICIENT_RESOURCES:
             DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
-            *Errno = WSA_NOT_ENOUGH_MEMORY;
+            *Errno = WSAENOBUFS;
             break;
         case STATUS_INVALID_CONNECTION:
             DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
             *Errno = WSAEAFNOSUPPORT;
             break;
+        case STATUS_INVALID_ADDRESS:
+            DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
+            *Errno = WSAEADDRNOTAVAIL;
+            break;
         case STATUS_REMOTE_NOT_LISTENING:
             DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
             *Errno = WSAECONNREFUSED;
@@ -679,7 +704,7 @@ WSPSelect(int nfds,
           fd_set *readfds,
           fd_set *writefds,
           fd_set *exceptfds,
-          struct timeval *timeout,
+          const LPTIMEVAL timeout,
           LPINT lpErrno)
 {
     IO_STATUS_BLOCK     IOSB;
@@ -700,10 +725,15 @@ WSPSelect(int nfds,
                   ( writefds ? writefds->fd_count : 0 ) +
                   ( exceptfds ? exceptfds->fd_count : 0 );
 
-    if( HandleCount < 0 || nfds != 0 )
-        HandleCount = nfds * 3;
+    if ( HandleCount == 0 )
+    {
+        AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
+                     HandleCount));
+        if (lpErrno) *lpErrno = WSAEINVAL;
+        return SOCKET_ERROR;
+    }
 
-    PollBufferSize = sizeof(*PollInfo) + (HandleCount * sizeof(AFD_HANDLE));
+    PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
 
     AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n", 
                  HandleCount, PollBufferSize));
@@ -768,6 +798,7 @@ WSPSelect(int nfds,
             PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
                                           AFD_EVENT_DISCONNECT |
                                           AFD_EVENT_ABORT |
+                                          AFD_EVENT_CLOSE |
                                           AFD_EVENT_ACCEPT;
         }
     }
@@ -905,7 +936,7 @@ WSPAccept(SOCKET Handle,
           struct sockaddr *SocketAddress,
           int *SocketAddressLength,
           LPCONDITIONPROC lpfnCondition,
-          DWORD_PTR dwCallbackData,
+          DWORD dwCallbackData,
           LPINT lpErrno)
 {
     IO_STATUS_BLOCK             IOSB;
@@ -1028,6 +1059,11 @@ WSPAccept(SOCKET Handle,
             {
                 /* Allocate needed space */
                 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
+                if (!PendingData)
+                {
+                    MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
+                    return INVALID_SOCKET;
+                }
 
                 /* We want the data now */
                 PendingAcceptData.ReturnSize = FALSE;
@@ -1069,6 +1105,13 @@ WSPAccept(SOCKET Handle,
         CalleeID.buf = (PVOID)Socket->LocalAddress;
         CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
 
+        RemoteAddress = HeapAlloc(GlobalHeap, 0, sizeof(*RemoteAddress));
+        if (!RemoteAddress)
+        {
+            MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+            return INVALID_SOCKET;
+        }
+
         /* Set up Address in SOCKADDR Format */
         RtlCopyMemory (RemoteAddress, 
                        &ListenReceiveData->Address.Address[0].AddressType, 
@@ -1087,6 +1130,10 @@ WSPAccept(SOCKET Handle,
         {
             /* Allocate Buffer for Callee Data */
             CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
+            if (!CalleeDataBuffer) {
+                MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
+                return INVALID_SOCKET;
+            }
             CalleeData.buf = CalleeDataBuffer;
             CalleeData.len = 4096;
         } 
@@ -1278,6 +1325,11 @@ WSPConnect(SOCKET Handle,
         /* Get the Wildcard Address */
         BindAddressLength = Socket->HelperData->MaxWSAddressLength;
         BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
+        if (!BindAddress)
+        {
+            MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
+            return INVALID_SOCKET;
+        }
         Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext, 
                                                     BindAddress, 
                                                     &BindAddressLength);
@@ -1683,12 +1735,19 @@ WSPIoctl(IN  SOCKET Handle,
     switch( dwIoControlCode )
     {
         case FIONBIO:
-            if( cbInBuffer < sizeof(INT) )
+            if( cbInBuffer < sizeof(INT) || IS_INTRESOURCE(lpvInBuffer) )
+            {
+                *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
-            Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
-            AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n", Handle, Socket->SharedData.NonBlocking));
-            return 0;
+            }
+            Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
+            return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
         case FIONREAD:
+            if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
+            {
+                *lpErrno = WSAEFAULT;
+                return SOCKET_ERROR;
+            }
             return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
         default:
             *lpErrno = WSAEINVAL;
@@ -1928,7 +1987,10 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
     }
 
     /* Return Information */
-    *Ulong = InfoData.Information.Ulong;
+    if (Ulong != NULL)
+    {
+        *Ulong = InfoData.Information.Ulong;
+    }
     if (LargeInteger != NULL)
     {
         *LargeInteger = InfoData.Information.LargeInteger;
@@ -1965,7 +2027,10 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     InfoData.InformationClass = AfdInformationClass;
 
     /* Set Information */
-    InfoData.Information.Ulong = *Ulong;
+    if (Ulong != NULL)
+    {
+        InfoData.Information.Ulong = *Ulong;
+    }
     if (LargeInteger != NULL)
     {
         InfoData.Information.LargeInteger = *LargeInteger;
@@ -2284,6 +2349,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
 
                 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
             case AFD_EVENT_CONNECT:
+            case AFD_EVENT_CONNECT_FAIL:
                 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
                     0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
                 {
@@ -2387,7 +2453,7 @@ VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
 
     if (lNetworkEvents & FD_CLOSE)
     {
-        AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
+        AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT | AFD_EVENT_CLOSE;
     }
 
     if (lNetworkEvents & FD_QOS)
@@ -2480,6 +2546,7 @@ SockReenableAsyncSelectEvent (IN PSOCKET_INFORMATION Socket,
 
     /* Wait on new events */
     AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
+    if (!AsyncData) return;
 
     /* Create the Asynch Thread if Needed */  
     SockCreateOrReferenceAsyncThread();
@@ -2527,6 +2594,7 @@ DllMain(HANDLE hInstDll,
 
         /* Allocate Heap for 1024 Sockets, can be expanded later */
         Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
+        if (!Sockets) return FALSE;
 
         AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));