- Update to r53061
[reactos.git] / dll / win32 / msafd / misc / dllmain.c
index b0b1f31..87e5dcd 100644 (file)
@@ -23,10 +23,11 @@ HANDLE GlobalHeap;
 WSPUPCALLTABLE Upcalls;
 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
 PSOCKET_INFORMATION SocketListHead = NULL;
+CRITICAL_SECTION SocketListLock;
 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
 ULONG SockAsyncThreadRefCount;
 HANDLE SockAsyncHelperAfdHandle;
-HANDLE SockAsyncCompletionPort;
+HANDLE SockAsyncCompletionPort = NULL;
 BOOLEAN SockAsyncSelectCalled;
 
 
@@ -263,7 +264,7 @@ WSPSocket(int AddressFamily,
     /* Save Group Info */
     if (g != 0)
     {
-        GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
+        GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, NULL, NULL, &GroupData);
         Socket->SharedData.GroupID = GroupData.u.LowPart;
         Socket->SharedData.GroupType = GroupData.u.HighPart;
     }
@@ -271,17 +272,21 @@ WSPSocket(int AddressFamily,
     /* Get Window Sizes and Save them */
     GetSocketInformation (Socket,
                           AFD_INFO_SEND_WINDOW_SIZE,
+                          NULL,
                           &Socket->SharedData.SizeOfSendBuffer,
                           NULL);
 
     GetSocketInformation (Socket,
                           AFD_INFO_RECEIVE_WINDOW_SIZE,
+                          NULL,
                           &Socket->SharedData.SizeOfRecvBuffer,
                           NULL);
 
     /* Save in Process Sockets List */
+    EnterCriticalSection(&SocketListLock);
     Socket->NextSocket = SocketListHead;
     SocketListHead = Socket;
+    LeaveCriticalSection(&SocketListLock);
 
     /* Create the Socket Context */
     CreateContext(Socket);
@@ -379,24 +384,6 @@ TranslateNtStatusError(NTSTATUS Status)
     }
 }
 
-DWORD MsafdReturnWithErrno(NTSTATUS Status,
-                           LPINT Errno,
-                           DWORD Received,
-                           LPDWORD ReturnedBytes)
-{
-    *Errno = TranslateNtStatusError(Status);
-
-    if (ReturnedBytes)
-    {
-        if (!*Errno)
-            *ReturnedBytes = Received;
-        else
-            *ReturnedBytes = 0;
-    }
-
-    return *Errno ? SOCKET_ERROR : 0;
-}
-
 /*
  * FUNCTION: Closes an open socket
  * ARGUMENTS:
@@ -416,6 +403,7 @@ WSPCloseSocket(IN SOCKET Handle,
     HANDLE SockEvent;
     AFD_DISCONNECT_INFO DisconnectInfo;
     SOCKET_STATE OldState;
+    LONG LingerWait = -1;
 
     /* Create the Wait Event */
     Status = NtCreateEvent(&SockEvent,
@@ -468,7 +456,6 @@ WSPCloseSocket(IN SOCKET Handle,
     /* FIXME: Should we do this on Datagram Sockets too? */
     if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
     {
-        ULONG LingerWait;
         ULONG SendsInProgress;
         ULONG SleepWait;
 
@@ -482,6 +469,7 @@ WSPCloseSocket(IN SOCKET Handle,
             /* Find out how many Sends are in Progress */
             if (GetSocketInformation(Socket,
                                      AFD_INFO_SENDS_IN_PROGRESS,
+                                     NULL,
                                      &SendsInProgress,
                                      NULL))
             {
@@ -492,7 +480,11 @@ WSPCloseSocket(IN SOCKET Handle,
 
             /* Bail out if no more sends are pending */
             if (!SendsInProgress)
+            {
+                LingerWait = -1;
                 break;
+            }
+
             /* 
              * We have to execute a sleep, so it's kind of like
              * a block. If the socket is Nonblock, we cannot
@@ -517,33 +509,36 @@ WSPCloseSocket(IN SOCKET Handle,
             Sleep(SleepWait);
             LingerWait -= SleepWait;
         }
+    }
 
-        /*
-        * We have reached the timeout or sends are over.
-        * Disconnect if the timeout has been reached. 
-        */
+    if (OldState == SocketConnected)
+    {
         if (LingerWait <= 0)
         {
             DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
-            DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
-
-            /* Send IOCTL */
-            Status = NtDeviceIoControlFile((HANDLE)Handle,
-                                           SockEvent,
-                                           NULL,
-                                           NULL,
-                                           &IoStatusBlock,
-                                           IOCTL_AFD_DISCONNECT,
-                                           &DisconnectInfo,
-                                           sizeof(DisconnectInfo),
-                                           NULL,
-                                           0);
-
-            /* Wait for return */
-            if (Status == STATUS_PENDING)
+            DisconnectInfo.DisconnectType = LingerWait < 0 ? AFD_DISCONNECT_SEND : AFD_DISCONNECT_ABORT;
+            
+            if (((DisconnectInfo.DisconnectType & AFD_DISCONNECT_SEND) && (!Socket->SharedData.SendShutdown)) ||
+                ((DisconnectInfo.DisconnectType & AFD_DISCONNECT_ABORT) && (!Socket->SharedData.ReceiveShutdown)))
             {
-                WaitForSingleObject(SockEvent, INFINITE);
-                Status = IoStatusBlock.Status;
+                /* Send IOCTL */
+                Status = NtDeviceIoControlFile((HANDLE)Handle,
+                                               SockEvent,
+                                               NULL,
+                                               NULL,
+                                               &IoStatusBlock,
+                                               IOCTL_AFD_DISCONNECT,
+                                               &DisconnectInfo,
+                                               sizeof(DisconnectInfo),
+                                               NULL,
+                                               0);
+                
+                /* Wait for return */
+                if (Status == STATUS_PENDING)
+                {
+                    WaitForSingleObject(SockEvent, INFINITE);
+                    Status = IoStatusBlock.Status;
+                }
             }
         }
     }
@@ -556,6 +551,7 @@ WSPCloseSocket(IN SOCKET Handle,
     NtClose(Socket->TdiConnectionHandle);
     Socket->TdiConnectionHandle = NULL;
 
+    EnterCriticalSection(&SocketListLock);
     if (SocketListHead == Socket)
     {
         SocketListHead = SocketListHead->NextSocket;
@@ -574,13 +570,13 @@ WSPCloseSocket(IN SOCKET Handle,
             CurrentSocket = CurrentSocket->NextSocket;
         }
     }
-
-    HeapFree(GlobalHeap, 0, Socket);
+    LeaveCriticalSection(&SocketListLock);
 
     /* Close the handle */
     NtClose((HANDLE)Handle);
     NtClose(SockEvent);
 
+    HeapFree(GlobalHeap, 0, Socket);
     return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
 }
 
@@ -810,7 +806,8 @@ WSPSelect(IN int nfds,
     IO_STATUS_BLOCK     IOSB;
     PAFD_POLL_INFO      PollInfo;
     NTSTATUS            Status;
-    LONG                HandleCount, OutCount = 0;
+    ULONG               HandleCount;
+    LONG                OutCount = 0;
     ULONG               PollBufferSize;
     PVOID               PollBuffer;
     ULONG               i, j = 0, x;
@@ -827,7 +824,7 @@ WSPSelect(IN int nfds,
 
     if ( HandleCount == 0 )
     {
-        AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
+        AFD_DbgPrint(MAX_TRACE,("HandleCount: %u. Return SOCKET_ERROR\n",
                      HandleCount));
         if (lpErrno) *lpErrno = WSAEINVAL;
         return SOCKET_ERROR;
@@ -835,7 +832,7 @@ WSPSelect(IN int nfds,
 
     PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
 
-    AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n", 
+    AFD_DbgPrint(MID_TRACE,("HandleCount: %u BufferSize: %u\n", 
                  HandleCount, PollBufferSize));
 
     /* Convert Timeout to NT Format */
@@ -920,7 +917,7 @@ WSPSelect(IN int nfds,
     }
 
     PollInfo->HandleCount = j;
-    PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
+    PollBufferSize = FIELD_OFFSET(AFD_POLL_INFO, Handles) + PollInfo->HandleCount * sizeof(AFD_HANDLE);
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
@@ -1058,6 +1055,7 @@ WSPAccept(SOCKET Handle,
     ULONG                       CallBack;
     WSAPROTOCOL_INFOW           ProtocolInfo;
     SOCKET                      AcceptSocket;
+    PSOCKET_INFORMATION         AcceptSocketInfo;
     UCHAR                       ReceiveBuffer[0x1A];
     HANDLE                      SockEvent;
 
@@ -1373,6 +1371,17 @@ WSPAccept(SOCKET Handle,
         MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
         return INVALID_SOCKET;
     }
+    
+    AcceptSocketInfo = GetSocketStructure(AcceptSocket);
+    if (!AcceptSocketInfo)
+    {
+        NtClose(SockEvent);
+        WSPCloseSocket( AcceptSocket, lpErrno );
+        MsafdReturnWithErrno( STATUS_INVALID_CONNECTION, lpErrno, 0, NULL );
+        return INVALID_SOCKET;
+    }
+    
+    AcceptSocketInfo->SharedData.State = SocketConnected;
 
     /* Return Address in SOCKADDR FORMAT */
     if( SocketAddress )
@@ -1578,6 +1587,7 @@ WSPConnect(SOCKET Handle,
     if (Status != STATUS_SUCCESS)
         goto notify;
 
+    Socket->SharedData.State = SocketConnected;
     Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
 
     /* Get any pending connect data */
@@ -1694,7 +1704,7 @@ WSPShutdown(SOCKET Handle,
             break;
     }
 
-    DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
+    DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1000000);
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)Handle,
@@ -1756,6 +1766,13 @@ WSPGetSockName(IN SOCKET Handle,
        return SOCKET_ERROR;
     }
 
+    if (!Name || !NameLength)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
+    }
+
     /* Allocate a buffer for the address */
     TdiAddressSize = 
                sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
@@ -1793,13 +1810,13 @@ WSPGetSockName(IN SOCKET Handle,
 
     if (NT_SUCCESS(Status))
     {
-        if (*NameLength >= SocketAddress->Address[0].AddressLength)
+        if (*NameLength >= Socket->SharedData.SizeOfLocalAddress)
         {
             Name->sa_family = SocketAddress->Address[0].AddressType;
             RtlCopyMemory (Name->sa_data,
                            SocketAddress->Address[0].Address, 
                            SocketAddress->Address[0].AddressLength);
-            *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+            *NameLength = Socket->SharedData.SizeOfLocalAddress;
             AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
                           *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
                           ((struct sockaddr_in *)Name)->sin_port));
@@ -1850,8 +1867,22 @@ WSPGetPeerName(IN SOCKET s,
        return SOCKET_ERROR;
     }
 
+    if (Socket->SharedData.State != SocketConnected)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAENOTCONN;
+        return SOCKET_ERROR;
+    }
+
+    if (!Name || !NameLength)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
+    }
+
     /* Allocate a buffer for the address */
-    TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
+    TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfRemoteAddress;
     SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
 
     if ( SocketAddress == NULL )
@@ -1884,13 +1915,13 @@ WSPGetPeerName(IN SOCKET s,
 
     if (NT_SUCCESS(Status))
     {
-        if (*NameLength >= SocketAddress->Address[0].AddressLength)
+        if (*NameLength >= Socket->SharedData.SizeOfRemoteAddress)
         {
             Name->sa_family = SocketAddress->Address[0].AddressType;
             RtlCopyMemory (Name->sa_data,
                            SocketAddress->Address[0].Address, 
                            SocketAddress->Address[0].AddressLength);
-            *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+            *NameLength = Socket->SharedData.SizeOfRemoteAddress;
             AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
                           *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
                           ((struct sockaddr_in *)Name)->sin_port));
@@ -1923,6 +1954,8 @@ WSPIoctl(IN  SOCKET Handle,
          OUT LPINT lpErrno)
 {
     PSOCKET_INFORMATION Socket = NULL;
+       BOOLEAN NeedsCompletion;
+    BOOLEAN NonBlocking;
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
@@ -1931,6 +1964,8 @@ WSPIoctl(IN  SOCKET Handle,
        *lpErrno = WSAENOTSOCK;
        return SOCKET_ERROR;
     }
+       
+       *lpcbBytesReturned = 0;
 
     switch( dwIoControlCode )
     {
@@ -1940,18 +1975,49 @@ WSPIoctl(IN  SOCKET Handle,
                 *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
             }
-            Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
-            return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+            NonBlocking = *((PULONG)lpvInBuffer) ? TRUE : FALSE;
+            Socket->SharedData.NonBlocking = NonBlocking ? 1 : 0;
+            *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, &NonBlocking, NULL, NULL);
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                               return NO_ERROR;
         case FIONREAD:
             if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
             {
                 *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
             }
-            return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
-        default:
+            *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, NULL, (PULONG)lpvOutBuffer, NULL);
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                       {
+                               *lpcbBytesReturned = sizeof(ULONG);
+                               return NO_ERROR;
+                       }
+        case SIO_GET_EXTENSION_FUNCTION_POINTER:
             *lpErrno = WSAEINVAL;
             return SOCKET_ERROR;
+        default:
+                       *lpErrno = Socket->HelperData->WSHIoctl(Socket->HelperContext,
+                                                                                                       Handle,
+                                                                                                       Socket->TdiAddressHandle,
+                                                                                                       Socket->TdiConnectionHandle,
+                                                                                                       dwIoControlCode,
+                                                                                                       lpvInBuffer,
+                                                                                                       cbInBuffer,
+                                                                                                       lpvOutBuffer,
+                                                                                                       cbOutBuffer,
+                                                                                                       lpcbBytesReturned,
+                                                                                                       lpOverlapped,
+                                                                                                       lpCompletionRoutine,
+                                                                                                       (LPBOOL)&NeedsCompletion);
+                       
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                               return NO_ERROR;
     }
 }
 
@@ -2083,7 +2149,16 @@ WSPSetSockOpt(
     }
 
 
-    /* FIXME: We should handle some cases here */
+    /* FIXME: We should handle some more cases here */
+    if (level == SOL_SOCKET)
+    {
+        switch (optname)
+        {
+           case SO_BROADCAST:
+              Socket->SharedData.Broadcast = (*optval != 0) ? 1 : 0;
+              return 0;
+        }
+    }
 
 
     *lpErrno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
@@ -2189,6 +2264,7 @@ WSPCleanup(OUT LPINT lpErrno)
 int 
 GetSocketInformation(PSOCKET_INFORMATION Socket, 
                      ULONG AfdInformationClass, 
+                     PBOOLEAN Boolean OPTIONAL,
                      PULONG Ulong OPTIONAL, 
                      PLARGE_INTEGER LargeInteger OPTIONAL)
 {
@@ -2240,6 +2316,10 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
     {
         *LargeInteger = InfoData.Information.LargeInteger;
     }
+    if (Boolean != NULL)
+    {
+        *Boolean = InfoData.Information.Boolean;
+    }
 
     NtClose( SockEvent );
 
@@ -2250,7 +2330,8 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
 
 int 
 SetSocketInformation(PSOCKET_INFORMATION Socket, 
-                     ULONG AfdInformationClass, 
+                     ULONG AfdInformationClass,
+                     PBOOLEAN Boolean OPTIONAL,
                      PULONG Ulong OPTIONAL, 
                      PLARGE_INTEGER LargeInteger OPTIONAL)
 {
@@ -2280,6 +2361,10 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     {
         InfoData.Information.LargeInteger = *LargeInteger;
     }
+    if (Boolean != NULL)
+    {
+        InfoData.Information.Boolean = *Boolean;
+    }
 
     AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
         AfdInformationClass, *Ulong));
@@ -2314,15 +2399,22 @@ GetSocketStructure(SOCKET Handle)
 {
     PSOCKET_INFORMATION CurrentSocket;
 
+    EnterCriticalSection(&SocketListLock);
+
     CurrentSocket = SocketListHead;
     while (CurrentSocket)
     {
         if (CurrentSocket->Handle == Handle)
+        {
+            LeaveCriticalSection(&SocketListLock);
             return CurrentSocket;
+        }
 
         CurrentSocket = CurrentSocket->NextSocket;
     }
 
+    LeaveCriticalSection(&SocketListLock);
+
     return NULL;
 }
 
@@ -2387,6 +2479,7 @@ BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
     /* Check if the Thread Already Exists */
     if (SockAsyncThreadRefCount)
     {
+        ASSERT(SockAsyncCompletionPort);
         return TRUE;
     }
 
@@ -2397,7 +2490,11 @@ BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
                                       IO_COMPLETION_ALL_ACCESS,
                                       NULL,
                                       2); // Allow 2 threads only
-
+        if (!NT_SUCCESS(Status))
+        {
+             AFD_DbgPrint(MID_TRACE,("Failed to create completion port\n"));
+             return FALSE;
+        }
         /* Protect Handle */   
         HandleFlags.ProtectFromClose = TRUE;
         HandleFlags.Inherit = FALSE;
@@ -2841,6 +2938,9 @@ DllMain(HANDLE hInstDll,
         /* Heap to use when allocating */
         GlobalHeap = GetProcessHeap();
 
+        /* Initialize the lock that protects our socket list */
+        InitializeCriticalSection(&SocketListLock);
+
         AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
 
         break;
@@ -2852,6 +2952,10 @@ DllMain(HANDLE hInstDll,
         break;
 
     case DLL_PROCESS_DETACH:
+
+        /* Delete the socket list lock */
+        DeleteCriticalSection(&SocketListLock);
+
         break;
     }