[MSAFD]
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
index f501eb0..f06b0a5 100644 (file)
@@ -22,8 +22,7 @@ DWORD DebugTraceLevel = 0;
 HANDLE GlobalHeap;
 WSPUPCALLTABLE Upcalls;
 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
-ULONG SocketCount = 0;
-PSOCKET_INFORMATION *Sockets = NULL;
+PSOCKET_INFORMATION SocketListHead = NULL;
 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
 ULONG SockAsyncThreadRefCount;
 HANDLE SockAsyncHelperAfdHandle;
@@ -292,8 +291,8 @@ WSPSocket(int AddressFamily,
                           NULL);
 
     /* Save in Process Sockets List */
-    Sockets[SocketCount] = Socket;
-    SocketCount ++;
+    Socket->NextSocket = SocketListHead;
+    SocketListHead = Socket;
 
     /* Create the Socket Context */
     CreateContext(Socket);
@@ -373,6 +372,18 @@ TranslateNtStatusError(NTSTATUS Status)
           DbgPrint("MSAFD: STATUS_CANCELLED\n");
           return WSA_OPERATION_ABORTED;
 
+       case STATUS_ADDRESS_ALREADY_EXISTS:
+          DbgPrint("MSAFD: STATUS_ADDRESS_ALREADY_EXISTS\n");
+          return WSAEADDRINUSE;
+
+       case STATUS_LOCAL_DISCONNECT:
+          DbgPrint("MSAFD: STATUS_LOCAL_DISCONNECT\n");
+          return WSAECONNABORTED;
+
+       case STATUS_REMOTE_DISCONNECT:
+          DbgPrint("MSAFD: STATUS_REMOTE_DISCONNECT\n");
+          return WSAECONNRESET;
+
        default:
           DbgPrint("MSAFD: Unhandled NTSTATUS value: 0x%x\n", Status);
           return WSAENETDOWN;
@@ -411,7 +422,7 @@ WSPCloseSocket(IN SOCKET Handle,
                OUT LPINT lpErrno)
 {
     IO_STATUS_BLOCK IoStatusBlock;
-    PSOCKET_INFORMATION Socket = NULL;
+    PSOCKET_INFORMATION Socket = NULL, CurrentSocket;
     NTSTATUS Status;
     HANDLE SockEvent;
     AFD_DISCONNECT_INFO DisconnectInfo;
@@ -429,6 +440,12 @@ WSPCloseSocket(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
     {
@@ -550,11 +567,32 @@ WSPCloseSocket(IN SOCKET Handle,
     NtClose(Socket->TdiConnectionHandle);
     Socket->TdiConnectionHandle = NULL;
 
+    if (SocketListHead == Socket)
+    {
+        SocketListHead = SocketListHead->NextSocket;
+    }
+    else
+    {
+        CurrentSocket = SocketListHead;
+        while (CurrentSocket->NextSocket)
+        {
+            if (CurrentSocket->NextSocket == Socket)
+            {
+                CurrentSocket->NextSocket = CurrentSocket->NextSocket->NextSocket;
+                break;
+            }
+
+            CurrentSocket = CurrentSocket->NextSocket;
+        }
+    }
+
+    HeapFree(GlobalHeap, 0, Socket);
+
     /* Close the handle */
     NtClose((HANDLE)Handle);
     NtClose(SockEvent);
 
-    return NO_ERROR;
+    return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
 }
 
 
@@ -603,6 +641,12 @@ WSPBind(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       HeapFree(GlobalHeap, 0, BindData);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set up Address in TDI Format */
     BindData->Address.TAAddressCount = 1;
@@ -654,13 +698,17 @@ WSPBind(SOCKET Handle,
         Status = IOSB.Status;
     }
 
+    NtClose( SockEvent );
+    HeapFree(GlobalHeap, 0, BindData);
+
+    if (Status != STATUS_SUCCESS)
+        return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
+
     /* Set up Socket Data */
     Socket->SharedData.State = SocketBound;
     Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
 
-    NtClose( SockEvent );
-    HeapFree(GlobalHeap, 0, BindData);
-    if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_BIND))
+    if (Socket->HelperEvents & WSH_NOTIFY_BIND)
     {
         Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
                                                Socket->Handle,
@@ -692,6 +740,11 @@ WSPListen(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     if (Socket->SharedData.Listening)
         return 0;
@@ -727,14 +780,17 @@ WSPListen(SOCKET Handle,
     {
         WaitForSingleObject(SockEvent, INFINITE);
         Status = IOSB.Status;
-    }         
+    }
+
+    NtClose( SockEvent );
+
+    if (Status != STATUS_SUCCESS)
+       return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
 
     /* Set to Listening */
     Socket->SharedData.Listening = TRUE;
 
-    NtClose( SockEvent );
-
-    if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_LISTEN))
+    if (Socket->HelperEvents & WSH_NOTIFY_LISTEN)
     {
         Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
                                                Socket->Handle,
@@ -895,6 +951,7 @@ WSPSelect(int nfds,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     /* Clear the Structures */
@@ -1032,6 +1089,12 @@ WSPAccept(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return INVALID_SOCKET;
+    }
 
     /* If this is non-blocking, make sure there's something for us to accept */
     FD_ZERO(&ReadSet);
@@ -1039,12 +1102,17 @@ WSPAccept(SOCKET Handle,
     Timeout.tv_sec=0;
     Timeout.tv_usec=0;
 
-    WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
+    if (WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, lpErrno) == SOCKET_ERROR)
+    {
+        NtClose(SockEvent);
+        return INVALID_SOCKET;
+    }
 
     if (ReadSet.fd_array[0] != Socket->Handle)
     {
         NtClose(SockEvent);
-        return 0;
+        *lpErrno = WSAEWOULDBLOCK;
+        return INVALID_SOCKET;
     }
 
     /* Send IOCTL */
@@ -1282,7 +1350,9 @@ WSPAccept(SOCKET Handle,
                               &ProtocolInfo,
                               GroupID,
                               Socket->SharedData.CreateFlags,
-                              NULL);
+                              lpErrno);
+    if (AcceptSocket == INVALID_SOCKET)
+        return INVALID_SOCKET;
 
     /* Set up the Accept Structure */
     AcceptData.ListenHandle = (HANDLE)AcceptSocket;
@@ -1388,6 +1458,12 @@ WSPConnect(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Bind us First */
     if (Socket->SharedData.State == SocketOpen)
@@ -1404,7 +1480,8 @@ WSPConnect(SOCKET Handle,
                                                     BindAddress, 
                                                     &BindAddressLength);
         /* Bind it */
-        WSPBind(Handle, BindAddress, BindAddressLength, NULL);
+        if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR)
+            return INVALID_SOCKET;
     }
 
     /* Set the Connect Data */
@@ -1427,6 +1504,9 @@ WSPConnect(SOCKET Handle,
             WaitForSingleObject(SockEvent, INFINITE);
             Status = IOSB.Status;
         }
+
+        if (Status != STATUS_SUCCESS)
+            goto notify;
     }
 
     /* Dynamic Structure...ugh */
@@ -1472,6 +1552,9 @@ WSPConnect(SOCKET Handle,
             WaitForSingleObject(SockEvent, INFINITE);
             Status = IOSB.Status;
         }
+
+        if (Status != STATUS_SUCCESS)
+            goto notify;
     }
 
     /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
@@ -1503,6 +1586,9 @@ WSPConnect(SOCKET Handle,
         Status = IOSB.Status;
     }
 
+    if (Status != STATUS_SUCCESS)
+        goto notify;
+
     Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
 
     /* Get any pending connect data */
@@ -1526,14 +1612,15 @@ WSPConnect(SOCKET Handle,
         }
     }
 
+    AFD_DbgPrint(MID_TRACE,("Ending\n"));
+
+notify:
     /* Re-enable Async Event */
     SockReenableAsyncSelectEvent(Socket, FD_WRITE);
 
     /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
     SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
 
-    AFD_DbgPrint(MID_TRACE,("Ending\n"));
-
     NtClose( SockEvent );
 
     if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT))
@@ -1593,6 +1680,12 @@ WSPShutdown(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set AFD Disconnect Type */
     switch (HowTo)
@@ -1667,6 +1760,12 @@ WSPGetSockName(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
     TdiAddressSize = 
@@ -1755,6 +1854,12 @@ WSPGetPeerName(IN SOCKET s,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(s);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
     TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
@@ -1832,6 +1937,11 @@ WSPIoctl(IN  SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     switch( dwIoControlCode )
     {
@@ -2126,8 +2236,12 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
+    if (Status != STATUS_SUCCESS)
+        return -1;
+
     /* Return Information */
     if (Ulong != NULL)
     {
@@ -2197,27 +2311,37 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     NtClose( SockEvent );
 
-    return 0;
+    return Status == STATUS_SUCCESS ? 0 : -1;
 
 }
 
 PSOCKET_INFORMATION
 GetSocketStructure(SOCKET Handle)
 {
-    ULONG i;
+    PSOCKET_INFORMATION CurrentSocket;
+
+    if (!SocketListHead)
+        return NULL;
+
+    /* This is a special case */
+    if (SocketListHead->Handle == Handle)
+        return SocketListHead;
 
-    for (i=0; i<SocketCount; i++) 
+    CurrentSocket = SocketListHead;
+    while (CurrentSocket->NextSocket)
     {
-        if (Sockets[i]->Handle == Handle)
-        {
-            return Sockets[i];
-        }
+        if (CurrentSocket->Handle == Handle)
+            return CurrentSocket;
+
+        CurrentSocket = CurrentSocket->NextSocket;
     }
-    return 0;
+
+    return NULL;
 }
 
 int CreateContext(PSOCKET_INFORMATION Socket)
@@ -2262,11 +2386,12 @@ int CreateContext(PSOCKET_INFORMATION Socket)
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     NtClose( SockEvent );
 
-    return 0;
+    return Status == STATUS_SUCCESS ? 0 : -1;
 }
 
 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
@@ -2734,10 +2859,6 @@ DllMain(HANDLE hInstDll,
         /* Heap to use when allocating */
         GlobalHeap = GetProcessHeap();
 
-        /* Allocate Heap for 1024 Sockets, can be expanded later */
-        Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
-        if (!Sockets) return FALSE;
-
         AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
 
         break;