- I hate catching these things as the commit is going out
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
index a8ab0dc..5714213 100644 (file)
@@ -22,8 +22,8 @@ DWORD DebugTraceLevel = 0;
 HANDLE GlobalHeap;
 WSPUPCALLTABLE Upcalls;
 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
-ULONG SocketCount = 0;
-PSOCKET_INFORMATION *Sockets = NULL;
+PSOCKET_INFORMATION SocketListHead = NULL;
+CRITICAL_SECTION SocketListLock;
 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
 ULONG SockAsyncThreadRefCount;
 HANDLE SockAsyncHelperAfdHandle;
@@ -61,7 +61,7 @@ WSPSocket(int AddressFamily,
     ULONG                       SizeOfEA;
     PAFD_CREATE_PACKET          AfdPacket;
     HANDLE                      Sock;
-    PSOCKET_INFORMATION         Socket = NULL, PrevSocket = NULL;
+    PSOCKET_INFORMATION         Socket = NULL;
     PFILE_FULL_EA_INFORMATION   EABuffer = NULL;
     PHELPER_DATA                HelperData;
     PVOID                       HelperDLLContext;
@@ -261,17 +261,6 @@ WSPSocket(int AddressFamily,
     /* Save Handle */
     Socket->Handle = (SOCKET)Sock;
 
-    /* XXX See if there's a structure we can reuse -- We need to do this
-    * more properly. */
-    PrevSocket = GetSocketStructure( (SOCKET)Sock );
-
-    if( PrevSocket )
-    {
-        RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
-        RtlFreeHeap( GlobalHeap, 0, Socket );
-        Socket = PrevSocket;
-    }
-
     /* Save Group Info */
     if (g != 0)
     {
@@ -292,8 +281,10 @@ WSPSocket(int AddressFamily,
                           NULL);
 
     /* Save in Process Sockets List */
-    Sockets[SocketCount] = Socket;
-    SocketCount ++;
+    EnterCriticalSection(&SocketListLock);
+    Socket->NextSocket = SocketListHead;
+    SocketListHead = Socket;
+    LeaveCriticalSection(&SocketListLock);
 
     /* Create the Socket Context */
     CreateContext(Socket);
@@ -391,24 +382,6 @@ TranslateNtStatusError(NTSTATUS Status)
     }
 }
 
-DWORD MsafdReturnWithErrno(NTSTATUS Status,
-                           LPINT Errno,
-                           DWORD Received,
-                           LPDWORD ReturnedBytes)
-{
-    *Errno = TranslateNtStatusError(Status);
-
-    if (ReturnedBytes)
-    {
-        if (!*Errno)
-            *ReturnedBytes = Received;
-        else
-            *ReturnedBytes = 0;
-    }
-
-    return *Errno ? SOCKET_ERROR : 0;
-}
-
 /*
  * FUNCTION: Closes an open socket
  * ARGUMENTS:
@@ -423,7 +396,7 @@ WSPCloseSocket(IN SOCKET Handle,
                OUT LPINT lpErrno)
 {
     IO_STATUS_BLOCK IoStatusBlock;
-    PSOCKET_INFORMATION Socket = NULL;
+    PSOCKET_INFORMATION Socket = NULL, CurrentSocket;
     NTSTATUS Status;
     HANDLE SockEvent;
     AFD_DISCONNECT_INFO DisconnectInfo;
@@ -441,6 +414,12 @@ WSPCloseSocket(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
     {
@@ -562,6 +541,29 @@ WSPCloseSocket(IN SOCKET Handle,
     NtClose(Socket->TdiConnectionHandle);
     Socket->TdiConnectionHandle = NULL;
 
+    EnterCriticalSection(&SocketListLock);
+    if (SocketListHead == Socket)
+    {
+        SocketListHead = SocketListHead->NextSocket;
+    }
+    else
+    {
+        CurrentSocket = SocketListHead;
+        while (CurrentSocket->NextSocket)
+        {
+            if (CurrentSocket->NextSocket == Socket)
+            {
+                CurrentSocket->NextSocket = CurrentSocket->NextSocket->NextSocket;
+                break;
+            }
+
+            CurrentSocket = CurrentSocket->NextSocket;
+        }
+    }
+    LeaveCriticalSection(&SocketListLock);
+
+    HeapFree(GlobalHeap, 0, Socket);
+
     /* Close the handle */
     NtClose((HANDLE)Handle);
     NtClose(SockEvent);
@@ -615,6 +617,12 @@ WSPBind(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       HeapFree(GlobalHeap, 0, BindData);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set up Address in TDI Format */
     BindData->Address.TAAddressCount = 1;
@@ -708,6 +716,11 @@ WSPListen(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     if (Socket->SharedData.Listening)
         return 0;
@@ -774,12 +787,12 @@ WSPListen(SOCKET Handle,
 
 int
 WSPAPI
-WSPSelect(int nfds,
-          fd_set *readfds,
-          fd_set *writefds,
-          fd_set *exceptfds,
-          const LPTIMEVAL timeout,
-          LPINT lpErrno)
+WSPSelect(IN int nfds,
+          IN OUT fd_set *readfds OPTIONAL,
+          IN OUT fd_set *writefds OPTIONAL,
+          IN OUT fd_set *exceptfds OPTIONAL,
+          IN const struct timeval *timeout OPTIONAL,
+          OUT LPINT lpErrno)
 {
     IO_STATUS_BLOCK     IOSB;
     PAFD_POLL_INFO      PollInfo;
@@ -1052,6 +1065,12 @@ WSPAccept(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return INVALID_SOCKET;
+    }
 
     /* If this is non-blocking, make sure there's something for us to accept */
     FD_ZERO(&ReadSet);
@@ -1059,12 +1078,17 @@ WSPAccept(SOCKET Handle,
     Timeout.tv_sec=0;
     Timeout.tv_usec=0;
 
-    WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
+    if (WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, lpErrno) == SOCKET_ERROR)
+    {
+        NtClose(SockEvent);
+        return INVALID_SOCKET;
+    }
 
     if (ReadSet.fd_array[0] != Socket->Handle)
     {
         NtClose(SockEvent);
-        return 0;
+        *lpErrno = WSAEWOULDBLOCK;
+        return INVALID_SOCKET;
     }
 
     /* Send IOCTL */
@@ -1302,7 +1326,9 @@ WSPAccept(SOCKET Handle,
                               &ProtocolInfo,
                               GroupID,
                               Socket->SharedData.CreateFlags,
-                              NULL);
+                              lpErrno);
+    if (AcceptSocket == INVALID_SOCKET)
+        return INVALID_SOCKET;
 
     /* Set up the Accept Structure */
     AcceptData.ListenHandle = (HANDLE)AcceptSocket;
@@ -1408,6 +1434,12 @@ WSPConnect(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Bind us First */
     if (Socket->SharedData.State == SocketOpen)
@@ -1624,6 +1656,12 @@ WSPShutdown(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set AFD Disconnect Type */
     switch (HowTo)
@@ -1698,6 +1736,12 @@ WSPGetSockName(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
     TdiAddressSize = 
@@ -1786,6 +1830,12 @@ WSPGetPeerName(IN SOCKET s,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(s);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
     TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
@@ -1860,9 +1910,17 @@ WSPIoctl(IN  SOCKET Handle,
          OUT LPINT lpErrno)
 {
     PSOCKET_INFORMATION Socket = NULL;
+       BOOLEAN NeedsCompletion;
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
+       
+       *lpcbBytesReturned = 0;
 
     switch( dwIoControlCode )
     {
@@ -1873,17 +1931,44 @@ WSPIoctl(IN  SOCKET Handle,
                 return SOCKET_ERROR;
             }
             Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
-            return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+            *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                               return NO_ERROR;
         case FIONREAD:
             if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
             {
                 *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
             }
-            return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
+            *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                       {
+                               *lpcbBytesReturned = sizeof(ULONG);
+                               return NO_ERROR;
+                       }
         default:
-            *lpErrno = WSAEINVAL;
-            return SOCKET_ERROR;
+                       *lpErrno = Socket->HelperData->WSHIoctl(Socket->HelperContext,
+                                                                                                       Handle,
+                                                                                                       Socket->TdiAddressHandle,
+                                                                                                       Socket->TdiConnectionHandle,
+                                                                                                       dwIoControlCode,
+                                                                                                       lpvInBuffer,
+                                                                                                       cbInBuffer,
+                                                                                                       lpvOutBuffer,
+                                                                                                       cbOutBuffer,
+                                                                                                       lpcbBytesReturned,
+                                                                                                       lpOverlapped,
+                                                                                                       lpCompletionRoutine,
+                                                                                                       (LPBOOL)&NeedsCompletion);
+                       
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                               return NO_ERROR;
     }
 }
 
@@ -2015,7 +2100,16 @@ WSPSetSockOpt(
     }
 
 
-    /* FIXME: We should handle some cases here */
+    /* FIXME: We should handle some more cases here */
+    if (level == SOL_SOCKET)
+    {
+        switch (optname)
+        {
+           case SO_BROADCAST:
+              Socket->SharedData.Broadcast = (*optval != 0) ? 1 : 0;
+              return 0;
+        }
+    }
 
 
     *lpErrno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
@@ -2244,16 +2338,25 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
 PSOCKET_INFORMATION
 GetSocketStructure(SOCKET Handle)
 {
-    ULONG i;
+    PSOCKET_INFORMATION CurrentSocket;
+
+    EnterCriticalSection(&SocketListLock);
 
-    for (i=0; i<SocketCount; i++) 
+    CurrentSocket = SocketListHead;
+    while (CurrentSocket)
     {
-        if (Sockets[i]->Handle == Handle)
+        if (CurrentSocket->Handle == Handle)
         {
-            return Sockets[i];
+            LeaveCriticalSection(&SocketListLock);
+            return CurrentSocket;
         }
+
+        CurrentSocket = CurrentSocket->NextSocket;
     }
-    return 0;
+
+    LeaveCriticalSection(&SocketListLock);
+
+    return NULL;
 }
 
 int CreateContext(PSOCKET_INFORMATION Socket)
@@ -2771,9 +2874,8 @@ DllMain(HANDLE hInstDll,
         /* Heap to use when allocating */
         GlobalHeap = GetProcessHeap();
 
-        /* Allocate Heap for 1024 Sockets, can be expanded later */
-        Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
-        if (!Sockets) return FALSE;
+        /* Initialize the lock that protects our socket list */
+        InitializeCriticalSection(&SocketListLock);
 
         AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
 
@@ -2786,6 +2888,10 @@ DllMain(HANDLE hInstDll,
         break;
 
     case DLL_PROCESS_DETACH:
+
+        /* Delete the socket list lock */
+        DeleteCriticalSection(&SocketListLock);
+
         break;
     }