[MSAFD]
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
index 2cd1ec2..f06b0a5 100644 (file)
@@ -14,7 +14,7 @@
 
 #include <debug.h>
 
-#ifdef DBG
+#if DBG
 //DWORD DebugTraceLevel = DEBUG_ULTRA;
 DWORD DebugTraceLevel = 0;
 #endif /* DBG */
@@ -22,8 +22,7 @@ DWORD DebugTraceLevel = 0;
 HANDLE GlobalHeap;
 WSPUPCALLTABLE Upcalls;
 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
-ULONG SocketCount = 0;
-PSOCKET_INFORMATION *Sockets = NULL;
+PSOCKET_INFORMATION SocketListHead = NULL;
 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
 ULONG SockAsyncThreadRefCount;
 HANDLE SockAsyncHelperAfdHandle;
@@ -98,9 +97,13 @@ WSPSocket(int AddressFamily,
 
     /* Set Socket Data */
     Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
+    if (!Socket)
+        return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+
     RtlZeroMemory(Socket, sizeof(*Socket));
     Socket->RefCount = 2;
     Socket->Handle = -1;
+    Socket->SharedData.Listening = FALSE;
     Socket->SharedData.State = SocketOpen;
     Socket->SharedData.AddressFamily = AddressFamily;
     Socket->SharedData.SocketType = SocketType;
@@ -139,6 +142,9 @@ WSPSocket(int AddressFamily,
 
     /* Set up EA Buffer */
     EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
+    if (!EABuffer)
+        return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+
     RtlZeroMemory(EABuffer, SizeOfEA);
     EABuffer->NextEntryOffset = 0;
     EABuffer->Flags = 0;
@@ -228,16 +234,28 @@ WSPSocket(int AddressFamily,
     ourselves after every call to NtDeviceIoControlFile. This is
     because the kernel doesn't support overlapping synchronous I/O
     requests (made from multiple threads) at this time (Sep 2005) */
-    ZwCreateFile(&Sock,
-                 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
-                 &Object,
-                 &IOSB,
-                 NULL,
-                 0,
-                 FILE_SHARE_READ | FILE_SHARE_WRITE, FILE_OPEN_IF,
-                 0,
-                 EABuffer,
-                 SizeOfEA);
+    Status = NtCreateFile(&Sock,
+                          GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
+                          &Object,
+                          &IOSB,
+                          NULL,
+                          0,
+                          FILE_SHARE_READ | FILE_SHARE_WRITE,
+                          FILE_OPEN_IF,
+                          0,
+                          EABuffer,
+                          SizeOfEA);
+
+    HeapFree(GlobalHeap, 0, EABuffer);
+
+    if (Status != STATUS_SUCCESS)
+    {
+        AFD_DbgPrint(MIN_TRACE, ("Failed to open socket\n"));
+
+        HeapFree(GlobalHeap, 0, Socket);
+
+        return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
+    }
 
     /* Save Handle */
     Socket->Handle = (SOCKET)Sock;
@@ -273,8 +291,8 @@ WSPSocket(int AddressFamily,
                           NULL);
 
     /* Save in Process Sockets List */
-    Sockets[SocketCount] = Socket;
-    SocketCount ++;
+    Socket->NextSocket = SocketListHead;
+    SocketListHead = Socket;
 
     /* Create the Socket Context */
     CreateContext(Socket);
@@ -290,75 +308,104 @@ WSPSocket(int AddressFamily,
 error:
     AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
 
+    if( Socket )
+        HeapFree(GlobalHeap, 0, Socket);
+
     if( lpErrno )
         *lpErrno = Status;
 
     return INVALID_SOCKET;
 }
 
+INT
+TranslateNtStatusError(NTSTATUS Status)
+{
+    switch (Status)
+    {
+       case STATUS_CANT_WAIT:
+          return WSAEWOULDBLOCK;
+
+       case STATUS_TIMEOUT:
+          return WSAETIMEDOUT;
+
+       case STATUS_SUCCESS:
+          return NO_ERROR;
+
+       case STATUS_FILE_CLOSED:
+       case STATUS_END_OF_FILE:
+          return WSAESHUTDOWN;
+
+       case STATUS_PENDING:
+          return WSA_IO_PENDING;
+
+       case STATUS_BUFFER_TOO_SMALL:
+       case STATUS_BUFFER_OVERFLOW:
+          DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
+          return WSAEMSGSIZE;
+
+       case STATUS_NO_MEMORY:
+       case STATUS_INSUFFICIENT_RESOURCES:
+          DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
+          return WSAENOBUFS;
+
+       case STATUS_INVALID_CONNECTION:
+          DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
+          return WSAEAFNOSUPPORT;
+
+       case STATUS_INVALID_ADDRESS:
+          DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
+          return WSAEADDRNOTAVAIL;
+
+       case STATUS_REMOTE_NOT_LISTENING:
+          DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
+          return WSAECONNREFUSED;
+
+       case STATUS_NETWORK_UNREACHABLE:
+          DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
+          return WSAENETUNREACH;
+
+       case STATUS_INVALID_PARAMETER:
+          DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
+          return WSAEINVAL;
+
+       case STATUS_CANCELLED:
+          DbgPrint("MSAFD: STATUS_CANCELLED\n");
+          return WSA_OPERATION_ABORTED;
+
+       case STATUS_ADDRESS_ALREADY_EXISTS:
+          DbgPrint("MSAFD: STATUS_ADDRESS_ALREADY_EXISTS\n");
+          return WSAEADDRINUSE;
+
+       case STATUS_LOCAL_DISCONNECT:
+          DbgPrint("MSAFD: STATUS_LOCAL_DISCONNECT\n");
+          return WSAECONNABORTED;
+
+       case STATUS_REMOTE_DISCONNECT:
+          DbgPrint("MSAFD: STATUS_REMOTE_DISCONNECT\n");
+          return WSAECONNRESET;
+
+       default:
+          DbgPrint("MSAFD: Unhandled NTSTATUS value: 0x%x\n", Status);
+          return WSAENETDOWN;
+    }
+}
 
 DWORD MsafdReturnWithErrno(NTSTATUS Status,
                            LPINT Errno,
                            DWORD Received,
                            LPDWORD ReturnedBytes)
 {
-    if( ReturnedBytes )
-        *ReturnedBytes = 0;
-    if( Errno )
+    *Errno = TranslateNtStatusError(Status);
+
+    if (ReturnedBytes)
     {
-        switch (Status)
-        {
-        case STATUS_CANT_WAIT: 
-            *Errno = WSAEWOULDBLOCK;
-            break;
-        case STATUS_TIMEOUT:
-            *Errno = WSAETIMEDOUT;
-            break;
-        case STATUS_SUCCESS: 
-            /* Return Number of bytes Read */
-            if( ReturnedBytes ) 
-                *ReturnedBytes = Received;
-            break;
-        case STATUS_END_OF_FILE:
-            *Errno = WSAESHUTDOWN;
-            break;
-        case STATUS_PENDING: 
-            *Errno = WSA_IO_PENDING;
-            break;
-        case STATUS_BUFFER_OVERFLOW:
-            DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
-            *Errno = WSAEMSGSIZE;
-            break;
-        case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
-        case STATUS_INSUFFICIENT_RESOURCES:
-            DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
-            *Errno = WSA_NOT_ENOUGH_MEMORY;
-            break;
-        case STATUS_INVALID_CONNECTION:
-            DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
-            *Errno = WSAEAFNOSUPPORT;
-            break;
-        case STATUS_REMOTE_NOT_LISTENING:
-            DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
-            *Errno = WSAECONNRESET;
-            break;
-        case STATUS_FILE_CLOSED:
-            DbgPrint("MSAFD: STATUS_FILE_CLOSED\n");
-            *Errno = WSAENOTSOCK;
-            break;
-        case STATUS_INVALID_PARAMETER:
-            DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
-            *Errno = WSAEINVAL;
-            break;
-        default:
-            DbgPrint("MSAFD: Error %x is unknown\n", Status);
-            *Errno = WSAEINVAL;
-            break;
-        }
+        if (!*Errno)
+            *ReturnedBytes = Received;
+        else
+            *ReturnedBytes = 0;
     }
 
-    /* Success */
-    return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
+    return *Errno ? SOCKET_ERROR : 0;
 }
 
 /*
@@ -375,7 +422,7 @@ WSPCloseSocket(IN SOCKET Handle,
                OUT LPINT lpErrno)
 {
     IO_STATUS_BLOCK IoStatusBlock;
-    PSOCKET_INFORMATION Socket = NULL;
+    PSOCKET_INFORMATION Socket = NULL, CurrentSocket;
     NTSTATUS Status;
     HANDLE SockEvent;
     AFD_DISCONNECT_INFO DisconnectInfo;
@@ -393,10 +440,33 @@ WSPCloseSocket(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
+
+    if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
+    {
+        Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
+                                               Socket->Handle,
+                                               Socket->TdiAddressHandle,
+                                               Socket->TdiConnectionHandle,
+                                               WSH_NOTIFY_CLOSE);
+
+        if (Status)
+        {
+            if (lpErrno) *lpErrno = Status;
+            NtClose(SockEvent);
+            return SOCKET_ERROR;
+        }
+    }
 
     /* If a Close is already in Process, give up */
     if (Socket->SharedData.State == SocketClosed)
     {
+        NtClose(SockEvent);
         *lpErrno = WSAENOTSOCK;
         return SOCKET_ERROR;
     }
@@ -442,6 +512,7 @@ WSPCloseSocket(IN SOCKET Handle,
              */
             if (Socket->SharedData.NonBlocking)
             {
+                NtClose(SockEvent);
                 Socket->SharedData.State = OldState;
                 *lpErrno = WSAEWOULDBLOCK;
                 return SOCKET_ERROR;
@@ -483,12 +554,11 @@ WSPCloseSocket(IN SOCKET Handle,
             if (Status == STATUS_PENDING)
             {
                 WaitForSingleObject(SockEvent, INFINITE);
+                Status = IoStatusBlock.Status;
             }
         }
     }
 
-    /* FIXME: We should notify the Helper DLL of WSH_NOTIFY_CLOSE */
-
     /* Cleanup Time! */
     Socket->HelperContext = NULL;
     Socket->SharedData.AsyncDisabledEvents = -1;
@@ -497,10 +567,32 @@ WSPCloseSocket(IN SOCKET Handle,
     NtClose(Socket->TdiConnectionHandle);
     Socket->TdiConnectionHandle = NULL;
 
+    if (SocketListHead == Socket)
+    {
+        SocketListHead = SocketListHead->NextSocket;
+    }
+    else
+    {
+        CurrentSocket = SocketListHead;
+        while (CurrentSocket->NextSocket)
+        {
+            if (CurrentSocket->NextSocket == Socket)
+            {
+                CurrentSocket->NextSocket = CurrentSocket->NextSocket->NextSocket;
+                break;
+            }
+
+            CurrentSocket = CurrentSocket->NextSocket;
+        }
+    }
+
+    HeapFree(GlobalHeap, 0, Socket);
+
     /* Close the handle */
     NtClose((HANDLE)Handle);
+    NtClose(SockEvent);
 
-    return NO_ERROR;
+    return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
 }
 
 
@@ -525,24 +617,36 @@ WSPBind(SOCKET Handle,
     PAFD_BIND_DATA          BindData;
     PSOCKET_INFORMATION     Socket = NULL;
     NTSTATUS                Status;
-    UCHAR                   BindBuffer[0x1A];
     SOCKADDR_INFO           SocketInfo;
     HANDLE                  SockEvent;
 
+    /* See below */
+    BindData = HeapAlloc(GlobalHeap, 0, 0xA + SocketAddressLength);
+    if (!BindData)
+    {
+        return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+    }
+
     Status = NtCreateEvent(&SockEvent,
                            GENERIC_READ | GENERIC_WRITE,
                            NULL,
                            1,
                            FALSE);
 
-    if( !NT_SUCCESS(Status) )
-        return -1;
+    if (!NT_SUCCESS(Status))
+    {
+        HeapFree(GlobalHeap, 0, BindData);
+        return SOCKET_ERROR;
+    }
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
-
-    /* Dynamic Structure...ugh */
-    BindData = (PAFD_BIND_DATA)BindBuffer;
+    if (!Socket)
+    {
+       HeapFree(GlobalHeap, 0, BindData);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set up Address in TDI Format */
     BindData->Address.TAAddressCount = 1;
@@ -594,11 +698,30 @@ WSPBind(SOCKET Handle,
         Status = IOSB.Status;
     }
 
+    NtClose( SockEvent );
+    HeapFree(GlobalHeap, 0, BindData);
+
+    if (Status != STATUS_SUCCESS)
+        return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
+
     /* Set up Socket Data */
     Socket->SharedData.State = SocketBound;
     Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
 
-    NtClose( SockEvent );
+    if (Socket->HelperEvents & WSH_NOTIFY_BIND)
+    {
+        Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
+                                               Socket->Handle,
+                                               Socket->TdiAddressHandle,
+                                               Socket->TdiConnectionHandle,
+                                               WSH_NOTIFY_BIND);
+
+        if (Status)
+        {
+            if (lpErrno) *lpErrno = Status;
+            return SOCKET_ERROR;
+        }
+    }
 
     return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
 }
@@ -615,6 +738,17 @@ WSPListen(SOCKET Handle,
     HANDLE                  SockEvent;
     NTSTATUS                Status;
 
+    /* Get the Socket Structure associate to this Socket*/
+    Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
+
+    if (Socket->SharedData.Listening)
+        return 0;
+
     Status = NtCreateEvent(&SockEvent,
                            GENERIC_READ | GENERIC_WRITE,
                            NULL,
@@ -624,9 +758,6 @@ WSPListen(SOCKET Handle,
     if( !NT_SUCCESS(Status) )
         return -1;
 
-    /* Get the Socket Structure associate to this Socket*/
-    Socket = GetSocketStructure(Handle);
-
     /* Set Up Listen Structure */
     ListenData.UseSAN = FALSE;
     ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
@@ -649,12 +780,30 @@ WSPListen(SOCKET Handle,
     {
         WaitForSingleObject(SockEvent, INFINITE);
         Status = IOSB.Status;
-    }         
+    }
+
+    NtClose( SockEvent );
+
+    if (Status != STATUS_SUCCESS)
+       return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
 
     /* Set to Listening */
     Socket->SharedData.Listening = TRUE;
 
-    NtClose( SockEvent );
+    if (Socket->HelperEvents & WSH_NOTIFY_LISTEN)
+    {
+        Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
+                                               Socket->Handle,
+                                               Socket->TdiAddressHandle,
+                                               Socket->TdiConnectionHandle,
+                                               WSH_NOTIFY_LISTEN);
+
+        if (Status)
+        {
+           if (lpErrno) *lpErrno = Status;
+           return SOCKET_ERROR;
+        }
+    }
 
     return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
 }
@@ -666,7 +815,7 @@ WSPSelect(int nfds,
           fd_set *readfds,
           fd_set *writefds,
           fd_set *exceptfds,
-          struct timeval *timeout,
+          const LPTIMEVAL timeout,
           LPINT lpErrno)
 {
     IO_STATUS_BLOCK     IOSB;
@@ -687,10 +836,15 @@ WSPSelect(int nfds,
                   ( writefds ? writefds->fd_count : 0 ) +
                   ( exceptfds ? exceptfds->fd_count : 0 );
 
-    if( HandleCount < 0 || nfds != 0 )
-        HandleCount = nfds * 3;
+    if ( HandleCount == 0 )
+    {
+        AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
+                     HandleCount));
+        if (lpErrno) *lpErrno = WSAEINVAL;
+        return SOCKET_ERROR;
+    }
 
-    PollBufferSize = sizeof(*PollInfo) + (HandleCount * sizeof(AFD_HANDLE));
+    PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
 
     AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n", 
                  HandleCount, PollBufferSize));
@@ -755,6 +909,7 @@ WSPSelect(int nfds,
             PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
                                           AFD_EVENT_DISCONNECT |
                                           AFD_EVENT_ABORT |
+                                          AFD_EVENT_CLOSE |
                                           AFD_EVENT_ACCEPT;
         }
     }
@@ -796,6 +951,7 @@ WSPSelect(int nfds,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     /* Clear the Structures */
@@ -892,7 +1048,7 @@ WSPAccept(SOCKET Handle,
           struct sockaddr *SocketAddress,
           int *SocketAddressLength,
           LPCONDITIONPROC lpfnCondition,
-          DWORD_PTR dwCallbackData,
+          DWORD dwCallbackData,
           LPINT lpErrno)
 {
     IO_STATUS_BLOCK             IOSB;
@@ -933,6 +1089,12 @@ WSPAccept(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return INVALID_SOCKET;
+    }
 
     /* If this is non-blocking, make sure there's something for us to accept */
     FD_ZERO(&ReadSet);
@@ -940,10 +1102,18 @@ WSPAccept(SOCKET Handle,
     Timeout.tv_sec=0;
     Timeout.tv_usec=0;
 
-    WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
+    if (WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, lpErrno) == SOCKET_ERROR)
+    {
+        NtClose(SockEvent);
+        return INVALID_SOCKET;
+    }
 
     if (ReadSet.fd_array[0] != Socket->Handle)
-        return 0;
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAEWOULDBLOCK;
+        return INVALID_SOCKET;
+    }
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
@@ -1012,6 +1182,11 @@ WSPAccept(SOCKET Handle,
             {
                 /* Allocate needed space */
                 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
+                if (!PendingData)
+                {
+                    MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
+                    return INVALID_SOCKET;
+                }
 
                 /* We want the data now */
                 PendingAcceptData.ReturnSize = FALSE;
@@ -1053,6 +1228,13 @@ WSPAccept(SOCKET Handle,
         CalleeID.buf = (PVOID)Socket->LocalAddress;
         CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
 
+        RemoteAddress = HeapAlloc(GlobalHeap, 0, sizeof(*RemoteAddress));
+        if (!RemoteAddress)
+        {
+            MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+            return INVALID_SOCKET;
+        }
+
         /* Set up Address in SOCKADDR Format */
         RtlCopyMemory (RemoteAddress, 
                        &ListenReceiveData->Address.Address[0].AddressType, 
@@ -1071,6 +1253,10 @@ WSPAccept(SOCKET Handle,
         {
             /* Allocate Buffer for Callee Data */
             CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
+            if (!CalleeDataBuffer) {
+                MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
+                return INVALID_SOCKET;
+            }
             CalleeData.buf = CalleeDataBuffer;
             CalleeData.len = 4096;
         } 
@@ -1164,10 +1350,12 @@ WSPAccept(SOCKET Handle,
                               &ProtocolInfo,
                               GroupID,
                               Socket->SharedData.CreateFlags,
-                              NULL);
+                              lpErrno);
+    if (AcceptSocket == INVALID_SOCKET)
+        return INVALID_SOCKET;
 
     /* Set up the Accept Structure */
-    AcceptData.ListenHandle = AcceptSocket;
+    AcceptData.ListenHandle = (HANDLE)AcceptSocket;
     AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
 
     /* Send IOCTL to Accept */
@@ -1191,6 +1379,7 @@ WSPAccept(SOCKET Handle,
 
     if (!NT_SUCCESS(Status))
     {
+        NtClose(SockEvent);
         WSPCloseSocket( AcceptSocket, lpErrno );
         MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
         return INVALID_SOCKET;
@@ -1213,6 +1402,21 @@ WSPAccept(SOCKET Handle,
 
     AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
 
+    if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_ACCEPT))
+    {
+        Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
+                                               Socket->Handle,
+                                               Socket->TdiAddressHandle,
+                                               Socket->TdiConnectionHandle,
+                                               WSH_NOTIFY_ACCEPT);
+
+        if (Status)
+        {
+            if (lpErrno) *lpErrno = Status;
+            return INVALID_SOCKET;
+        }
+    }
+
     *lpErrno = 0;
 
     /* Return Socket */
@@ -1254,6 +1458,12 @@ WSPConnect(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Bind us First */
     if (Socket->SharedData.State == SocketOpen)
@@ -1261,11 +1471,17 @@ WSPConnect(SOCKET Handle,
         /* Get the Wildcard Address */
         BindAddressLength = Socket->HelperData->MaxWSAddressLength;
         BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
+        if (!BindAddress)
+        {
+            MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
+            return INVALID_SOCKET;
+        }
         Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext, 
                                                     BindAddress, 
                                                     &BindAddressLength);
         /* Bind it */
-        WSPBind(Handle, BindAddress, BindAddressLength, NULL);
+        if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR)
+            return INVALID_SOCKET;
     }
 
     /* Set the Connect Data */
@@ -1288,6 +1504,9 @@ WSPConnect(SOCKET Handle,
             WaitForSingleObject(SockEvent, INFINITE);
             Status = IOSB.Status;
         }
+
+        if (Status != STATUS_SUCCESS)
+            goto notify;
     }
 
     /* Dynamic Structure...ugh */
@@ -1333,6 +1552,9 @@ WSPConnect(SOCKET Handle,
             WaitForSingleObject(SockEvent, INFINITE);
             Status = IOSB.Status;
         }
+
+        if (Status != STATUS_SUCCESS)
+            goto notify;
     }
 
     /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
@@ -1364,6 +1586,11 @@ WSPConnect(SOCKET Handle,
         Status = IOSB.Status;
     }
 
+    if (Status != STATUS_SUCCESS)
+        goto notify;
+
+    Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
+
     /* Get any pending connect data */
     if (lpCalleeData != NULL)
     {
@@ -1385,16 +1612,46 @@ WSPConnect(SOCKET Handle,
         }
     }
 
+    AFD_DbgPrint(MID_TRACE,("Ending\n"));
+
+notify:
     /* Re-enable Async Event */
     SockReenableAsyncSelectEvent(Socket, FD_WRITE);
 
     /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
     SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
 
-    AFD_DbgPrint(MID_TRACE,("Ending\n"));
-
     NtClose( SockEvent );
 
+    if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT))
+    {
+        Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
+                                               Socket->Handle,
+                                               Socket->TdiAddressHandle,
+                                               Socket->TdiConnectionHandle,
+                                               WSH_NOTIFY_CONNECT);
+
+        if (Status)
+        {
+            if (lpErrno) *lpErrno = Status;
+            return SOCKET_ERROR;
+        }
+    }
+    else if (Status != STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT_ERROR))
+    {
+        Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
+                                               Socket->Handle,
+                                               Socket->TdiAddressHandle,
+                                               Socket->TdiConnectionHandle,
+                                               WSH_NOTIFY_CONNECT_ERROR);
+
+        if (Status)
+        {
+            if (lpErrno) *lpErrno = Status;
+            return SOCKET_ERROR;
+        }
+    }
+
     return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
 }
 int
@@ -1423,6 +1680,12 @@ WSPShutdown(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set AFD Disconnect Type */
     switch (HowTo)
@@ -1480,7 +1743,7 @@ WSPGetSockName(IN SOCKET Handle,
 {
     IO_STATUS_BLOCK         IOSB;
     ULONG                   TdiAddressSize;
-    PTDI_ADDRESS_INFO       TdiAddress;
+       PTDI_ADDRESS_INFO       TdiAddress;
     PTRANSPORT_ADDRESS      SocketAddress;
     PSOCKET_INFORMATION     Socket = NULL;
     NTSTATUS                Status;
@@ -1497,11 +1760,16 @@ WSPGetSockName(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
-    TdiAddressSize = FIELD_OFFSET(TDI_ADDRESS_INFO, Address.Address[0].Address) +
-                     Socket->SharedData.SizeOfLocalAddress;
-
+    TdiAddressSize = 
+               sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
     TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
 
     if ( TdiAddress == NULL )
@@ -1570,7 +1838,6 @@ WSPGetPeerName(IN SOCKET s,
 {
     IO_STATUS_BLOCK         IOSB;
     ULONG                   TdiAddressSize;
-    PTDI_ADDRESS_INFO       TdiAddress;
     PTRANSPORT_ADDRESS      SocketAddress;
     PSOCKET_INFORMATION     Socket = NULL;
     NTSTATUS                Status;
@@ -1587,21 +1854,24 @@ WSPGetPeerName(IN SOCKET s,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(s);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
-    TdiAddressSize = FIELD_OFFSET(TDI_ADDRESS_INFO, Address.Address[0].Address) +
-                     Socket->SharedData.SizeOfLocalAddress;
-    TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
+    TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
+    SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
 
-    if ( TdiAddress == NULL )
+    if ( SocketAddress == NULL )
     {
         NtClose( SockEvent );
         *lpErrno = WSAENOBUFS;
         return SOCKET_ERROR;
     }
 
-    SocketAddress = &TdiAddress->Address;
-
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
                                    SockEvent,
@@ -1611,7 +1881,7 @@ WSPGetPeerName(IN SOCKET s,
                                    IOCTL_AFD_GET_PEER_NAME,
                                    NULL,
                                    0,
-                                   TdiAddress,
+                                   SocketAddress,
                                    TdiAddressSize);
 
     /* Wait for return */
@@ -1635,12 +1905,12 @@ WSPGetPeerName(IN SOCKET s,
             AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
                           *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
                           ((struct sockaddr_in *)Name)->sin_port));
-            HeapFree(GlobalHeap, 0, TdiAddress);
+            HeapFree(GlobalHeap, 0, SocketAddress);
             return 0;
         }
         else
         {
-            HeapFree(GlobalHeap, 0, TdiAddress);
+            HeapFree(GlobalHeap, 0, SocketAddress);
             *lpErrno = WSAEFAULT;
             return SOCKET_ERROR;
         }
@@ -1667,16 +1937,28 @@ WSPIoctl(IN  SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     switch( dwIoControlCode )
     {
         case FIONBIO:
-            if( cbInBuffer < sizeof(INT) )
+            if( cbInBuffer < sizeof(INT) || IS_INTRESOURCE(lpvInBuffer) )
+            {
+                *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
-            Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
-            AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n", Handle, Socket->SharedData.NonBlocking));
-            return 0;
+            }
+            Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
+            return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
         case FIONREAD:
+            if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
+            {
+                *lpErrno = WSAEFAULT;
+                return SOCKET_ERROR;
+            }
             return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
         default:
             *lpErrno = WSAEINVAL;
@@ -1779,11 +2061,52 @@ WSPGetSockOpt(IN SOCKET Handle,
 
         case IPPROTO_TCP: /* FIXME */
         default:
-            *lpErrno = WSAEINVAL;
-            return SOCKET_ERROR;
+            *lpErrno = Socket->HelperData->WSHGetSocketInformation(Socket->HelperContext,
+                                                                   Handle,
+                                                                   Socket->TdiAddressHandle,
+                                                                   Socket->TdiConnectionHandle,
+                                                                   Level,
+                                                                   OptionName,
+                                                                   OptionValue,
+                                                                   (LPINT)OptionLength);
+            return (*lpErrno == 0) ? 0 : SOCKET_ERROR;
     }
 }
 
+INT
+WSPAPI
+WSPSetSockOpt(
+    IN  SOCKET s,
+    IN  INT level,
+    IN  INT optname,
+    IN  CONST CHAR FAR* optval,
+    IN  INT optlen,
+    OUT LPINT lpErrno)
+{
+    PSOCKET_INFORMATION Socket;
+
+    /* Get the Socket Structure associate to this Socket*/
+    Socket = GetSocketStructure(s);
+    if (Socket == NULL)
+    {
+        *lpErrno = WSAENOTSOCK;
+        return SOCKET_ERROR;
+    }
+
+
+    /* FIXME: We should handle some cases here */
+
+
+    *lpErrno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
+                                                           s,
+                                                           Socket->TdiAddressHandle,
+                                                           Socket->TdiConnectionHandle,
+                                                           level,
+                                                           optname,
+                                                           (PCHAR)optval,
+                                                           optlen);
+    return (*lpErrno == 0) ? 0 : SOCKET_ERROR;
+}
 
 /*
  * FUNCTION: Initialize service provider for a client
@@ -1913,10 +2236,17 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
+    if (Status != STATUS_SUCCESS)
+        return -1;
+
     /* Return Information */
-    *Ulong = InfoData.Information.Ulong;
+    if (Ulong != NULL)
+    {
+        *Ulong = InfoData.Information.Ulong;
+    }
     if (LargeInteger != NULL)
     {
         *LargeInteger = InfoData.Information.LargeInteger;
@@ -1953,7 +2283,10 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     InfoData.InformationClass = AfdInformationClass;
 
     /* Set Information */
-    InfoData.Information.Ulong = *Ulong;
+    if (Ulong != NULL)
+    {
+        InfoData.Information.Ulong = *Ulong;
+    }
     if (LargeInteger != NULL)
     {
         InfoData.Information.LargeInteger = *LargeInteger;
@@ -1968,7 +2301,7 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
                                    NULL,
                                    NULL,
                                    &IOSB,
-                                   IOCTL_AFD_GET_INFO,
+                                   IOCTL_AFD_SET_INFO,
                                    &InfoData,
                                    sizeof(InfoData),
                                    NULL,
@@ -1978,27 +2311,37 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     NtClose( SockEvent );
 
-    return 0;
+    return Status == STATUS_SUCCESS ? 0 : -1;
 
 }
 
 PSOCKET_INFORMATION
 GetSocketStructure(SOCKET Handle)
 {
-    ULONG i;
+    PSOCKET_INFORMATION CurrentSocket;
+
+    if (!SocketListHead)
+        return NULL;
 
-    for (i=0; i<SocketCount; i++) 
+    /* This is a special case */
+    if (SocketListHead->Handle == Handle)
+        return SocketListHead;
+
+    CurrentSocket = SocketListHead;
+    while (CurrentSocket->NextSocket)
     {
-        if (Sockets[i]->Handle == Handle)
-        {
-            return Sockets[i];
-        }
+        if (CurrentSocket->Handle == Handle)
+            return CurrentSocket;
+
+        CurrentSocket = CurrentSocket->NextSocket;
     }
-    return 0;
+
+    return NULL;
 }
 
 int CreateContext(PSOCKET_INFORMATION Socket)
@@ -2043,11 +2386,12 @@ int CreateContext(PSOCKET_INFORMATION Socket)
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     NtClose( SockEvent );
 
-    return 0;
+    return Status == STATUS_SUCCESS ? 0 : -1;
 }
 
 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
@@ -2272,6 +2616,7 @@ VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBl
 
                 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
             case AFD_EVENT_CONNECT:
+            case AFD_EVENT_CONNECT_FAIL:
                 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
                     0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
                 {
@@ -2375,7 +2720,7 @@ VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
 
     if (lNetworkEvents & FD_CLOSE)
     {
-        AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
+        AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT | AFD_EVENT_CLOSE;
     }
 
     if (lNetworkEvents & FD_QOS)
@@ -2468,6 +2813,7 @@ SockReenableAsyncSelectEvent (IN PSOCKET_INFORMATION Socket,
 
     /* Wait on new events */
     AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
+    if (!AsyncData) return;
 
     /* Create the Asynch Thread if Needed */  
     SockCreateOrReferenceAsyncThread();
@@ -2513,9 +2859,6 @@ DllMain(HANDLE hInstDll,
         /* Heap to use when allocating */
         GlobalHeap = GetProcessHeap();
 
-        /* Allocate Heap for 1024 Sockets, can be expanded later */
-        Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
-
         AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
 
         break;