[MSAFD]
authorCameron Gutman <aicommander@gmail.com>
Sat, 11 Jun 2011 17:49:30 +0000 (17:49 +0000)
committerCameron Gutman <aicommander@gmail.com>
Sat, 11 Jun 2011 17:49:30 +0000 (17:49 +0000)
- Fix linger and graceful disconnect
- Fix a crash in WSPGetSockName and WSPGetPeerName exposed by ws2_32 winetest sock

svn path=/trunk/; revision=52193

reactos/dll/win32/msafd/misc/dllmain.c

index b2680e0..9e080e8 100644 (file)
@@ -401,6 +401,7 @@ WSPCloseSocket(IN SOCKET Handle,
     HANDLE SockEvent;
     AFD_DISCONNECT_INFO DisconnectInfo;
     SOCKET_STATE OldState;
+    LONG LingerWait = -1;
 
     /* Create the Wait Event */
     Status = NtCreateEvent(&SockEvent,
@@ -453,7 +454,6 @@ WSPCloseSocket(IN SOCKET Handle,
     /* FIXME: Should we do this on Datagram Sockets too? */
     if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
     {
-        ULONG LingerWait;
         ULONG SendsInProgress;
         ULONG SleepWait;
 
@@ -477,7 +477,11 @@ WSPCloseSocket(IN SOCKET Handle,
 
             /* Bail out if no more sends are pending */
             if (!SendsInProgress)
+            {
+                LingerWait = -1;
                 break;
+            }
+
             /* 
              * We have to execute a sleep, so it's kind of like
              * a block. If the socket is Nonblock, we cannot
@@ -502,15 +506,14 @@ WSPCloseSocket(IN SOCKET Handle,
             Sleep(SleepWait);
             LingerWait -= SleepWait;
         }
+    }
 
-        /*
-        * We have reached the timeout or sends are over.
-        * Disconnect if the timeout has been reached. 
-        */
+    if (OldState == SocketConnected)
+    {
         if (LingerWait <= 0)
         {
             DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
-            DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
+            DisconnectInfo.DisconnectType = LingerWait < 0 ? AFD_DISCONNECT_SEND : AFD_DISCONNECT_ABORT;
 
             /* Send IOCTL */
             Status = NtDeviceIoControlFile((HANDLE)Handle,
@@ -1564,6 +1567,7 @@ WSPConnect(SOCKET Handle,
     if (Status != STATUS_SUCCESS)
         goto notify;
 
+    Socket->SharedData.State = SocketConnected;
     Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
 
     /* Get any pending connect data */
@@ -1742,6 +1746,13 @@ WSPGetSockName(IN SOCKET Handle,
        return SOCKET_ERROR;
     }
 
+    if (!Name || !NameLength)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
+    }
+
     /* Allocate a buffer for the address */
     TdiAddressSize = 
                sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
@@ -1836,6 +1847,20 @@ WSPGetPeerName(IN SOCKET s,
        return SOCKET_ERROR;
     }
 
+    if (Socket->SharedData.State != SocketConnected)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAENOTCONN;
+        return SOCKET_ERROR;
+    }
+
+    if (!Name || !NameLength)
+    {
+        NtClose(SockEvent);
+        *lpErrno = WSAEFAULT;
+        return SOCKET_ERROR;
+    }
+
     /* Allocate a buffer for the address */
     TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
     SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);