[MSAFD][WS2_32] Better WSASocket parameters check
authorPeter Hater <7element@mail.bg>
Sat, 29 Oct 2016 18:22:22 +0000 (18:22 +0000)
committerPeter Hater <7element@mail.bg>
Sat, 29 Oct 2016 18:22:22 +0000 (18:22 +0000)
CORE-12104

svn path=/trunk/; revision=73068

reactos/dll/win32/msafd/misc/dllmain.c
reactos/dll/win32/ws2_32/src/dcatalog.c

index cf99774..b156980 100644 (file)
@@ -95,13 +95,33 @@ WSPSocket(int AddressFamily,
         Protocol = SharedData->Protocol;
     }
 
-    if (AddressFamily == AF_UNSPEC && SocketType == 0 && Protocol == 0)
+    if (lpProtocolInfo)
+    {
+        if (lpProtocolInfo->iAddressFamily && AddressFamily <= 0)
+            AddressFamily = lpProtocolInfo->iAddressFamily;
+        if (lpProtocolInfo->iSocketType && SocketType <= 0)
+            SocketType = lpProtocolInfo->iSocketType;
+        if (lpProtocolInfo->iProtocol && Protocol <= 0)
+            Protocol = lpProtocolInfo->iProtocol;
+    }
+
+    /* FIXME: AF_NETDES should be AF_MAX */
+    if (AddressFamily < AF_UNSPEC || AddressFamily > AF_NETDES)
         return WSAEINVAL;
 
-    /* Set the defaults */
-    if (AddressFamily == AF_UNSPEC)
-        AddressFamily = AF_INET;
+    if (SocketType < 0 && SocketType > SOCK_SEQPACKET)
+        return WSAEINVAL;
+
+    if (Protocol < 0 && Protocol > IPPROTO_MAX)
+        return WSAEINVAL;
+
+    /* when no protocol and socket type are specified the first entry
+    * from WSAEnumProtocols that has the flag PFL_MATCHES_PROTOCOL_ZERO
+    * is returned */
+    if (SocketType == 0 && Protocol == 0 && lpProtocolInfo && (lpProtocolInfo->dwProviderFlags & PFL_MATCHES_PROTOCOL_ZERO) == 0)
+        return WSAEINVAL;
 
+    /* Set the defaults */
     if (SocketType == 0)
     {
         switch (Protocol)
@@ -117,8 +137,7 @@ WSPSocket(int AddressFamily,
             break;
         default:
             TRACE("Unknown Protocol (%d). We will try SOCK_STREAM.\n", Protocol);
-            SocketType = SOCK_STREAM;
-            break;
+            return WSAEINVAL;
         }
     }
 
@@ -137,11 +156,13 @@ WSPSocket(int AddressFamily,
             break;
         default:
             TRACE("Unknown SocketType (%d). We will try IPPROTO_TCP.\n", SocketType);
-            Protocol = IPPROTO_TCP;
-            break;
+            return WSAEINVAL;
         }
     }
 
+    if (AddressFamily == AF_UNSPEC)
+        return WSAEINVAL;
+
     /* Get Helper Data and Transport */
     Status = SockGetTdiName (&AddressFamily,
                              &SocketType,
index a2a9e0b..951f306 100644 (file)
@@ -490,6 +490,20 @@ WsTcGetEntryFromTriplet(IN PTCATALOG Catalog,
     /* Assume failure */
     *CatalogEntry = NULL;
 
+    /* Params can't be all wildcards */
+    if (af == AF_UNSPEC && type == 0 && protocol == 0)
+        return WSAEINVAL;
+
+    /* FIXME: AF_NETDES should be AF_MAX */
+    if (af < AF_UNSPEC || af > AF_NETDES)
+        return WSAEINVAL;
+
+    if (type < 0 && type > SOCK_SEQPACKET)
+        return WSAEINVAL;
+
+    if (protocol < 0 && protocol > IPPROTO_MAX)
+        return WSAEINVAL;
+
     /* Lock the catalog */
     WsTcLock();
 
@@ -527,6 +541,13 @@ WsTcGetEntryFromTriplet(IN PTCATALOG Catalog,
                       Entry->ProtocolInfo.iProtocolMaxOffset) >= protocol)) ||
                     (protocol == 0))
                 {
+                    /* Check that if type and protocol are 0 provider entry has PFL_MATCHES_PROTOCOL_ZERO flag set */
+                    if (type == 0 && protocol == 0 && (Entry->ProtocolInfo.dwProviderFlags & PFL_MATCHES_PROTOCOL_ZERO) == 0)
+                    {
+                        ErrorCode = WSAEPROTONOSUPPORT;
+                        continue;
+                    }
+
                     /* Check if it doesn't already have a provider */
                     if (!Entry->Provider)
                     {
@@ -550,12 +571,14 @@ WsTcGetEntryFromTriplet(IN PTCATALOG Catalog,
             } 
             else 
             {
-                ErrorCode = WSAESOCKTNOSUPPORT;
+                if (ErrorCode != WSAEPROTONOSUPPORT)
+                    ErrorCode = WSAESOCKTNOSUPPORT;
             }
         } 
         else 
         {
-            ErrorCode = WSAEAFNOSUPPORT;
+            if (ErrorCode != WSAEPROTONOSUPPORT && ErrorCode != WSAESOCKTNOSUPPORT)
+                ErrorCode = WSAEAFNOSUPPORT;
         }
     }