- Update to r53061
[reactos.git] / dll / win32 / msafd / misc / dllmain.c
index 7ce6976..87e5dcd 100644 (file)
@@ -264,7 +264,7 @@ WSPSocket(int AddressFamily,
     /* Save Group Info */
     if (g != 0)
     {
-        GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
+        GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, NULL, NULL, &GroupData);
         Socket->SharedData.GroupID = GroupData.u.LowPart;
         Socket->SharedData.GroupType = GroupData.u.HighPart;
     }
@@ -272,11 +272,13 @@ WSPSocket(int AddressFamily,
     /* Get Window Sizes and Save them */
     GetSocketInformation (Socket,
                           AFD_INFO_SEND_WINDOW_SIZE,
+                          NULL,
                           &Socket->SharedData.SizeOfSendBuffer,
                           NULL);
 
     GetSocketInformation (Socket,
                           AFD_INFO_RECEIVE_WINDOW_SIZE,
+                          NULL,
                           &Socket->SharedData.SizeOfRecvBuffer,
                           NULL);
 
@@ -467,6 +469,7 @@ WSPCloseSocket(IN SOCKET Handle,
             /* Find out how many Sends are in Progress */
             if (GetSocketInformation(Socket,
                                      AFD_INFO_SENDS_IN_PROGRESS,
+                                     NULL,
                                      &SendsInProgress,
                                      NULL))
             {
@@ -514,24 +517,28 @@ WSPCloseSocket(IN SOCKET Handle,
         {
             DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
             DisconnectInfo.DisconnectType = LingerWait < 0 ? AFD_DISCONNECT_SEND : AFD_DISCONNECT_ABORT;
-
-            /* Send IOCTL */
-            Status = NtDeviceIoControlFile((HANDLE)Handle,
-                                           SockEvent,
-                                           NULL,
-                                           NULL,
-                                           &IoStatusBlock,
-                                           IOCTL_AFD_DISCONNECT,
-                                           &DisconnectInfo,
-                                           sizeof(DisconnectInfo),
-                                           NULL,
-                                           0);
-
-            /* Wait for return */
-            if (Status == STATUS_PENDING)
+            
+            if (((DisconnectInfo.DisconnectType & AFD_DISCONNECT_SEND) && (!Socket->SharedData.SendShutdown)) ||
+                ((DisconnectInfo.DisconnectType & AFD_DISCONNECT_ABORT) && (!Socket->SharedData.ReceiveShutdown)))
             {
-                WaitForSingleObject(SockEvent, INFINITE);
-                Status = IoStatusBlock.Status;
+                /* Send IOCTL */
+                Status = NtDeviceIoControlFile((HANDLE)Handle,
+                                               SockEvent,
+                                               NULL,
+                                               NULL,
+                                               &IoStatusBlock,
+                                               IOCTL_AFD_DISCONNECT,
+                                               &DisconnectInfo,
+                                               sizeof(DisconnectInfo),
+                                               NULL,
+                                               0);
+                
+                /* Wait for return */
+                if (Status == STATUS_PENDING)
+                {
+                    WaitForSingleObject(SockEvent, INFINITE);
+                    Status = IoStatusBlock.Status;
+                }
             }
         }
     }
@@ -799,7 +806,8 @@ WSPSelect(IN int nfds,
     IO_STATUS_BLOCK     IOSB;
     PAFD_POLL_INFO      PollInfo;
     NTSTATUS            Status;
-    LONG                HandleCount, OutCount = 0;
+    ULONG               HandleCount;
+    LONG                OutCount = 0;
     ULONG               PollBufferSize;
     PVOID               PollBuffer;
     ULONG               i, j = 0, x;
@@ -816,7 +824,7 @@ WSPSelect(IN int nfds,
 
     if ( HandleCount == 0 )
     {
-        AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
+        AFD_DbgPrint(MAX_TRACE,("HandleCount: %u. Return SOCKET_ERROR\n",
                      HandleCount));
         if (lpErrno) *lpErrno = WSAEINVAL;
         return SOCKET_ERROR;
@@ -824,7 +832,7 @@ WSPSelect(IN int nfds,
 
     PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
 
-    AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n", 
+    AFD_DbgPrint(MID_TRACE,("HandleCount: %u BufferSize: %u\n", 
                  HandleCount, PollBufferSize));
 
     /* Convert Timeout to NT Format */
@@ -1696,7 +1704,7 @@ WSPShutdown(SOCKET Handle,
             break;
     }
 
-    DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
+    DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1000000);
 
     /* Send IOCTL */
     Status = NtDeviceIoControlFile((HANDLE)Handle,
@@ -1802,13 +1810,13 @@ WSPGetSockName(IN SOCKET Handle,
 
     if (NT_SUCCESS(Status))
     {
-        if (*NameLength >= SocketAddress->Address[0].AddressLength)
+        if (*NameLength >= Socket->SharedData.SizeOfLocalAddress)
         {
             Name->sa_family = SocketAddress->Address[0].AddressType;
             RtlCopyMemory (Name->sa_data,
                            SocketAddress->Address[0].Address, 
                            SocketAddress->Address[0].AddressLength);
-            *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+            *NameLength = Socket->SharedData.SizeOfLocalAddress;
             AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
                           *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
                           ((struct sockaddr_in *)Name)->sin_port));
@@ -1874,7 +1882,7 @@ WSPGetPeerName(IN SOCKET s,
     }
 
     /* Allocate a buffer for the address */
-    TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
+    TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfRemoteAddress;
     SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
 
     if ( SocketAddress == NULL )
@@ -1907,13 +1915,13 @@ WSPGetPeerName(IN SOCKET s,
 
     if (NT_SUCCESS(Status))
     {
-        if (*NameLength >= SocketAddress->Address[0].AddressLength)
+        if (*NameLength >= Socket->SharedData.SizeOfRemoteAddress)
         {
             Name->sa_family = SocketAddress->Address[0].AddressType;
             RtlCopyMemory (Name->sa_data,
                            SocketAddress->Address[0].Address, 
                            SocketAddress->Address[0].AddressLength);
-            *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+            *NameLength = Socket->SharedData.SizeOfRemoteAddress;
             AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
                           *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
                           ((struct sockaddr_in *)Name)->sin_port));
@@ -1947,6 +1955,7 @@ WSPIoctl(IN  SOCKET Handle,
 {
     PSOCKET_INFORMATION Socket = NULL;
        BOOLEAN NeedsCompletion;
+    BOOLEAN NonBlocking;
 
     /* Get the Socket Structure associate to this Socket*/
     Socket = GetSocketStructure(Handle);
@@ -1966,8 +1975,9 @@ WSPIoctl(IN  SOCKET Handle,
                 *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
             }
-            Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
-            *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+            NonBlocking = *((PULONG)lpvInBuffer) ? TRUE : FALSE;
+            Socket->SharedData.NonBlocking = NonBlocking ? 1 : 0;
+            *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, &NonBlocking, NULL, NULL);
                        if (*lpErrno != NO_ERROR)
                                return SOCKET_ERROR;
                        else
@@ -1978,7 +1988,7 @@ WSPIoctl(IN  SOCKET Handle,
                 *lpErrno = WSAEFAULT;
                 return SOCKET_ERROR;
             }
-            *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
+            *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, NULL, (PULONG)lpvOutBuffer, NULL);
                        if (*lpErrno != NO_ERROR)
                                return SOCKET_ERROR;
                        else
@@ -2254,6 +2264,7 @@ WSPCleanup(OUT LPINT lpErrno)
 int 
 GetSocketInformation(PSOCKET_INFORMATION Socket, 
                      ULONG AfdInformationClass, 
+                     PBOOLEAN Boolean OPTIONAL,
                      PULONG Ulong OPTIONAL, 
                      PLARGE_INTEGER LargeInteger OPTIONAL)
 {
@@ -2305,6 +2316,10 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
     {
         *LargeInteger = InfoData.Information.LargeInteger;
     }
+    if (Boolean != NULL)
+    {
+        *Boolean = InfoData.Information.Boolean;
+    }
 
     NtClose( SockEvent );
 
@@ -2315,7 +2330,8 @@ GetSocketInformation(PSOCKET_INFORMATION Socket,
 
 int 
 SetSocketInformation(PSOCKET_INFORMATION Socket, 
-                     ULONG AfdInformationClass, 
+                     ULONG AfdInformationClass,
+                     PBOOLEAN Boolean OPTIONAL,
                      PULONG Ulong OPTIONAL, 
                      PLARGE_INTEGER LargeInteger OPTIONAL)
 {
@@ -2345,6 +2361,10 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
     {
         InfoData.Information.LargeInteger = *LargeInteger;
     }
+    if (Boolean != NULL)
+    {
+        InfoData.Information.Boolean = *Boolean;
+    }
 
     AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
         AfdInformationClass, *Ulong));