Git conversion: Make reactos the root directory, move rosapps, rostests, wallpapers...
[reactos.git] / dll / win32 / msafd / misc / helpers.c
diff --git a/dll/win32/msafd/misc/helpers.c b/dll/win32/msafd/misc/helpers.c
new file mode 100644 (file)
index 0000000..aab52ba
--- /dev/null
@@ -0,0 +1,524 @@
+/*
+ * COPYRIGHT:   See COPYING in the top level directory
+ * PROJECT:     ReactOS Ancillary Function Driver DLL
+ * FILE:        dll/win32/msafd/misc/helpers.c
+ * PURPOSE:     Helper DLL management
+ * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
+ *                             Alex Ionescu (alex@relsoft.net)
+ * REVISIONS:
+ *   CSH 01/09-2000 Created
+ *      Alex 16/07/2004 - Complete Rewrite
+ */
+
+#include <msafd.h>
+
+#include <winreg.h>
+
+#include <wine/debug.h>
+WINE_DEFAULT_DEBUG_CHANNEL(msafd);
+
+CRITICAL_SECTION HelperDLLDatabaseLock;
+LIST_ENTRY HelperDLLDatabaseListHead;
+
+
+INT
+SockGetTdiName(
+    PINT AddressFamily,
+    PINT SocketType,
+    PINT Protocol,
+    GROUP Group,
+    DWORD Flags,
+    PUNICODE_STRING TransportName,
+    PVOID *HelperDllContext,
+    PHELPER_DATA *HelperDllData,
+    PDWORD Events)
+{
+    PHELPER_DATA        HelperData;
+    PWSTR               Transports;
+    PWSTR               Transport;
+    PWINSOCK_MAPPING   Mapping;
+    PLIST_ENTRY                Helpers;
+    INT                 Status;
+
+    TRACE("AddressFamily %p, SocketType %p, Protocol %p, Group %u, Flags %lx, TransportName %p, HelperDllContext %p, HelperDllData %p, Events %p\n",
+        AddressFamily, SocketType, Protocol, Group, Flags, TransportName, HelperDllContext, HelperDllData, Events);
+
+    /* Check in our Current Loaded Helpers */
+    for (Helpers = SockHelpersListHead.Flink;
+         Helpers != &SockHelpersListHead;
+         Helpers = Helpers->Flink ) {
+
+        HelperData = CONTAINING_RECORD(Helpers, HELPER_DATA, Helpers);
+
+        /* See if this Mapping works for us */
+        if (SockIsTripleInMapping (HelperData->Mapping,
+                                   *AddressFamily,
+                                   *SocketType,
+                                   *Protocol)) {
+
+            /* Call the Helper Dll function get the Transport Name */
+            if (HelperData->WSHOpenSocket2 == NULL ) {
+
+                /* DLL Doesn't support WSHOpenSocket2, call the old one */
+                HelperData->WSHOpenSocket(AddressFamily,
+                                          SocketType,
+                                          Protocol,
+                                          TransportName,
+                                          HelperDllContext,
+                                          Events
+                                          );
+            } else {
+                HelperData->WSHOpenSocket2(AddressFamily,
+                                           SocketType,
+                                           Protocol,
+                                           Group,
+                                           Flags,
+                                           TransportName,
+                                           HelperDllContext,
+                                           Events
+                                           );
+            }
+
+            /* Return the Helper Pointers */
+            *HelperDllData = HelperData;
+            return NO_ERROR;
+        }
+    }
+
+    /* Get the Transports available */
+    Status = SockLoadTransportList(&Transports);
+
+    /* Check for error */
+    if (Status) {
+        WARN("Can't get transport list\n");
+        return Status;
+    }
+
+    /* Loop through each transport until we find one that can satisfy us */
+    for (Transport = Transports;
+         *Transports != 0;
+         Transport += wcslen(Transport) + 1) {
+        TRACE("Transport: %S\n", Transports);
+
+        /* See what mapping this Transport supports */
+        Status = SockLoadTransportMapping(Transport, &Mapping);
+
+        /* Check for error */
+        if (Status) {
+            ERR("Can't get mapping for %S\n", Transports);
+            HeapFree(GlobalHeap, 0, Transports);
+            return Status;
+        }
+
+        /* See if this Mapping works for us */
+        if (SockIsTripleInMapping(Mapping, *AddressFamily, *SocketType, *Protocol)) {
+
+            /* It does, so load the DLL associated with it */
+            Status = SockLoadHelperDll(Transport, Mapping, &HelperData);
+
+            /* Check for error */
+            if (Status) {
+                ERR("Can't load helper DLL for Transport %S.\n", Transport);
+                HeapFree(GlobalHeap, 0, Transports);
+                HeapFree(GlobalHeap, 0, Mapping);
+                return Status;
+            }
+
+            /* Call the Helper Dll function get the Transport Name */
+            if (HelperData->WSHOpenSocket2 == NULL) {
+                /* DLL Doesn't support WSHOpenSocket2, call the old one */
+                HelperData->WSHOpenSocket(AddressFamily,
+                                          SocketType,
+                                          Protocol,
+                                          TransportName,
+                                          HelperDllContext,
+                                          Events
+                                          );
+            } else {
+                HelperData->WSHOpenSocket2(AddressFamily,
+                                           SocketType,
+                                           Protocol,
+                                           Group,
+                                           Flags,
+                                           TransportName,
+                                           HelperDllContext,
+                                           Events
+                                           );
+            }
+
+            /* Return the Helper Pointers */
+            *HelperDllData = HelperData;
+            /* We actually cache these ... the can't be freed yet */
+            /*HeapFree(GlobalHeap, 0, Transports);*/
+            /*HeapFree(GlobalHeap, 0, Mapping);*/
+            return NO_ERROR;
+        }
+
+        HeapFree(GlobalHeap, 0, Mapping);
+    }
+    HeapFree(GlobalHeap, 0, Transports);
+    return WSAEINVAL;
+}
+
+INT
+SockLoadTransportMapping(
+    PWSTR TransportName,
+    PWINSOCK_MAPPING *Mapping)
+{
+    PWSTR               TransportKey;
+    HKEY                KeyHandle;
+    ULONG               MappingSize;
+    LONG                Status;
+
+    TRACE("TransportName %ws\n", TransportName);
+
+    /* Allocate a Buffer */
+    TransportKey = HeapAlloc(GlobalHeap, 0, (54 + wcslen(TransportName)) * sizeof(WCHAR));
+
+    /* Check for error */
+    if (TransportKey == NULL) {
+        ERR("Buffer allocation failed\n");
+        return WSAEINVAL;
+    }
+
+    /* Generate the right key name */
+    wcscpy(TransportKey, L"System\\CurrentControlSet\\Services\\");
+    wcscat(TransportKey, TransportName);
+    wcscat(TransportKey, L"\\Parameters\\Winsock");
+
+    /* Open the Key */
+    Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, TransportKey, 0, KEY_READ, &KeyHandle);
+
+    /* We don't need the Transport Key anymore */
+    HeapFree(GlobalHeap, 0, TransportKey);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading transport mapping registry\n");
+        return WSAEINVAL;
+    }
+
+    /* Find out how much space we need for the Mapping */
+    Status = RegQueryValueExW(KeyHandle, L"Mapping", NULL, NULL, NULL, &MappingSize);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading transport mapping registry\n");
+        return WSAEINVAL;
+    }
+
+    /* Allocate Memory for the Mapping */
+    *Mapping = HeapAlloc(GlobalHeap, 0, MappingSize);
+
+    /* Check for error */
+    if (*Mapping == NULL) {
+        ERR("Buffer allocation failed\n");
+        return WSAEINVAL;
+    }
+
+    /* Read the Mapping */
+    Status = RegQueryValueExW(KeyHandle, L"Mapping", NULL, NULL, (LPBYTE)*Mapping, &MappingSize);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading transport mapping registry\n");
+        HeapFree(GlobalHeap, 0, *Mapping);
+        return WSAEINVAL;
+    }
+
+    /* Close key and return */
+    RegCloseKey(KeyHandle);
+    return 0;
+}
+
+INT
+SockLoadTransportList(
+    PWSTR *TransportList)
+{
+    ULONG      TransportListSize;
+    HKEY       KeyHandle;
+    LONG       Status;
+
+    TRACE("Called\n");
+
+    /* Open the Transports Key */
+    Status = RegOpenKeyExW (HKEY_LOCAL_MACHINE,
+                            L"SYSTEM\\CurrentControlSet\\Services\\Winsock\\Parameters",
+                            0,
+                            KEY_READ,
+                            &KeyHandle);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading transport list registry\n");
+        return WSAEINVAL;
+    }
+
+    /* Get the Transport List Size */
+    Status = RegQueryValueExW(KeyHandle,
+                              L"Transports",
+                              NULL,
+                              NULL,
+                              NULL,
+                              &TransportListSize);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading transport list registry\n");
+        return WSAEINVAL;
+    }
+
+    /* Allocate Memory for the Transport List */
+    *TransportList = HeapAlloc(GlobalHeap, 0, TransportListSize);
+
+    /* Check for error */
+    if (*TransportList == NULL) {
+        ERR("Buffer allocation failed\n");
+        return WSAEINVAL;
+    }
+
+    /* Get the Transports */
+    Status = RegQueryValueExW (KeyHandle,
+                               L"Transports",
+                               NULL,
+                               NULL,
+                               (LPBYTE)*TransportList,
+                               &TransportListSize);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading transport list registry\n");
+        HeapFree(GlobalHeap, 0, *TransportList);
+        return WSAEINVAL;
+    }
+
+    /* Close key and return */
+    RegCloseKey(KeyHandle);
+    return 0;
+}
+
+INT
+SockLoadHelperDll(
+    PWSTR TransportName,
+    PWINSOCK_MAPPING Mapping,
+    PHELPER_DATA *HelperDllData)
+{
+    PHELPER_DATA        HelperData;
+    PWSTR               HelperDllName;
+    PWSTR               FullHelperDllName;
+    PWSTR               HelperKey;
+    HKEY                KeyHandle;
+    ULONG               DataSize;
+    LONG                Status;
+
+    /* Allocate space for the Helper Structure and TransportName */
+    HelperData = HeapAlloc(GlobalHeap, 0, sizeof(*HelperData) + (wcslen(TransportName) + 1) * sizeof(WCHAR));
+
+    /* Check for error */
+    if (HelperData == NULL) {
+        ERR("Buffer allocation failed\n");
+        return WSAEINVAL;
+    }
+
+    /* Allocate Space for the Helper DLL Key */
+    HelperKey = HeapAlloc(GlobalHeap, 0, (54 + wcslen(TransportName)) * sizeof(WCHAR));
+
+    /* Check for error */
+    if (HelperKey == NULL) {
+        ERR("Buffer allocation failed\n");
+        HeapFree(GlobalHeap, 0, HelperData);
+        return WSAEINVAL;
+    }
+
+    /* Generate the right key name */
+    wcscpy(HelperKey, L"System\\CurrentControlSet\\Services\\");
+    wcscat(HelperKey, TransportName);
+    wcscat(HelperKey, L"\\Parameters\\Winsock");
+
+    /* Open the Key */
+    Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, HelperKey, 0, KEY_READ, &KeyHandle);
+
+    HeapFree(GlobalHeap, 0, HelperKey);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading helper DLL parameters\n");
+        HeapFree(GlobalHeap, 0, HelperData);
+        return WSAEINVAL;
+    }
+
+    /* Read Size of SockAddr Structures */
+    DataSize = sizeof(HelperData->MinWSAddressLength);
+    HelperData->MinWSAddressLength = 16;
+    RegQueryValueExW (KeyHandle,
+                      L"MinSockaddrLength",
+                      NULL,
+                      NULL,
+                      (LPBYTE)&HelperData->MinWSAddressLength,
+                      &DataSize);
+    DataSize = sizeof(HelperData->MinWSAddressLength);
+    HelperData->MaxWSAddressLength = 16;
+    RegQueryValueExW (KeyHandle,
+                      L"MaxSockaddrLength",
+                      NULL,
+                      NULL,
+                      (LPBYTE)&HelperData->MaxWSAddressLength,
+                      &DataSize);
+
+    /* Size of TDI Structures */
+    HelperData->MinTDIAddressLength = HelperData->MinWSAddressLength + 6;
+    HelperData->MaxTDIAddressLength = HelperData->MaxWSAddressLength + 6;
+
+    /* Read Delayed Acceptance Setting */
+    DataSize = sizeof(DWORD);
+    HelperData->UseDelayedAcceptance = FALSE;
+    RegQueryValueExW (KeyHandle,
+                      L"UseDelayedAcceptance",
+                      NULL,
+                      NULL,
+                      (LPBYTE)&HelperData->UseDelayedAcceptance,
+                      &DataSize);
+
+    /* Allocate Space for the Helper DLL Names */
+    HelperDllName = HeapAlloc(GlobalHeap, 0, 512);
+
+    /* Check for error */
+    if (HelperDllName == NULL) {
+        ERR("Buffer allocation failed\n");
+        HeapFree(GlobalHeap, 0, HelperData);
+        return WSAEINVAL;
+    }
+
+    FullHelperDllName = HeapAlloc(GlobalHeap, 0, 512);
+
+    /* Check for error */
+    if (FullHelperDllName == NULL) {
+        ERR("Buffer allocation failed\n");
+        HeapFree(GlobalHeap, 0, HelperDllName);
+        HeapFree(GlobalHeap, 0, HelperData);
+        return WSAEINVAL;
+    }
+
+    /* Get the name of the Helper DLL*/
+    DataSize = 512;
+    Status = RegQueryValueExW (KeyHandle,
+                               L"HelperDllName",
+                               NULL,
+                               NULL,
+                               (LPBYTE)HelperDllName,
+                               &DataSize);
+
+    /* Check for error */
+    if (Status) {
+        ERR("Error reading helper DLL parameters\n");
+        HeapFree(GlobalHeap, 0, FullHelperDllName);
+        HeapFree(GlobalHeap, 0, HelperDllName);
+        HeapFree(GlobalHeap, 0, HelperData);
+        return WSAEINVAL;
+    }
+
+    /* Get the Full name, expanding Environment Strings */
+    ExpandEnvironmentStringsW (HelperDllName,
+                               FullHelperDllName,
+                               256);
+
+    /* Load the DLL */
+    HelperData->hInstance = LoadLibraryW(FullHelperDllName);
+
+    HeapFree(GlobalHeap, 0, HelperDllName);
+    HeapFree(GlobalHeap, 0, FullHelperDllName);
+
+    if (HelperData->hInstance == NULL) {
+        ERR("Error loading helper DLL\n");
+        HeapFree(GlobalHeap, 0, HelperData);
+        return WSAEINVAL;
+    }
+
+    /* Close Key */
+    RegCloseKey(KeyHandle);
+
+    /* Get the Pointers to the Helper Routines */
+    HelperData->WSHOpenSocket =        (PWSH_OPEN_SOCKET)
+                                                                       GetProcAddress(HelperData->hInstance,
+                                                                       "WSHOpenSocket");
+    HelperData->WSHOpenSocket2 = (PWSH_OPEN_SOCKET2)
+                                                                       GetProcAddress(HelperData->hInstance,
+                                                                       "WSHOpenSocket2");
+    HelperData->WSHJoinLeaf = (PWSH_JOIN_LEAF)
+                                                               GetProcAddress(HelperData->hInstance,
+                                                               "WSHJoinLeaf");
+    HelperData->WSHNotify = (PWSH_NOTIFY)
+                                                               GetProcAddress(HelperData->hInstance, "WSHNotify");
+    HelperData->WSHGetSocketInformation = (PWSH_GET_SOCKET_INFORMATION)
+                                                                                       GetProcAddress(HelperData->hInstance,
+                                                                                       "WSHGetSocketInformation");
+    HelperData->WSHSetSocketInformation = (PWSH_SET_SOCKET_INFORMATION)
+                                                                                       GetProcAddress(HelperData->hInstance,
+                                                                                       "WSHSetSocketInformation");
+    HelperData->WSHGetSockaddrType = (PWSH_GET_SOCKADDR_TYPE)
+                                                                               GetProcAddress(HelperData->hInstance,
+                                                                               "WSHGetSockaddrType");
+    HelperData->WSHGetWildcardSockaddr = (PWSH_GET_WILDCARD_SOCKADDR)
+                                                                                       GetProcAddress(HelperData->hInstance,
+                                                                                       "WSHGetWildcardSockaddr");
+    HelperData->WSHGetBroadcastSockaddr = (PWSH_GET_BROADCAST_SOCKADDR)
+                                                                                       GetProcAddress(HelperData->hInstance,
+                                                                                       "WSHGetBroadcastSockaddr");
+    HelperData->WSHAddressToString = (PWSH_ADDRESS_TO_STRING)
+                                                                               GetProcAddress(HelperData->hInstance,
+                                                                               "WSHAddressToString");
+    HelperData->WSHStringToAddress = (PWSH_STRING_TO_ADDRESS)
+                                                                               GetProcAddress(HelperData->hInstance,
+                                                                               "WSHStringToAddress");
+    HelperData->WSHIoctl = (PWSH_IOCTL)
+                                                       GetProcAddress(HelperData->hInstance,
+                                                       "WSHIoctl");
+
+    /* Save the Mapping Structure and transport name */
+    HelperData->Mapping = Mapping;
+    wcscpy(HelperData->TransportName, TransportName);
+
+    /* Increment Reference Count */
+    HelperData->RefCount = 1;
+
+    /* Add it to our list */
+    InsertHeadList(&SockHelpersListHead, &HelperData->Helpers);
+
+    /* Return Pointers */
+    *HelperDllData = HelperData;
+    return 0;
+}
+
+BOOL
+SockIsTripleInMapping(
+    PWINSOCK_MAPPING Mapping,
+    INT AddressFamily,
+    INT SocketType,
+    INT Protocol)
+{
+    /* The Windows version returns more detailed information on which of the 3 parameters failed...we should do this later */
+    ULONG    Row;
+
+    TRACE("Called, Mapping rows = %d\n", Mapping->Rows);
+
+    /* Loop through Mapping to Find a matching one */
+    for (Row = 0; Row < Mapping->Rows; Row++) {
+        TRACE("Examining: row %d: AF %d type %d proto %d\n",
+                               Row,
+                               (INT)Mapping->Mapping[Row].AddressFamily,
+                               (INT)Mapping->Mapping[Row].SocketType,
+                               (INT)Mapping->Mapping[Row].Protocol);
+
+        /* Check of all three values Match */
+        if (((INT)Mapping->Mapping[Row].AddressFamily == AddressFamily) &&
+            ((INT)Mapping->Mapping[Row].SocketType == SocketType) &&
+            ((INT)Mapping->Mapping[Row].Protocol == Protocol)) {
+            TRACE("Found\n");
+            return TRUE;
+        }
+    }
+    WARN("Not found\n");
+    return FALSE;
+}
+
+/* EOF */