[DNSAPI] Answer queries for ip addresses before they are passed to the resolver cache.
authorEric Kohl <eric.kohl@reactos.org>
Sun, 12 Jan 2020 14:15:06 +0000 (15:15 +0100)
committerEric Kohl <eric.kohl@reactos.org>
Sun, 12 Jan 2020 14:15:55 +0000 (15:15 +0100)
dll/win32/dnsapi/query.c

index 0603b32..c92f9a1 100644 (file)
 #define NDEBUG
 #include <debug.h>
 
+static
+BOOL
+ParseIpv4Address(
+    _In_ PCWSTR AddressString,
+    _Out_ PIN_ADDR pAddress)
+{
+    PCWSTR pTerminator = NULL;
+    NTSTATUS Status;
+
+    Status = RtlIpv4StringToAddressW(AddressString,
+                                     TRUE,
+                                     &pTerminator,
+                                     pAddress);
+    if (NT_SUCCESS(Status) && pTerminator != NULL && *pTerminator == L'\0')
+        return TRUE;
+
+    return FALSE;
+}
+
+
+static
+BOOL
+ParseIpv6Address(
+    _In_ PCWSTR AddressString,
+    _Out_ PIN6_ADDR pAddress)
+{
+    PCWSTR pTerminator = NULL;
+    NTSTATUS Status;
+
+    Status = RtlIpv6StringToAddressW(AddressString,
+                                     &pTerminator,
+                                     pAddress);
+    if (NT_SUCCESS(Status) && pTerminator != NULL && *pTerminator == L'\0')
+        return TRUE;
+
+    return FALSE;
+}
+
+
+static
+PDNS_RECORDW
+CreateRecordForIpAddress(
+    _In_ PCWSTR Name,
+    _In_ WORD Type)
+{
+    IN_ADDR Ip4Address;
+    IN6_ADDR Ip6Address;
+    PDNS_RECORDW pRecord = NULL;
+
+    if (Type == DNS_TYPE_A)
+    {
+        if (ParseIpv4Address(Name, &Ip4Address))
+        {
+            pRecord = RtlAllocateHeap(RtlGetProcessHeap(),
+                                      HEAP_ZERO_MEMORY,
+                                      sizeof(DNS_RECORDW));
+            if (pRecord == NULL)
+                return NULL;
+
+            pRecord->pName = RtlAllocateHeap(RtlGetProcessHeap(),
+                                             0,
+                                             (wcslen(Name) + 1) * sizeof(WCHAR));
+            if (pRecord == NULL)
+            {
+                RtlFreeHeap(RtlGetProcessHeap(), 0, pRecord);
+                return NULL;
+            }
+
+            wcscpy(pRecord->pName, Name);
+            pRecord->wType = DNS_TYPE_A;
+            pRecord->wDataLength = sizeof(DNS_A_DATA);
+            pRecord->Flags.S.Section = DnsSectionQuestion;
+            pRecord->Flags.S.CharSet = DnsCharSetUnicode;
+            pRecord->dwTtl = 7 * 24 * 60 * 60;
+
+            pRecord->Data.A.IpAddress = Ip4Address.S_un.S_addr;
+
+            return pRecord;
+        }
+    }
+    else if (Type == DNS_TYPE_AAAA)
+    {
+        if (ParseIpv6Address(Name, &Ip6Address))
+        {
+            pRecord = RtlAllocateHeap(RtlGetProcessHeap(),
+                                      HEAP_ZERO_MEMORY,
+                                      sizeof(DNS_RECORDW));
+            if (pRecord == NULL)
+                return NULL;
+
+            pRecord->pName = RtlAllocateHeap(RtlGetProcessHeap(),
+                                             0,
+                                             (wcslen(Name) + 1) * sizeof(WCHAR));
+            if (pRecord == NULL)
+            {
+                RtlFreeHeap(RtlGetProcessHeap(), 0, pRecord);
+                return NULL;
+            }
+
+            wcscpy(pRecord->pName, Name);
+            pRecord->wType = DNS_TYPE_AAAA;
+            pRecord->wDataLength = sizeof(DNS_AAAA_DATA);
+            pRecord->Flags.S.Section = DnsSectionQuestion;
+            pRecord->Flags.S.CharSet = DnsCharSetUnicode;
+            pRecord->dwTtl = 7 * 24 * 60 * 60;
+
+            CopyMemory(&pRecord->Data.AAAA.Ip6Address,
+                       &Ip6Address.u.Byte,
+                       sizeof(IN6_ADDR));
+
+            return pRecord;
+        }
+    }
+
+    return NULL;
+}
+
 
 /* DnsQuery ****************************
  * Begin a DNS query, and allow the result to be placed in the application
@@ -347,6 +464,56 @@ DnsQuery_UTF8(LPCSTR Name,
     return DnsQuery_CodePage(CP_UTF8, Name, Type, Options, Extra, QueryResultSet, Reserved);
 }
 
+DNS_STATUS
+WINAPI
+DnsQuery_W(LPCWSTR Name,
+           WORD Type,
+           DWORD Options,
+           PVOID Extra,
+           PDNS_RECORD *QueryResultSet,
+           PVOID *Reserved)
+{
+    DWORD dwRecords = 0;
+    PDNS_RECORDW pRecord = NULL;
+    DNS_STATUS Status = ERROR_SUCCESS;
+
+    DPRINT("DnsQuery_W()\n");
+
+    if ((Name == NULL) ||
+        (QueryResultSet == NULL))
+        return ERROR_INVALID_PARAMETER;
+
+    *QueryResultSet = NULL;
+
+    /* Create an A or AAAA record for an IP4 or IP6 address */
+    pRecord = CreateRecordForIpAddress(Name,
+                                       Type);
+    if (pRecord != NULL)
+    {
+        *QueryResultSet = (PDNS_RECORD)pRecord;
+        return ERROR_SUCCESS;
+    }
+
+    RpcTryExcept
+    {
+        Status = R_ResolverQuery(NULL,
+                                 Name,
+                                 Type,
+                                 Options,
+                                 &dwRecords,
+                                 (DNS_RECORDW **)QueryResultSet);
+        DPRINT("R_ResolverQuery() returned %lu\n", Status);
+    }
+    RpcExcept(EXCEPTION_EXECUTE_HANDLER)
+    {
+        Status = RpcExceptionCode();
+        DPRINT("Exception returned %lu\n", Status);
+    }
+    RpcEndExcept;
+
+    return Status;
+}
+
 WCHAR
 *xstrsave(const WCHAR *str)
 {
@@ -455,111 +622,6 @@ CheckForCurrentHostname(CONST CHAR * Name, PFIXED_INFO network_info)
     return ret;
 }
 
-BOOL
-ParseV4Address(LPCSTR AddressString,
-    OUT PDWORD pAddress)
-{
-    CHAR * cp = (CHAR *)AddressString;
-    DWORD val, base;
-    unsigned char c;
-    DWORD parts[4], *pp = parts;
-    if (!AddressString)
-        return FALSE;
-    if (!isdigit(*cp)) return FALSE;
-
-again:
-    /*
-    * Collect number up to ``.''.
-    * Values are specified as for C:
-    * 0x=hex, 0=octal, other=decimal.
-    */
-    val = 0; base = 10;
-    if (*cp == '0') {
-        if (*++cp == 'x' || *cp == 'X')
-            base = 16, cp++;
-        else
-            base = 8;
-    }
-    while ((c = *cp)) {
-        if (isdigit(c)) {
-            val = (val * base) + (c - '0');
-            cp++;
-            continue;
-        }
-        if (base == 16 && isxdigit(c)) {
-            val = (val << 4) + (c + 10 - (islower(c) ? 'a' : 'A'));
-            cp++;
-            continue;
-        }
-        break;
-    }
-    if (*cp == '.') {
-        /*
-        * Internet format:
-        *    a.b.c.d
-        */
-        if (pp >= parts + 4) return FALSE;
-        *pp++ = val;
-        cp++;
-        goto again;
-    }
-    /*
-    * Check for trailing characters.
-    */
-    if (*cp && *cp > ' ') return FALSE;
-
-    if (pp >= parts + 4) return FALSE;
-    *pp++ = val;
-    /*
-    * Concoct the address according to
-    * the number of parts specified.
-    */
-    if ((DWORD)(pp - parts) != 4) return FALSE;
-    if (parts[0] > 0xff || parts[1] > 0xff || parts[2] > 0xff || parts[3] > 0xff) return FALSE;
-    val = (parts[3] << 24) | (parts[2] << 16) | (parts[1] << 8) | parts[0];
-
-    if (pAddress)
-        *pAddress = val;
-
-    return TRUE;
-}
-
-
-DNS_STATUS WINAPI
-DnsQuery_W(LPCWSTR Name,
-           WORD Type,
-           DWORD Options,
-           PVOID Extra,
-           PDNS_RECORD *QueryResultSet,
-           PVOID *Reserved)
-{
-    DWORD dwRecords = 0;
-    DNS_STATUS Status = ERROR_SUCCESS;
-
-    DPRINT("DnsQuery_W()\n");
-
-    *QueryResultSet = NULL;
-
-    RpcTryExcept
-    {
-        Status = R_ResolverQuery(NULL,
-                                 Name,
-                                 Type,
-                                 Options,
-                                 &dwRecords,
-                                 (DNS_RECORDW **)QueryResultSet);
-        DPRINT("R_ResolverQuery() returned %lu\n", Status);
-    }
-    RpcExcept(EXCEPTION_EXECUTE_HANDLER)
-    {
-        Status = RpcExceptionCode();
-        DPRINT("Exception returned %lu\n", Status);
-    }
-    RpcEndExcept;
-
-    return Status;
-}
-
 
 DNS_STATUS
 WINAPI
@@ -617,28 +679,6 @@ Query_Main(LPCWSTR Name,
                             NULL,
                             0);
         NameLen--;
-        /* Is it an IPv4 address? */
-        if (ParseV4Address(AnsiName, &Address))
-        {
-            RtlFreeHeap(RtlGetProcessHeap(), 0, AnsiName);
-            *QueryResultSet = (PDNS_RECORD)RtlAllocateHeap(RtlGetProcessHeap(), 0, sizeof(DNS_RECORD));
-
-            if (NULL == *QueryResultSet)
-            {
-                return ERROR_OUTOFMEMORY;
-            }
-
-            (*QueryResultSet)->pNext = NULL;
-            (*QueryResultSet)->wType = Type;
-            (*QueryResultSet)->wDataLength = sizeof(DNS_A_DATA);
-            (*QueryResultSet)->Flags.S.Section = DnsSectionAnswer;
-            (*QueryResultSet)->Flags.S.CharSet = DnsCharSetUnicode;
-            (*QueryResultSet)->Data.A.IpAddress = Address;
-
-            (*QueryResultSet)->pName = (LPSTR)xstrsave(Name);
-
-            return (*QueryResultSet)->pName ? ERROR_SUCCESS : ERROR_OUTOFMEMORY;
-        }
 
         /* Check allowed characters
         * According to RFC a-z,A-Z,0-9,-,_, but can't start or end with - or _