- Update to r53061
[reactos.git] / dll / win32 / msafd / misc / dllmain.c
index c832a85..87e5dcd 100644 (file)
@@ -22,12 +22,12 @@ DWORD DebugTraceLevel = 0;
 HANDLE GlobalHeap;
 WSPUPCALLTABLE Upcalls;
 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
-ULONG SocketCount = 0;
-PSOCKET_INFORMATION *Sockets = NULL;
+PSOCKET_INFORMATION SocketListHead = NULL;
+CRITICAL_SECTION SocketListLock;
 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
 ULONG SockAsyncThreadRefCount;
 HANDLE SockAsyncHelperAfdHandle;
-HANDLE SockAsyncCompletionPort;
+HANDLE SockAsyncCompletionPort = NULL;
 BOOLEAN SockAsyncSelectCalled;
 
 
@@ -61,7 +61,7 @@ WSPSocket(int AddressFamily,
     ULONG                       SizeOfEA;
     PAFD_CREATE_PACKET          AfdPacket;
     HANDLE                      Sock;
-    PSOCKET_INFORMATION         Socket = NULL, PrevSocket = NULL;
+    PSOCKET_INFORMATION         Socket = NULL;
     PFILE_FULL_EA_INFORMATION   EABuffer = NULL;
     PHELPER_DATA                HelperData;
     PVOID                       HelperDLLContext;
@@ -261,21 +261,10 @@ WSPSocket(int AddressFamily,
     /* Save Handle */
     Socket->Handle = (SOCKET)Sock;
 
-    /* XXX See if there's a structure we can reuse -- We need to do this
-    * more properly. */
-    PrevSocket = GetSocketStructure( (SOCKET)Sock );
-
-    if( PrevSocket )
-    {
-        RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
-        RtlFreeHeap( GlobalHeap, 0, Socket );
-        Socket = PrevSocket;
-    }
-
     /* Save Group Info */
     if (g != 0)
     {
-        GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
+        GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, NULL, NULL, &GroupData);
         Socket->SharedData.GroupID = GroupData.u.LowPart;
         Socket->SharedData.GroupType = GroupData.u.HighPart;
     }
@@ -283,17 +272,21 @@ WSPSocket(int AddressFamily,
     /* Get Window Sizes and Save them */
     GetSocketInformation (Socket,
                           AFD_INFO_SEND_WINDOW_SIZE,
+                          NULL,
                           &Socket->SharedData.SizeOfSendBuffer,
                           NULL);
 
     GetSocketInformation (Socket,
                           AFD_INFO_RECEIVE_WINDOW_SIZE,
+                          NULL,
                           &Socket->SharedData.SizeOfRecvBuffer,
                           NULL);
 
     /* Save in Process Sockets List */
-    Sockets[SocketCount] = Socket;
-    SocketCount ++;
+    EnterCriticalSection(&SocketListLock);
+    Socket->NextSocket = SocketListHead;
+    SocketListHead = Socket;
+    LeaveCriticalSection(&SocketListLock);
 
     /* Create the Socket Context */
     CreateContext(Socket);
@@ -318,79 +311,77 @@ error:
     return INVALID_SOCKET;
 }
 
-
-DWORD MsafdReturnWithErrno(NTSTATUS Status,
-                           LPINT Errno,
-                           DWORD Received,
-                           LPDWORD ReturnedBytes)
+INT
+TranslateNtStatusError(NTSTATUS Status)
 {
-    if( ReturnedBytes )
-        *ReturnedBytes = 0;
-    if( Errno )
+    switch (Status)
     {
-        switch (Status)
-        {
-        case STATUS_CANT_WAIT: 
-            *Errno = WSAEWOULDBLOCK;
-            break;
-        case STATUS_TIMEOUT:
-            *Errno = WSAETIMEDOUT;
-            break;
-        case STATUS_SUCCESS: 
-            /* Return Number of bytes Read */
-            if( ReturnedBytes ) 
-                *ReturnedBytes = Received;
-            break;
-        case STATUS_FILE_CLOSED:
-        case STATUS_END_OF_FILE:
-            *Errno = WSAESHUTDOWN;
-            break;
-        case STATUS_PENDING: 
-            *Errno = WSA_IO_PENDING;
-            break;
-        case STATUS_BUFFER_TOO_SMALL:
-        case STATUS_BUFFER_OVERFLOW:
-            DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
-            *Errno = WSAEMSGSIZE;
-            break;
-        case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
-        case STATUS_INSUFFICIENT_RESOURCES:
-            DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
-            *Errno = WSAENOBUFS;
-            break;
-        case STATUS_INVALID_CONNECTION:
-            DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
-            *Errno = WSAEAFNOSUPPORT;
-            break;
-        case STATUS_INVALID_ADDRESS:
-            DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
-            *Errno = WSAEADDRNOTAVAIL;
-            break;
-        case STATUS_REMOTE_NOT_LISTENING:
-            DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
-            *Errno = WSAECONNREFUSED;
-            break;
-        case STATUS_NETWORK_UNREACHABLE:
-            DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
-            *Errno = WSAENETUNREACH;
-            break;
-        case STATUS_INVALID_PARAMETER:
-            DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
-            *Errno = WSAEINVAL;
-            break;
-        case STATUS_CANCELLED:
-            DbgPrint("MSAFD: STATUS_CANCELLED\n");
-            *Errno = WSA_OPERATION_ABORTED;
-            break;
-        default:
-            DbgPrint("MSAFD: Error %x is unknown\n", Status);
-            *Errno = WSAEINVAL;
-            break;
-        }
-    }
+       case STATUS_CANT_WAIT:
+          return WSAEWOULDBLOCK;
+
+       case STATUS_TIMEOUT:
+          return WSAETIMEDOUT;
+
+       case STATUS_SUCCESS:
+          return NO_ERROR;
+
+       case STATUS_FILE_CLOSED:
+       case STATUS_END_OF_FILE:
+          return WSAESHUTDOWN;
+
+       case STATUS_PENDING:
+          return WSA_IO_PENDING;
+
+       case STATUS_BUFFER_TOO_SMALL:
+       case STATUS_BUFFER_OVERFLOW:
+          DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
+          return WSAEMSGSIZE;
+
+       case STATUS_NO_MEMORY:
+       case STATUS_INSUFFICIENT_RESOURCES:
+          DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
+          return WSAENOBUFS;
+
+       case STATUS_INVALID_CONNECTION:
+          DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
+          return WSAEAFNOSUPPORT;
+
+       case STATUS_INVALID_ADDRESS:
+          DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
+          return WSAEADDRNOTAVAIL;
+
+       case STATUS_REMOTE_NOT_LISTENING:
+          DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
+          return WSAECONNREFUSED;
 
-    /* Success */
-    return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
+       case STATUS_NETWORK_UNREACHABLE:
+          DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
+          return WSAENETUNREACH;
+
+       case STATUS_INVALID_PARAMETER:
+          DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
+          return WSAEINVAL;
+
+       case STATUS_CANCELLED:
+          DbgPrint("MSAFD: STATUS_CANCELLED\n");
+          return WSA_OPERATION_ABORTED;
+
+       case STATUS_ADDRESS_ALREADY_EXISTS:
+          DbgPrint("MSAFD: STATUS_ADDRESS_ALREADY_EXISTS\n");
+          return WSAEADDRINUSE;
+
+       case STATUS_LOCAL_DISCONNECT:
+          DbgPrint("MSAFD: STATUS_LOCAL_DISCONNECT\n");
+          return WSAECONNABORTED;
+
+       case STATUS_REMOTE_DISCONNECT:
+          DbgPrint("MSAFD: STATUS_REMOTE_DISCONNECT\n");
+          return WSAECONNRESET;
+
+       default:
+          DbgPrint("MSAFD: Unhandled NTSTATUS value: 0x%x\n", Status);
+          return WSAENETDOWN;
+    }
 }
 
 /*
@@ -407,11 +398,12 @@ WSPCloseSocket(IN SOCKET Handle,
                OUT LPINT lpErrno)
 {
     IO_STATUS_BLOCK IoStatusBlock;
-    PSOCKET_INFORMATION Socket = NULL;
+    PSOCKET_INFORMATION Socket = NULL, CurrentSocket;
     NTSTATUS Status;
     HANDLE SockEvent;
     AFD_DISCONNECT_INFO DisconnectInfo;
     SOCKET_STATE OldState;
+    LONG LingerWait = -1;
 
     /* Create the Wait Event */
     Status = NtCreateEvent(&SockEvent,
@@ -425,6 +417,12 @@ WSPCloseSocket(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
     {
@@ -458,7 +456,6 @@ WSPCloseSocket(IN SOCKET Handle,
     /* FIXME: Should we do this on Datagram Sockets too? */
     if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
     {
-        ULONG LingerWait;
         ULONG SendsInProgress;
         ULONG SleepWait;
 
@@ -472,6 +469,7 @@ WSPCloseSocket(IN SOCKET Handle,
             /* Find out how many Sends are in Progress */
             if (GetSocketInformation(Socket,
                                      AFD_INFO_SENDS_IN_PROGRESS,
+                                     NULL,
                                      &SendsInProgress,
                                      NULL))
             {
@@ -482,7 +480,11 @@ WSPCloseSocket(IN SOCKET Handle,
 
             /* Bail out if no more sends are pending */
             if (!SendsInProgress)
+            {
+                LingerWait = -1;
                 break;
+            }
+
             /* 
              * We have to execute a sleep, so it's kind of like
              * a block. If the socket is Nonblock, we cannot
@@ -507,33 +509,36 @@ WSPCloseSocket(IN SOCKET Handle,
             Sleep(SleepWait);
             LingerWait -= SleepWait;
         }
+    }
 
-        /*
-        * We have reached the timeout or sends are over.
-        * Disconnect if the timeout has been reached. 
-        */
+    if (OldState == SocketConnected)
+    {
         if (LingerWait <= 0)
         {
             DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
-            DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
-
-            /* Send IOCTL */
-            Status = NtDeviceIoControlFile((HANDLE)Handle,
-                                           SockEvent,
-                                           NULL,
-                                           NULL,
-                                           &IoStatusBlock,
-                                           IOCTL_AFD_DISCONNECT,
-                                           &DisconnectInfo,
-                                           sizeof(DisconnectInfo),
-                                           NULL,
-                                           0);
-
-            /* Wait for return */
-            if (Status == STATUS_PENDING)
+            DisconnectInfo.DisconnectType = LingerWait < 0 ? AFD_DISCONNECT_SEND : AFD_DISCONNECT_ABORT;
+            
+            if (((DisconnectInfo.DisconnectType & AFD_DISCONNECT_SEND) && (!Socket->SharedData.SendShutdown)) ||
+                ((DisconnectInfo.DisconnectType & AFD_DISCONNECT_ABORT) && (!Socket->SharedData.ReceiveShutdown)))
             {
-                WaitForSingleObject(SockEvent, INFINITE);
-                Status = IoStatusBlock.Status;
+                /* Send IOCTL */
+                Status = NtDeviceIoControlFile((HANDLE)Handle,
+                                               SockEvent,
+                                               NULL,
+                                               NULL,
+                                               &IoStatusBlock,
+                                               IOCTL_AFD_DISCONNECT,
+                                               &DisconnectInfo,
+                                               sizeof(DisconnectInfo),
+                                               NULL,
+                                               0);
+                
+                /* Wait for return */
+                if (Status == STATUS_PENDING)
+                {
+                    WaitForSingleObject(SockEvent, INFINITE);
+                    Status = IoStatusBlock.Status;
+                }
             }
         }
     }
@@ -546,11 +551,33 @@ WSPCloseSocket(IN SOCKET Handle,
     NtClose(Socket->TdiConnectionHandle);
     Socket->TdiConnectionHandle = NULL;
 
+    EnterCriticalSection(&SocketListLock);
+    if (SocketListHead == Socket)
+    {
+        SocketListHead = SocketListHead->NextSocket;
+    }
+    else
+    {
+        CurrentSocket = SocketListHead;
+        while (CurrentSocket->NextSocket)
+        {
+            if (CurrentSocket->NextSocket == Socket)
+            {
+                CurrentSocket->NextSocket = CurrentSocket->NextSocket->NextSocket;
+                break;
+            }
+
+            CurrentSocket = CurrentSocket->NextSocket;
+        }
+    }
+    LeaveCriticalSection(&SocketListLock);
+
     /* Close the handle */
     NtClose((HANDLE)Handle);
     NtClose(SockEvent);
 
-    return NO_ERROR;
+    HeapFree(GlobalHeap, 0, Socket);
+    return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
 }
 
 
@@ -599,6 +626,12 @@ WSPBind(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       HeapFree(GlobalHeap, 0, BindData);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set up Address in TDI Format */
     BindData->Address.TAAddressCount = 1;
@@ -650,13 +683,17 @@ WSPBind(SOCKET Handle,
         Status = IOSB.Status;
     }
 
+    NtClose( SockEvent );
+    HeapFree(GlobalHeap, 0, BindData);
+
+    if (Status != STATUS_SUCCESS)
+        return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
+
     /* Set up Socket Data */
     Socket->SharedData.State = SocketBound;
     Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
 
-    NtClose( SockEvent );
-    HeapFree(GlobalHeap, 0, BindData);
-    if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_BIND))
+    if (Socket->HelperEvents & WSH_NOTIFY_BIND)
     {
         Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
                                                Socket->Handle,
@@ -688,6 +725,11 @@ WSPListen(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     if (Socket->SharedData.Listening)
         return 0;
@@ -723,14 +765,17 @@ WSPListen(SOCKET Handle,
     {
         WaitForSingleObject(SockEvent, INFINITE);
         Status = IOSB.Status;
-    }         
+    }
+
+    NtClose( SockEvent );
+
+    if (Status != STATUS_SUCCESS)
+       return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
 
     /* Set to Listening */
     Socket->SharedData.Listening = TRUE;
 
-    NtClose( SockEvent );
-
-    if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_LISTEN))
+    if (Socket->HelperEvents & WSH_NOTIFY_LISTEN)
     {
         Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
                                                Socket->Handle,
@@ -761,7 +806,8 @@ WSPSelect(IN int nfds,
     IO_STATUS_BLOCK     IOSB;
     PAFD_POLL_INFO      PollInfo;
     NTSTATUS            Status;
-    LONG                HandleCount, OutCount = 0;
+    ULONG               HandleCount;
+    LONG                OutCount = 0;
     ULONG               PollBufferSize;
     PVOID               PollBuffer;
     ULONG               i, j = 0, x;
@@ -778,7 +824,7 @@ WSPSelect(IN int nfds,
 
     if ( HandleCount == 0 )
     {
-        AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
+        AFD_DbgPrint(MAX_TRACE,("HandleCount: %u. Return SOCKET_ERROR\n",
                      HandleCount));
         if (lpErrno) *lpErrno = WSAEINVAL;
         return SOCKET_ERROR;
@@ -786,7 +832,7 @@ WSPSelect(IN int nfds,
 
     PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
 
-    AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n", 
+    AFD_DbgPrint(MID_TRACE,("HandleCount: %u BufferSize: %u\n", 
                  HandleCount, PollBufferSize));
 
     /* Convert Timeout to NT Format */
@@ -871,7 +917,7 @@ WSPSelect(IN int nfds,
     }
 
     PollInfo->HandleCount = j;
-    PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
+    PollBufferSize = FIELD_OFFSET(AFD_POLL_INFO, Handles) + PollInfo->HandleCount * sizeof(AFD_HANDLE);
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
@@ -891,6 +937,7 @@ WSPSelect(IN int nfds,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     /* Clear the Structures */
@@ -1008,6 +1055,7 @@ WSPAccept(SOCKET Handle,
     ULONG                       CallBack;
     WSAPROTOCOL_INFOW           ProtocolInfo;
     SOCKET                      AcceptSocket;
+    PSOCKET_INFORMATION         AcceptSocketInfo;
     UCHAR                       ReceiveBuffer[0x1A];
     HANDLE                      SockEvent;
 
@@ -1028,6 +1076,12 @@ WSPAccept(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return INVALID_SOCKET;
+    }
 
     /* If this is non-blocking, make sure there's something for us to accept */
     FD_ZERO(&ReadSet);
@@ -1035,12 +1089,17 @@ WSPAccept(SOCKET Handle,
     Timeout.tv_sec=0;
     Timeout.tv_usec=0;
 
-    WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
+    if (WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, lpErrno) == SOCKET_ERROR)
+    {
+        NtClose(SockEvent);
+        return INVALID_SOCKET;
+    }
 
     if (ReadSet.fd_array[0] != Socket->Handle)
     {
         NtClose(SockEvent);
-        return 0;
+        *lpErrno = WSAEWOULDBLOCK;
+        return INVALID_SOCKET;
     }
 
     /* Send IOCTL */
@@ -1278,7 +1337,9 @@ WSPAccept(SOCKET Handle,
                               &ProtocolInfo,
                               GroupID,
                               Socket->SharedData.CreateFlags,
-                              NULL);
+                              lpErrno);
+    if (AcceptSocket == INVALID_SOCKET)
+        return INVALID_SOCKET;
 
     /* Set up the Accept Structure */
     AcceptData.ListenHandle = (HANDLE)AcceptSocket;
@@ -1310,6 +1371,17 @@ WSPAccept(SOCKET Handle,
         MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
         return INVALID_SOCKET;
     }
+    
+    AcceptSocketInfo = GetSocketStructure(AcceptSocket);
+    if (!AcceptSocketInfo)
+    {
+        NtClose(SockEvent);
+        WSPCloseSocket( AcceptSocket, lpErrno );
+        MsafdReturnWithErrno( STATUS_INVALID_CONNECTION, lpErrno, 0, NULL );
+        return INVALID_SOCKET;
+    }
+    
+    AcceptSocketInfo->SharedData.State = SocketConnected;
 
     /* Return Address in SOCKADDR FORMAT */
     if( SocketAddress )
@@ -1384,6 +1456,12 @@ WSPConnect(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Bind us First */
     if (Socket->SharedData.State == SocketOpen)
@@ -1400,7 +1478,8 @@ WSPConnect(SOCKET Handle,
                                                     BindAddress, 
                                                     &BindAddressLength);
         /* Bind it */
-        WSPBind(Handle, BindAddress, BindAddressLength, NULL);
+        if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR)
+            return INVALID_SOCKET;
     }
 
     /* Set the Connect Data */
@@ -1423,6 +1502,9 @@ WSPConnect(SOCKET Handle,
             WaitForSingleObject(SockEvent, INFINITE);
             Status = IOSB.Status;
         }
+
+        if (Status != STATUS_SUCCESS)
+            goto notify;
     }
 
     /* Dynamic Structure...ugh */
@@ -1468,6 +1550,9 @@ WSPConnect(SOCKET Handle,
             WaitForSingleObject(SockEvent, INFINITE);
             Status = IOSB.Status;
         }
+
+        if (Status != STATUS_SUCCESS)
+            goto notify;
     }
 
     /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
@@ -1499,6 +1584,10 @@ WSPConnect(SOCKET Handle,
         Status = IOSB.Status;
     }
 
+    if (Status != STATUS_SUCCESS)
+        goto notify;
+
+    Socket->SharedData.State = SocketConnected;
     Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
 
     /* Get any pending connect data */
@@ -1522,14 +1611,15 @@ WSPConnect(SOCKET Handle,
         }
     }
 
+    AFD_DbgPrint(MID_TRACE,("Ending\n"));
+
+notify:
     /* Re-enable Async Event */
     SockReenableAsyncSelectEvent(Socket, FD_WRITE);
 
     /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
     SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
 
-    AFD_DbgPrint(MID_TRACE,("Ending\n"));
-
     NtClose( SockEvent );
 
     if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT))
@@ -1589,6 +1679,12 @@ WSPShutdown(SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
 
     /* Set AFD Disconnect Type */
     switch (HowTo)
@@ -1608,7 +1704,7 @@ WSPShutdown(SOCKET Handle,
             break;
     }
 
-    DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
+    DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1000000);
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)Handle,
@@ -1663,6 +1759,19 @@ WSPGetSockName(IN SOCKET Handle,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
+
+    if (!Name || !NameLength)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
     TdiAddressSize = 
@@ -1701,13 +1810,13 @@ WSPGetSockName(IN SOCKET Handle,
 
     if (NT_SUCCESS(Status))
     {
-        if (*NameLength >= SocketAddress->Address[0].AddressLength)
+        if (*NameLength >= Socket->SharedData.SizeOfLocalAddress)
         {
             Name->sa_family = SocketAddress->Address[0].AddressType;
             RtlCopyMemory (Name->sa_data,
                            SocketAddress->Address[0].Address, 
                            SocketAddress->Address[0].AddressLength);
-            *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+            *NameLength = Socket->SharedData.SizeOfLocalAddress;
             AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
                           *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
                           ((struct sockaddr_in *)Name)->sin_port));
@@ -1751,9 +1860,29 @@ WSPGetPeerName(IN SOCKET s,
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(s);
+    if (!Socket)
+    {
+       NtClose(SockEvent);
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
+
+    if (Socket->SharedData.State != SocketConnected)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAENOTCONN;
+        return SOCKET_ERROR;
+    }
+
+    if (!Name || !NameLength)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
+    }
 
     /* Allocate a buffer for the address */
-    TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
+    TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfRemoteAddress;
     SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
 
     if ( SocketAddress == NULL )
@@ -1786,13 +1915,13 @@ WSPGetPeerName(IN SOCKET s,
 
     if (NT_SUCCESS(Status))
     {
-        if (*NameLength >= SocketAddress->Address[0].AddressLength)
+        if (*NameLength >= Socket->SharedData.SizeOfRemoteAddress)
         {
             Name->sa_family = SocketAddress->Address[0].AddressType;
             RtlCopyMemory (Name->sa_data,
                            SocketAddress->Address[0].Address, 
                            SocketAddress->Address[0].AddressLength);
-            *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+            *NameLength = Socket->SharedData.SizeOfRemoteAddress;
             AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
                           *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
                           ((struct sockaddr_in *)Name)->sin_port));
@@ -1825,9 +1954,18 @@ WSPIoctl(IN  SOCKET Handle,
          OUT LPINT lpErrno)
 {
     PSOCKET_INFORMATION Socket = NULL;
+       BOOLEAN NeedsCompletion;
+    BOOLEAN NonBlocking;
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
+    if (!Socket)
+    {
+       *lpErrno = WSAENOTSOCK;
+       return SOCKET_ERROR;
+    }
+       
+       *lpcbBytesReturned = 0;
 
     switch( dwIoControlCode )
     {
@@ -1837,18 +1975,49 @@ WSPIoctl(IN  SOCKET Handle,
                 *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
             }
-            Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
-            return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+            NonBlocking = *((PULONG)lpvInBuffer) ? TRUE : FALSE;
+            Socket->SharedData.NonBlocking = NonBlocking ? 1 : 0;
+            *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, &NonBlocking, NULL, NULL);
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                               return NO_ERROR;
         case FIONREAD:
             if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
             {
                 *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
             }
-            return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
-        default:
+            *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, NULL, (PULONG)lpvOutBuffer, NULL);
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                       {
+                               *lpcbBytesReturned = sizeof(ULONG);
+                               return NO_ERROR;
+                       }
+        case SIO_GET_EXTENSION_FUNCTION_POINTER:
             *lpErrno = WSAEINVAL;
             return SOCKET_ERROR;
+        default:
+                       *lpErrno = Socket->HelperData->WSHIoctl(Socket->HelperContext,
+                                                                                                       Handle,
+                                                                                                       Socket->TdiAddressHandle,
+                                                                                                       Socket->TdiConnectionHandle,
+                                                                                                       dwIoControlCode,
+                                                                                                       lpvInBuffer,
+                                                                                                       cbInBuffer,
+                                                                                                       lpvOutBuffer,
+                                                                                                       cbOutBuffer,
+                                                                                                       lpcbBytesReturned,
+                                                                                                       lpOverlapped,
+                                                                                                       lpCompletionRoutine,
+                                                                                                       (LPBOOL)&NeedsCompletion);
+                       
+                       if (*lpErrno != NO_ERROR)
+                               return SOCKET_ERROR;
+                       else
+                               return NO_ERROR;
     }
 }
 
@@ -1980,7 +2149,16 @@ WSPSetSockOpt(
     }
 
 
-    /* FIXME: We should handle some cases here */
+    /* FIXME: We should handle some more cases here */
+    if (level == SOL_SOCKET)
+    {
+        switch (optname)
+        {
+           case SO_BROADCAST:
+              Socket->SharedData.Broadcast = (*optval != 0) ? 1 : 0;
+              return 0;
+        }
+    }
 
 
     *lpErrno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
@@ -2086,6 +2264,7 @@ WSPCleanup(OUT LPINT lpErrno)
 int 
 GetSocketInformation(PSOCKET_INFORMATION Socket, 
                      ULONG AfdInformationClass, 
+                     PBOOLEAN Boolean OPTIONAL,
                      PULONG Ulong OPTIONAL, 
                      PLARGE_INTEGER LargeInteger OPTIONAL)
 {
@@ -2122,8 +2301,12 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
+    if (Status != STATUS_SUCCESS)
+        return -1;
+
     /* Return Information */
     if (Ulong != NULL)
     {
@@ -2133,6 +2316,10 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
     {
         *LargeInteger = InfoData.Information.LargeInteger;
     }
+    if (Boolean != NULL)
+    {
+        *Boolean = InfoData.Information.Boolean;
+    }
 
     NtClose( SockEvent );
 
@@ -2143,7 +2330,8 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
 
 int 
 SetSocketInformation(PSOCKET_INFORMATION Socket, 
-                     ULONG AfdInformationClass, 
+                     ULONG AfdInformationClass,
+                     PBOOLEAN Boolean OPTIONAL,
                      PULONG Ulong OPTIONAL, 
                      PLARGE_INTEGER LargeInteger OPTIONAL)
 {
@@ -2173,6 +2361,10 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     {
         InfoData.Information.LargeInteger = *LargeInteger;
     }
+    if (Boolean != NULL)
+    {
+        InfoData.Information.Boolean = *Boolean;
+    }
 
     AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
         AfdInformationClass, *Ulong));
@@ -2193,27 +2385,37 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     NtClose( SockEvent );
 
-    return 0;
+    return Status == STATUS_SUCCESS ? 0 : -1;
 
 }
 
 PSOCKET_INFORMATION
 GetSocketStructure(SOCKET Handle)
 {
-    ULONG i;
+    PSOCKET_INFORMATION CurrentSocket;
+
+    EnterCriticalSection(&SocketListLock);
 
-    for (i=0; i<SocketCount; i++) 
+    CurrentSocket = SocketListHead;
+    while (CurrentSocket)
     {
-        if (Sockets[i]->Handle == Handle)
+        if (CurrentSocket->Handle == Handle)
         {
-            return Sockets[i];
+            LeaveCriticalSection(&SocketListLock);
+            return CurrentSocket;
         }
+
+        CurrentSocket = CurrentSocket->NextSocket;
     }
-    return 0;
+
+    LeaveCriticalSection(&SocketListLock);
+
+    return NULL;
 }
 
 int CreateContext(PSOCKET_INFORMATION Socket)
@@ -2258,11 +2460,12 @@ int CreateContext(PSOCKET_INFORMATION Socket)
     if (Status == STATUS_PENDING)
     {
         WaitForSingleObject(SockEvent, INFINITE);
+        Status = IOSB.Status;
     }
 
     NtClose( SockEvent );
 
-    return 0;
+    return Status == STATUS_SUCCESS ? 0 : -1;
 }
 
 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
@@ -2276,6 +2479,7 @@ BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
     /* Check if the Thread Already Exists */
     if (SockAsyncThreadRefCount)
     {
+        ASSERT(SockAsyncCompletionPort);
         return TRUE;
     }
 
@@ -2286,7 +2490,11 @@ BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
                                       IO_COMPLETION_ALL_ACCESS,
                                       NULL,
                                       2); // Allow 2 threads only
-
+        if (!NT_SUCCESS(Status))
+        {
+             AFD_DbgPrint(MID_TRACE,("Failed to create completion port\n"));
+             return FALSE;
+        }
         /* Protect Handle */   
         HandleFlags.ProtectFromClose = TRUE;
         HandleFlags.Inherit = FALSE;
@@ -2730,9 +2938,8 @@ DllMain(HANDLE hInstDll,
         /* Heap to use when allocating */
         GlobalHeap = GetProcessHeap();
 
-        /* Allocate Heap for 1024 Sockets, can be expanded later */
-        Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
-        if (!Sockets) return FALSE;
+        /* Initialize the lock that protects our socket list */
+        InitializeCriticalSection(&SocketListLock);
 
         AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
 
@@ -2745,6 +2952,10 @@ DllMain(HANDLE hInstDll,
         break;
 
     case DLL_PROCESS_DETACH:
+
+        /* Delete the socket list lock */
+        DeleteCriticalSection(&SocketListLock);
+
         break;
     }