[USERENV] Use a reference counter when loading and unloading profiles. Unload the...
authorEric Kohl <eric.kohl@reactos.org>
Sat, 16 Mar 2019 06:45:30 +0000 (07:45 +0100)
committerEric Kohl <eric.kohl@reactos.org>
Sat, 16 Mar 2019 06:45:30 +0000 (07:45 +0100)
dll/win32/userenv/profile.c

index 6b4635b..1c18a67 100644 (file)
@@ -197,6 +197,157 @@ CreateProfileMutex(
 }
 
 
 }
 
 
+static
+DWORD
+IncrementRefCount(
+    PWSTR pszSidString,
+    PDWORD pdwRefCount)
+{
+    HKEY hProfilesKey = NULL, hProfileKey = NULL;
+    DWORD dwRefCount = 0, dwLength, dwType;
+    DWORD dwError;
+
+    DPRINT1("IncrementRefCount(%S %p)\n",
+            pszSidString, pdwRefCount);
+
+    dwError = RegOpenKeyExW(HKEY_LOCAL_MACHINE,
+                            L"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\ProfileList",
+                            0,
+                            KEY_QUERY_VALUE,
+                            &hProfilesKey);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("Error: %lu\n", dwError);
+        goto done;
+    }
+
+    dwError = RegOpenKeyExW(hProfilesKey,
+                            pszSidString,
+                            0,
+                            KEY_QUERY_VALUE | KEY_SET_VALUE,
+                            &hProfileKey);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("Error: %lu\n", dwError);
+        goto done;
+    }
+
+    /* Get the reference counter */
+    dwLength = sizeof(dwRefCount);
+    RegQueryValueExW(hProfileKey,
+                     L"RefCount",
+                     NULL,
+                     &dwType,
+                     (PBYTE)&dwRefCount,
+                     &dwLength);
+
+    dwRefCount++;
+
+    dwLength = sizeof(dwRefCount);
+    dwError = RegSetValueExW(hProfileKey,
+                             L"RefCount",
+                             0,
+                             REG_DWORD,
+                             (PBYTE)&dwRefCount,
+                             dwLength);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("Error: %lu\n", dwError);
+        goto done;
+    }
+
+    if (pdwRefCount != NULL)
+        *pdwRefCount = dwRefCount;
+
+done:
+    if (hProfileKey != NULL)
+        RegCloseKey(hProfileKey);
+
+    if (hProfilesKey != NULL)
+        RegCloseKey(hProfilesKey);
+
+    return dwError;
+}
+
+
+static
+DWORD
+DecrementRefCount(
+    PWSTR pszSidString,
+    PDWORD pdwRefCount)
+{
+    HKEY hProfilesKey = NULL, hProfileKey = NULL;
+    DWORD dwRefCount = 0, dwLength, dwType;
+    DWORD dwError;
+
+    DPRINT1("DecrementRefCount(%S %p)\n",
+            pszSidString, pdwRefCount);
+
+    dwError = RegOpenKeyExW(HKEY_LOCAL_MACHINE,
+                            L"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\ProfileList",
+                            0,
+                            KEY_QUERY_VALUE,
+                            &hProfilesKey);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("Error: %lu\n", dwError);
+        goto done;
+    }
+
+    dwError = RegOpenKeyExW(hProfilesKey,
+                            pszSidString,
+                            0,
+                            KEY_QUERY_VALUE | KEY_SET_VALUE,
+                            &hProfileKey);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("Error: %lu\n", dwError);
+        goto done;
+    }
+
+    /* Get the reference counter */
+    dwLength = sizeof(dwRefCount);
+    dwError = RegQueryValueExW(hProfileKey,
+                               L"RefCount",
+                               NULL,
+                               &dwType,
+                               (PBYTE)&dwRefCount,
+                               &dwLength);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("Error: %lu\n", dwError);
+        goto done;
+    }
+
+    dwRefCount--;
+
+    dwLength = sizeof(dwRefCount);
+    dwError = RegSetValueExW(hProfileKey,
+                             L"RefCount",
+                             0,
+                             REG_DWORD,
+                             (PBYTE)&dwRefCount,
+                             dwLength);
+    if (dwError != ERROR_SUCCESS)
+    {
+        DPRINT1("Error: %lu\n", dwError);
+        goto done;
+    }
+
+    if (pdwRefCount != NULL)
+        *pdwRefCount = dwRefCount;
+
+done:
+    if (hProfileKey != NULL)
+        RegCloseKey(hProfileKey);
+
+    if (hProfilesKey != NULL)
+        RegCloseKey(hProfilesKey);
+
+    return dwError;
+}
+
+
 /* PUBLIC FUNCTIONS ********************************************************/
 
 BOOL
 /* PUBLIC FUNCTIONS ********************************************************/
 
 BOOL
@@ -1613,6 +1764,14 @@ LoadUserProfileW(
         goto cleanup;
     }
 
         goto cleanup;
     }
 
+    Error = IncrementRefCount(SidString.Buffer, NULL);
+    if (Error != ERROR_SUCCESS)
+    {
+        DPRINT1("IncrementRefCount() failed (Error %ld)\n", Error);
+        SetLastError((DWORD)Error);
+        goto cleanup;
+    }
+
     ret = TRUE;
 
 cleanup:
     ret = TRUE;
 
 cleanup:
@@ -1640,6 +1799,7 @@ UnloadUserProfile(
 {
     UNICODE_STRING SidString = {0, 0, NULL};
     HANDLE hProfileMutex = NULL;
 {
     UNICODE_STRING SidString = {0, 0, NULL};
     HANDLE hProfileMutex = NULL;
+    DWORD dwRefCount = 0;
     LONG Error;
     BOOL bRet = FALSE;
 
     LONG Error;
     BOOL bRet = FALSE;
 
@@ -1675,44 +1835,57 @@ UnloadUserProfile(
     /* Close the profile handle */
     RegCloseKey(hProfile);
 
     /* Close the profile handle */
     RegCloseKey(hProfile);
 
-    /* Acquire restore privilege */
-    if (!AcquireRemoveRestorePrivilege(TRUE))
+    Error = DecrementRefCount(SidString.Buffer, &dwRefCount);
+    if (Error != ERROR_SUCCESS)
     {
     {
-        DPRINT1("AcquireRemoveRestorePrivilege() failed (Error %ld)\n", GetLastError());
+        DPRINT1("DecrementRefCount() failed (Error %ld)\n", Error);
+        SetLastError((DWORD)Error);
         goto cleanup;
     }
 
         goto cleanup;
     }
 
-    /* HACK */
+    if (dwRefCount == 0)
     {
     {
-        HKEY hUserKey;
+        DPRINT1("RefCount is 0: Unload the Hive!\n");
 
 
-        Error = RegOpenKeyExW(HKEY_USERS,
-                              SidString.Buffer,
-                              0,
-                              KEY_WRITE,
-                              &hUserKey);
-        if (Error == ERROR_SUCCESS)
+        /* Acquire restore privilege */
+        if (!AcquireRemoveRestorePrivilege(TRUE))
         {
         {
-            RegDeleteKeyW(hUserKey,
-                          L"Volatile Environment");
+            DPRINT1("AcquireRemoveRestorePrivilege() failed (Error %ld)\n", GetLastError());
+            goto cleanup;
+        }
+
+        /* HACK */
+        {
+            HKEY hUserKey;
+
+            Error = RegOpenKeyExW(HKEY_USERS,
+                                  SidString.Buffer,
+                                  0,
+                                  KEY_WRITE,
+                                  &hUserKey);
+            if (Error == ERROR_SUCCESS)
+            {
+                RegDeleteKeyW(hUserKey,
+                              L"Volatile Environment");
 
 
-            RegCloseKey(hUserKey);
+                RegCloseKey(hUserKey);
+            }
         }
         }
-    }
-    /* End of HACK */
+        /* End of HACK */
 
 
-    /* Unload the hive */
-    Error = RegUnLoadKeyW(HKEY_USERS,
-                          SidString.Buffer);
+        /* Unload the hive */
+        Error = RegUnLoadKeyW(HKEY_USERS,
+                              SidString.Buffer);
 
 
-    /* Remove restore privilege */
-    AcquireRemoveRestorePrivilege(FALSE);
+        /* Remove restore privilege */
+        AcquireRemoveRestorePrivilege(FALSE);
 
 
-    if (Error != ERROR_SUCCESS)
-    {
-        DPRINT1("RegUnLoadKeyW() failed (Error %ld)\n", Error);
-        SetLastError((DWORD)Error);
-        goto cleanup;
+        if (Error != ERROR_SUCCESS)
+        {
+            DPRINT1("RegUnLoadKeyW() failed (Error %ld)\n", Error);
+            SetLastError((DWORD)Error);
+            goto cleanup;
+        }
     }
 
     bRet = TRUE;
     }
 
     bRet = TRUE;