[RAPPS] Various fixes
[reactos.git] / base / applications / rapps / loaddlg.cpp
index 8e2e76a..30aa4a2 100644 (file)
@@ -331,11 +331,24 @@ HRESULT WINAPI CDownloadDialog_Constructor(HWND Dlg, BOOL *pbCancelled, REFIID r
 }
 
 #ifdef USE_CERT_PINNING
-static BOOL CertIsValid(HINTERNET hFile, LPWSTR lpszHostName)
+typedef CHeapPtr<char, CLocalAllocator> CLocalPtr;
+
+static BOOL CertGetSubjectAndIssuer(HINTERNET hFile, CLocalPtr& subjectInfo, CLocalPtr& issuerInfo)
 {
     DWORD certInfoLength;
     INTERNET_CERTIFICATE_INFOA certInfo;
-    int ValidFlags = 0;
+    DWORD size, flags;
+
+    size = sizeof(flags);
+    if (!InternetQueryOptionA(hFile, INTERNET_OPTION_SECURITY_FLAGS, &flags, &size))
+    {
+        return FALSE;
+    }
+
+    if (!flags & SECURITY_FLAG_SECURE)
+    {
+        return FALSE;
+    }
 
     /* Despite what the header indicates, the implementation of INTERNET_CERTIFICATE_INFO is not Unicode-aware. */
     certInfoLength = sizeof(certInfo);
@@ -347,18 +360,9 @@ static BOOL CertIsValid(HINTERNET hFile, LPWSTR lpszHostName)
         return FALSE;
     }
 
-    if (certInfo.lpszSubjectInfo)
-    {
-        if (strcmp(certInfo.lpszSubjectInfo, CERT_SUBJECT_INFO) == 0)
-            ValidFlags |= 1;
-        LocalFree(certInfo.lpszSubjectInfo);
-    }
-    if (certInfo.lpszIssuerInfo)
-    {
-        if (strcmp(certInfo.lpszIssuerInfo, CERT_ISSUER_INFO) == 0)
-            ValidFlags |= 2;
-        LocalFree(certInfo.lpszIssuerInfo);
-    }
+    subjectInfo.Attach(certInfo.lpszSubjectInfo);
+    issuerInfo.Attach(certInfo.lpszIssuerInfo);
+
     if (certInfo.lpszProtocolName)
         LocalFree(certInfo.lpszProtocolName);
     if (certInfo.lpszSignatureAlgName)
@@ -366,13 +370,13 @@ static BOOL CertIsValid(HINTERNET hFile, LPWSTR lpszHostName)
     if (certInfo.lpszEncryptionAlgName)
         LocalFree(certInfo.lpszEncryptionAlgName);
 
-    return ValidFlags == 3;
+    return certInfo.lpszSubjectInfo && certInfo.lpszIssuerInfo;
 }
 #endif
 
 inline VOID MessageBox_LoadString(HWND hMainWnd, INT StringID)
 {
-    ATL::CString szMsgText;
+    ATL::CStringW szMsgText;
     if (szMsgText.LoadStringW(StringID))
     {
         MessageBoxW(hMainWnd, szMsgText.GetString(), NULL, MB_OK | MB_ICONERROR);
@@ -616,6 +620,19 @@ DWORD WINAPI CDownloadManager::ThreadFunc(LPVOID param)
             SendMessageW(Item, PBM_SETPOS, 0, 0);
         }
 
+        // is this URL an update package for RAPPS? if so store it in a different place
+        if (InfoArray[iAppId].szUrl == APPLICATION_DATABASE_URL)
+        {
+            bCab = TRUE;
+            if (!GetStorageDirectory(Path))
+                goto end;
+        }
+        else
+        {
+            bCab = FALSE;
+            Path = SettingsInfo.szDownloadDir;
+        }
+
         // Change caption to show the currently downloaded app
         if (!bCab)
         {
@@ -644,18 +661,6 @@ DWORD WINAPI CDownloadManager::ThreadFunc(LPVOID param)
         if (q && q > p && (q - p) > 0)
             filenameLength -= wcslen(q - 1) * sizeof(WCHAR);
 
-        // is this URL an update package for RAPPS? if so store it in a different place
-        if (InfoArray[iAppId].szUrl == APPLICATION_DATABASE_URL)
-        {
-            bCab = TRUE;
-            if (!GetStorageDirectory(Path))
-                goto end;
-        }
-        else
-        {
-            Path = SettingsInfo.szDownloadDir;
-        }
-
         // is the path valid? can we access it?
         if (GetFileAttributesW(Path.GetString()) == INVALID_FILE_ATTRIBUTES)
         {
@@ -690,6 +695,7 @@ DWORD WINAPI CDownloadManager::ThreadFunc(LPVOID param)
         switch (SettingsInfo.Proxy)
         {
         case 0: // preconfig
+        default:
             hOpen = InternetOpenW(lpszAgent, INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, 0);
             break;
         case 1: // direct (no proxy)
@@ -698,15 +704,14 @@ DWORD WINAPI CDownloadManager::ThreadFunc(LPVOID param)
         case 2: // use proxy
             hOpen = InternetOpenW(lpszAgent, INTERNET_OPEN_TYPE_PROXY, SettingsInfo.szProxyServer, SettingsInfo.szNoProxyFor, 0);
             break;
-        default: // preconfig
-            hOpen = InternetOpenW(lpszAgent, INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, 0);
-            break;
         }
 
         if (!hOpen)
             goto end;
 
-        hFile = InternetOpenUrlW(hOpen, InfoArray[iAppId].szUrl.GetString(), NULL, 0, INTERNET_FLAG_PRAGMA_NOCACHE | INTERNET_FLAG_KEEP_CONNECTION, 0);
+        hFile = InternetOpenUrlW(hOpen, InfoArray[iAppId].szUrl.GetString(), NULL, 0,
+                                 INTERNET_FLAG_DONT_CACHE | INTERNET_FLAG_PRAGMA_NOCACHE | INTERNET_FLAG_KEEP_CONNECTION,
+                                 0);
 
         if (!hFile)
         {
@@ -740,7 +745,7 @@ DWORD WINAPI CDownloadManager::ThreadFunc(LPVOID param)
         dwContentLen = 0;
 
         if (urlComponents.nScheme == INTERNET_SCHEME_HTTP || urlComponents.nScheme == INTERNET_SCHEME_HTTPS)
-            HttpQueryInfoW(hFile, HTTP_QUERY_CONTENT_LENGTH | HTTP_QUERY_FLAG_NUMBER, &dwContentLen, &dwStatus, 0);
+            HttpQueryInfoW(hFile, HTTP_QUERY_CONTENT_LENGTH | HTTP_QUERY_FLAG_NUMBER, &dwContentLen, &dwStatusLen, 0);
 
         if (urlComponents.nScheme == INTERNET_SCHEME_FTP)
             dwContentLen = FtpGetFileSize(hFile, &dwStatus);
@@ -751,20 +756,42 @@ DWORD WINAPI CDownloadManager::ThreadFunc(LPVOID param)
             SetProgressMarquee(Item, TRUE);
         }
 
+        free(urlComponents.lpszScheme);
+        free(urlComponents.lpszHostName);
+
 #ifdef USE_CERT_PINNING
         // are we using HTTPS to download the RAPPS update package? check if the certificate is original
         if ((urlComponents.nScheme == INTERNET_SCHEME_HTTPS) &&
-            (wcscmp(InfoArray[iAppId].szUrl, APPLICATION_DATABASE_URL) == 0) &&
-            (!CertIsValid(hFile, urlComponents.lpszHostName)))
+            (wcscmp(InfoArray[iAppId].szUrl, APPLICATION_DATABASE_URL) == 0))
         {
-            MessageBox_LoadString(hMainWnd, IDS_CERT_DOES_NOT_MATCH);
-            goto end;
+            CLocalPtr subjectName, issuerName;
+            CStringW szMsgText;
+            bool bAskQuestion = false;
+            if (!CertGetSubjectAndIssuer(hFile, subjectName, issuerName))
+            {
+                szMsgText.LoadStringW(IDS_UNABLE_TO_QUERY_CERT);
+                bAskQuestion = true;
+            }
+            else
+            {
+                if (strcmp(subjectName, CERT_SUBJECT_INFO) ||
+                    strcmp(issuerName, CERT_ISSUER_INFO))
+                {
+                    szMsgText.Format(IDS_MISMATCH_CERT_INFO, (char*)subjectName, (const char*)issuerName);
+                    bAskQuestion = true;
+                }
+            }
+
+            if (bAskQuestion)
+            {
+                if (MessageBoxW(hMainWnd, szMsgText.GetString(), NULL, MB_YESNO | MB_ICONERROR) != IDYES)
+                {
+                    goto end;
+                }
+            }
         }
 #endif
 
-        free(urlComponents.lpszScheme);
-        free(urlComponents.lpszHostName);
-
         hOut = CreateFileW(Path.GetString(), GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE, NULL, CREATE_ALWAYS, 0, NULL);
 
         if (hOut == INVALID_HANDLE_VALUE)