[MSWSOCK]
[reactos.git] / reactos / dll / win32 / mswsock / nsplookup.c
index bf491ef..de3efe5 100644 (file)
 #include <svcguid.h>
 #include <iptypes.h>
 #include <strsafe.h>
+#include <winreg.h>
 
 #include "mswhelper.h"
 
+#define NDEBUG
+#include <debug.h>
+
 #define NSP_CALLID_DNS 0x0001
 #define NSP_CALLID_HOSTNAME 0x0002
 #define NSP_CALLID_HOSTBYNAME 0x0003
@@ -459,310 +463,71 @@ NSP_GetHostNameHeapAllocW(_Out_ WCHAR** hostname)
     return ERROR_SUCCESS;
 }
 
-/* This function is far from perfect but it works enough */
-IP4_ADDRESS
-FindEntryInHosts(IN CONST WCHAR FAR* wname)
-{
-    BOOL Found = FALSE;
-    HANDLE HostsFile;
-    CHAR HostsDBData[BUFSIZ] = {0};
-    PCHAR SystemDirectory = HostsDBData;
-    PCHAR HostsLocation = "\\drivers\\etc\\hosts";
-    PCHAR AddressStr, DnsName = NULL, AddrTerm, NameSt, NextLine, ThisLine, Comment;
-    UINT SystemDirSize = sizeof(HostsDBData) - 1, ValidData = 0;
-    DWORD ReadSize;
-    DWORD Address;
-    CHAR name[MAX_HOSTNAME_LEN + 1];
-
-    wcstombs(name, wname, MAX_HOSTNAME_LEN);
-
-    /* We assume that the parameters are valid */
-    if (!GetSystemDirectoryA(SystemDirectory, SystemDirSize))
-    {
-        WSASetLastError(WSANO_RECOVERY);
-        //WS_DbgPrint(MIN_TRACE, ("Could not get windows system directory.\n"));
-        return 0; /* Can't get system directory */
-    }
-
-    strncat(SystemDirectory, HostsLocation, SystemDirSize);
-
-    HostsFile = CreateFileA(SystemDirectory,
-                            GENERIC_READ,
-                            FILE_SHARE_READ,
-                            NULL,
-                            OPEN_EXISTING,
-                            FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN,
-                            NULL);
-    if (HostsFile == INVALID_HANDLE_VALUE)
-    {
-        WSASetLastError(WSANO_RECOVERY);
-        return 0;
-    }
-
-    while (!Found && ReadFile(HostsFile,
-                              HostsDBData + ValidData,
-                              sizeof(HostsDBData) - ValidData,
-                              &ReadSize,
-                              NULL))
-    {
-        ValidData += ReadSize;
-        ReadSize = 0;
-        NextLine = ThisLine = HostsDBData;
-
-        /* Find the beginning of the next line */
-        while ((NextLine < HostsDBData + ValidData) &&
-               (*NextLine != '\r') &&
-               (*NextLine != '\n'))
-        {
-            NextLine++;
-        }
-
-        /* Zero and skip, so we can treat what we have as a string */
-        if (NextLine > HostsDBData + ValidData)
-            break;
-
-        *NextLine = 0;
-        NextLine++;
-
-        Comment = strchr(ThisLine, '#');
-        if (Comment)
-            *Comment = 0; /* Terminate at comment start */
-
-        AddressStr = ThisLine;
-        /* Find the first space separating the IP address from the DNS name */
-        AddrTerm = strchr(ThisLine, ' ');
-        if (AddrTerm)
-        {
-            /* Terminate the address string */
-            *AddrTerm = 0;
-
-            /* Find the last space before the DNS name */
-            NameSt = strrchr(ThisLine, ' ');
-
-            /* If there is only one space (the one we removed above), then just use the address terminator */
-            if (!NameSt)
-                NameSt = AddrTerm;
-
-            /* Move from the space to the first character of the DNS name */
-            NameSt++;
-
-            DnsName = NameSt;
-
-            if (!strcmp(name, DnsName))
-            {
-                Found = TRUE;
-                break;
-            }
-        }
-
-        /* Get rid of everything we read so far */
-        while (NextLine <= HostsDBData + ValidData &&
-               isspace (*NextLine))
-        {
-            NextLine++;
-        }
-
-        if (HostsDBData + ValidData - NextLine <= 0)
-            break;
-
-        //WS_DbgPrint(MAX_TRACE,("About to move %d chars\n",
-        //            HostsDBData + ValidData - NextLine));
-
-        memmove(HostsDBData, NextLine, HostsDBData + ValidData - NextLine);
-        ValidData -= NextLine - HostsDBData;
-        //WS_DbgPrint(MAX_TRACE,("Valid bytes: %d\n", ValidData));
-    }
-
-    CloseHandle(HostsFile);
-
-    if (!Found)
-    {
-        //WS_DbgPrint(MAX_TRACE,("Not found\n"));
-        WSASetLastError(WSANO_DATA);
-        return 0;
-    }
-
-    if (strstr(AddressStr, ":"))
-    {
-       //DbgPrint("AF_INET6 NOT SUPPORTED!\n");
-       WSASetLastError(WSAEINVAL);
-       return 0;
-    }
-
-    Address = inet_addr(AddressStr);
-    if (Address == INADDR_NONE)
-    {
-        WSASetLastError(WSAEINVAL);
-        return 0;
-    }
-
-    return Address;
-}
-
 INT
 NSP_GetHostByNameHeapAllocW(_In_ WCHAR* name,
                             _In_ GUID* lpProviderId,
                             _Out_ PWSHOSTINFOINTERN hostinfo)
 {
     HANDLE hHeap = GetProcessHeap();
-    enum addr_type
-    {
-        GH_INVALID,
-        GH_IPV6,
-        GH_IPV4,
-        GH_RFC1123_DNS
-    };
-    typedef enum addr_type addr_type;
-    addr_type addr;
-    INT ret = 0;
-    WCHAR* found = 0;
     DNS_STATUS dns_status = {0};
     /* include/WinDNS.h -- look up DNS_RECORD on MSDN */
     PDNS_RECORD dp;
     PDNS_RECORD curr;
-    WCHAR* tmpHostnameW;
-    CHAR* tmpHostnameA;
-    IP4_ADDRESS address;
     INT result = ERROR_SUCCESS;
 
     /* needed to be cleaned up if != NULL */
-    tmpHostnameW = NULL;
     dp = NULL;
 
-    addr = GH_INVALID;
-
     if (name == NULL)
     {
         result = ERROR_INVALID_PARAMETER;
         goto cleanup;
     }
 
-    /* Hostname "" / "localhost"
-       - convert to "computername" */
-    if ((wcscmp(L"", name) == 0) /*||
-        (wcsicmp(L"localhost", name) == 0)*/)
+    /* DNS_TYPE_A: include/WinDNS.h */
+    /* DnsQuery -- lib/dnsapi/dnsapi/query.c */
+    dns_status = DnsQuery(name,
+                          DNS_TYPE_A,
+                          DNS_QUERY_STANDARD,
+                          /* extra dns servers */ 0,
+                          &dp,
+                          0);
+
+    if (dns_status == ERROR_INVALID_NAME)
     {
-        ret = NSP_GetHostNameHeapAllocW(&tmpHostnameW);
-        if (ret != ERROR_SUCCESS)
-        {
-            result = ret;
-            goto cleanup;
-        }
-        name = tmpHostnameW;
+        WSASetLastError(WSAEFAULT);
+        result = ERROR_INVALID_PARAMETER;
+        goto cleanup;
     }
 
-    /* Is it an IPv6 address? */
-    found = wcschr(name, L':');
-    if (found != NULL)
+    if ((dns_status != 0) || (dp == NULL))
     {
-        addr = GH_IPV6;
-        goto act;
+        result = WSAHOST_NOT_FOUND;
+        goto cleanup;
     }
 
-    /* Is it an IPv4 address? */
-    if (!iswalpha(name[0]))
+    //ASSERT(dp->wType == DNS_TYPE_A);
+    //ASSERT(dp->wDataLength == sizeof(DNS_A_DATA));
+    curr = dp;
+    while ((curr->pNext != NULL) || (curr->wType != DNS_TYPE_A))
     {
-        addr = GH_IPV4;
-        goto act;
+        curr = curr->pNext;
     }
 
-    addr = GH_RFC1123_DNS;
-
-/* Broken out in case we want to get fancy later */
-act:
-    switch (addr)
+    if (curr->wType != DNS_TYPE_A)
     {
-        case GH_IPV6:
-            WSASetLastError(ERROR_CALL_NOT_IMPLEMENTED);
-            result = ERROR_CALL_NOT_IMPLEMENTED;
-            goto cleanup;
-        break;
-
-        case GH_INVALID:
-            WSASetLastError(WSAEFAULT);
-            result = ERROR_INVALID_PARAMETER;
-            goto cleanup;
-        break;
-
-        /* Note: If passed an IP address, MSDN says that gethostbyname()
-                 treats it as an unknown host.
-           This is different from the unix implementation. Use inet_addr()
-        */
-        case GH_IPV4:
-        case GH_RFC1123_DNS:
-        /* DNS_TYPE_A: include/WinDNS.h */
-        /* DnsQuery -- lib/dnsapi/dnsapi/query.c */
-
-        /* Look for the DNS name in the hosts file */
-        if ((address = FindEntryInHosts(name)) != 0)
-        {
-            hostinfo->hostnameW = StrCpyHeapAllocW(hHeap, name);
-            hostinfo->addr4 = address;
-            result = ERROR_SUCCESS;
-            goto cleanup;
-        }
-
-        tmpHostnameA = StrW2AHeapAlloc(hHeap, name);
-        dns_status = DnsQuery(tmpHostnameA,
-                              DNS_TYPE_A,
-                              DNS_QUERY_STANDARD,
-                              /* extra dns servers */ 0,
-                              &dp,
-                              0);
-        HeapFree(hHeap, 0, tmpHostnameA);
-
-        if ((dns_status != 0) || (dp == NULL))
-        {
-            result = WSAHOST_NOT_FOUND;
-            goto cleanup;
-        }
-
-        //ASSERT(dp->wType == DNS_TYPE_A);
-        //ASSERT(dp->wDataLength == sizeof(DNS_A_DATA));
-        curr = dp;
-        while ((curr->pNext != NULL) || (curr->wType != DNS_TYPE_A))
-        {
-            curr = curr->pNext;
-        }
-
-        if (curr->wType != DNS_TYPE_A)
-        {
-            result = WSASERVICE_NOT_FOUND;
-            goto cleanup;
-        }
-
-        //WS_DbgPrint(MID_TRACE,("populating hostent\n"));
-        //WS_DbgPrint(MID_TRACE,("pName is (%s)\n", curr->pName));
-        //populate_hostent(p->Hostent,
-        //                 (PCHAR)curr->pName,
-        //                 curr->Data.A.IpAddress);
-        hostinfo->hostnameW = StrA2WHeapAlloc(hHeap, curr->pName);
-        hostinfo->addr4 = curr->Data.A.IpAddress;
-        result = ERROR_SUCCESS;
+        result = WSASERVICE_NOT_FOUND;
         goto cleanup;
-
-        //WS_DbgPrint(MID_TRACE,("Called DnsQuery, but host not found. Err: %i\n",
-        //            dns_status));
-        //WSASetLastError(WSAHOST_NOT_FOUND);
-        //return NULL;
-
-        break;
-
-        default:
-            result = WSANO_RECOVERY;
-            goto cleanup;
-        break;
     }
 
-    result = WSANO_RECOVERY;
+    hostinfo->hostnameW = StrCpyHeapAllocW(hHeap, curr->pName);
+    hostinfo->addr4 = curr->Data.A.IpAddress;
+    result = ERROR_SUCCESS;
 
 cleanup:
     if (dp != NULL)
         DnsRecordListFree(dp, DnsFreeRecordList);
 
-    if (tmpHostnameW != NULL)
-        HeapFree(hHeap, 0, tmpHostnameW);
-
     return result;
 }
 
@@ -825,6 +590,99 @@ DecodeServEntFromString(IN PCHAR ServiceString,
     return TRUE;
 }
 
+HANDLE
+WSAAPI
+OpenNetworkDatabase(_In_ LPCWSTR Name)
+{
+    PWSTR ExpandedPath;
+    PWSTR DatabasePath;
+    INT ErrorCode;
+    HKEY DatabaseKey;
+    DWORD RegType;
+    DWORD RegSize = 0;
+    size_t StringLength;
+    HANDLE ret;
+
+    ExpandedPath = HeapAlloc(GetProcessHeap(), 0, MAX_PATH*sizeof(WCHAR));
+    if (!ExpandedPath)
+        return INVALID_HANDLE_VALUE;
+
+    /* Open the database path key */
+    ErrorCode = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+                             L"System\\CurrentControlSet\\Services\\Tcpip\\Parameters",
+                             0,
+                             KEY_READ,
+                             &DatabaseKey);
+    if (ErrorCode == NO_ERROR)
+    {
+        /* Read the actual path */
+        ErrorCode = RegQueryValueEx(DatabaseKey,
+                                    L"DatabasePath",
+                                    NULL,
+                                    &RegType,
+                                    NULL,
+                                    &RegSize);
+
+        DatabasePath = HeapAlloc(GetProcessHeap(), 0, RegSize);
+        if (!DatabasePath)
+        {
+            HeapFree(GetProcessHeap(), 0, ExpandedPath);
+            return INVALID_HANDLE_VALUE;
+        }
+
+        /* Read the actual path */
+        ErrorCode = RegQueryValueEx(DatabaseKey,
+                                    L"DatabasePath",
+                                    NULL,
+                                    &RegType,
+                                    (LPBYTE)DatabasePath,
+                                    &RegSize);
+
+        /* Close the key */
+        RegCloseKey(DatabaseKey);
+
+        /* Expand the name */
+        ExpandEnvironmentStrings(DatabasePath, ExpandedPath, MAX_PATH);
+
+        HeapFree(GetProcessHeap(), 0, DatabasePath);
+    }
+    else
+    {
+        /* Use defalt path */
+        GetSystemDirectory(ExpandedPath, MAX_PATH);
+        StringCchLength(ExpandedPath, MAX_PATH, &StringLength);
+        if (ExpandedPath[StringLength - 1] != L'\\')
+        {
+            /* It isn't, so add it ourselves */
+            StringCchCat(ExpandedPath, MAX_PATH, L"\\");
+        }
+        StringCchCat(ExpandedPath, MAX_PATH, L"DRIVERS\\ETC\\");
+    }
+
+    /* Make sure that the path is backslash-terminated */
+    StringCchLength(ExpandedPath, MAX_PATH, &StringLength);
+    if (ExpandedPath[StringLength - 1] != L'\\')
+    {
+        /* It isn't, so add it ourselves */
+        StringCchCat(ExpandedPath, MAX_PATH, L"\\");
+    }
+
+    /* Add the database name */
+    StringCchCat(ExpandedPath, MAX_PATH, Name);
+
+    /* Return a handle to the file */
+    ret = CreateFile(ExpandedPath,
+                      FILE_READ_DATA,
+                      FILE_SHARE_READ,
+                      NULL,
+                      OPEN_EXISTING,
+                      FILE_ATTRIBUTE_NORMAL,
+                      NULL);
+
+    HeapFree(GetProcessHeap(), 0, ExpandedPath);
+    return ret;
+}
+
 INT
 NSP_GetServiceByNameHeapAllocW(_In_ WCHAR* nameW,
                                _In_ GUID* lpProviderId,
@@ -833,14 +691,11 @@ NSP_GetServiceByNameHeapAllocW(_In_ WCHAR* nameW,
     BOOL Found = FALSE;
     HANDLE ServicesFile;
     CHAR ServiceDBData[BUFSIZ * sizeof(WCHAR)] = {0};
-    PWCHAR SystemDirectory = (PWCHAR)ServiceDBData; /* Reuse this stack space */
-    PWCHAR ServicesFileLocation = L"\\drivers\\etc\\services";
     PCHAR ThisLine = 0, NextLine = 0, ServiceName = 0, PortNumberStr = 0,
     ProtocolStr = 0, Comment = 0, EndValid;
     PCHAR Aliases[WS2_INTERNAL_MAX_ALIAS] = {0};
     PCHAR* AliasPtr;
-    UINT i = 0,
-    SystemDirSize = (sizeof(ServiceDBData) / sizeof(WCHAR)) - 1;
+    UINT i = 0;
     DWORD ReadSize = 0;
     HANDLE hHeap;
     PCHAR nameA = NULL;
@@ -872,23 +727,7 @@ NSP_GetServiceByNameHeapAllocW(_In_ WCHAR* nameW,
     StringCbCopyA(nameServiceA, i + 1, nameA);
     nameServiceA[i] = '\0';
 
-    if (!GetSystemDirectoryW(SystemDirectory, SystemDirSize))
-    {
-        /* Can't get system directory */
-        res = WSANO_RECOVERY;
-        goto End;
-    }
-
-    wcsncat(SystemDirectory, ServicesFileLocation, SystemDirSize);
-
-    ServicesFile = CreateFileW(SystemDirectory,
-                               GENERIC_READ,
-                               FILE_SHARE_READ,
-                               NULL,
-                               OPEN_EXISTING,
-                               FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN,
-                               NULL);
-
+    ServicesFile = OpenNetworkDatabase(L"services");
     if (ServicesFile == INVALID_HANDLE_VALUE)
     {
         return WSANO_RECOVERY;
@@ -1055,19 +894,15 @@ NSP_LookupServiceNextW(_In_ PWSHANDLEINTERN data,
         if (result != ERROR_SUCCESS)
             goto End;
     }
-    else if (CallID == NSP_CALLID_SERVICEBYNAME)
+    else
     {
+        ASSERT(CallID == NSP_CALLID_SERVICEBYNAME);
         result = NSP_GetServiceByNameHeapAllocW(data->hostnameW,
                                                 &data->providerId,
                                                 &hostinfo);
         if (result != ERROR_SUCCESS)
             goto End;
     }
-    else
-    {
-        result = WSANO_RECOVERY; // Internal error!
-        goto End;
-    }
 
     if (((LUP_RETURN_BLOB & data->dwControlFlags) != 0) ||
         ((LUP_RETURN_NAME & data->dwControlFlags) != 0))