[WINLOGON] Load the notification dlls on demand
authorEric Kohl <eric.kohl@reactos.org>
Sat, 3 Aug 2019 13:04:50 +0000 (15:04 +0200)
committerEric Kohl <eric.kohl@reactos.org>
Sat, 3 Aug 2019 13:05:20 +0000 (15:05 +0200)
Do not keep the notifications loaded until shutdown. Load and unload them when a notification routine must be called.

base/system/winlogon/notify.c

index 6c20e0f..c0b78dc 100644 (file)
@@ -35,15 +35,14 @@ static PSTR FuncNames[LastHandler] =
 typedef struct _NOTIFICATION_ITEM
 {
     LIST_ENTRY ListEntry;
-
-    HINSTANCE hInstance;
+    PWSTR pszKeyName;
+    PWSTR pszDllName;
     BOOL bEnabled;
     BOOL bAsynchronous;
     BOOL bSafe;
     BOOL bImpersonate;
     BOOL bSmartCardLogon;
     DWORD dwMaxWait;
-    PWLX_NOTIFY_HANDLER Handler[LastHandler];
 } NOTIFICATION_ITEM, *PNOTIFICATION_ITEM;
 
 
@@ -52,140 +51,145 @@ static LIST_ENTRY NotificationDllListHead;
 
 /* FUNCTIONS *****************************************************************/
 
-PWLX_NOTIFY_HANDLER
-GetNotificationHandler(
-    HKEY hDllKey,
-    HINSTANCE hInstance,
-    PSTR pNotification)
-{
-    CHAR szFuncBuffer[128];
-    DWORD dwSize;
-    DWORD dwType;
-    LONG lError;
-
-    dwSize = 128;
-    lError = RegQueryValueExA(hDllKey,
-                              pNotification,
-                              NULL,
-                              &dwType,
-                              (PBYTE)szFuncBuffer,
-                              &dwSize);
-    if (lError == ERROR_SUCCESS)
-    {
-        return (PWLX_NOTIFY_HANDLER)GetProcAddress(hInstance, szFuncBuffer);
-    }
-
-    return NULL;
-}
-
-
 static
 VOID
-LoadNotificationDll(
+AddNotificationDll(
     HKEY hNotifyKey,
-    PWSTR pKeyName)
+    PWSTR pszKeyName)
 {
     HKEY hDllKey = NULL;
     PNOTIFICATION_ITEM NotificationDll = NULL;
-    LONG lError;
-    WCHAR szBuffer[80];
-    DWORD dwSize;
-    DWORD dwType;
-    HINSTANCE hInstance;
-    NOTIFICATION_TYPE i;
+    DWORD dwSize, dwType;
+    DWORD dwError;
 
-    TRACE("LoadNotificationDll(%p %S)\n", hNotifyKey, pKeyName);
+    TRACE("AddNotificationDll(%p %S)\n", hNotifyKey, pszKeyName);
 
-    lError = RegOpenKeyExW(hNotifyKey,
-                           pKeyName,
-                           0,
-                           KEY_READ,
-                           &hDllKey);
-    if (lError != ERROR_SUCCESS)
+    dwError = RegOpenKeyExW(hNotifyKey,
+                            pszKeyName,
+                            0,
+                            KEY_READ,
+                            &hDllKey);
+    if (dwError != ERROR_SUCCESS)
         return;
 
-    dwSize = 80 * sizeof(WCHAR);
-    lError = RegQueryValueExW(hDllKey,
-                              L"DllName",
-                              NULL,
-                              &dwType,
-                              (PBYTE)szBuffer,
-                              &dwSize);
-    if (lError == ERROR_SUCCESS)
+    NotificationDll = RtlAllocateHeap(RtlGetProcessHeap(),
+                                      HEAP_ZERO_MEMORY,
+                                      sizeof(NOTIFICATION_ITEM));
+    if (NotificationDll == NULL)
     {
-        hInstance = LoadLibraryW(szBuffer);
-        if (hInstance == NULL)
-            return;
-
-        NotificationDll = RtlAllocateHeap(RtlGetProcessHeap(),
-                                          HEAP_ZERO_MEMORY,
-                                          sizeof(NOTIFICATION_ITEM));
-        if (NotificationDll == NULL)
-        {
-            FreeLibrary(hInstance);
-            return;
-        }
+        dwError = ERROR_OUTOFMEMORY;
+        goto done;
+    }
+
+    NotificationDll->pszKeyName = RtlAllocateHeap(RtlGetProcessHeap(),
+                                                  HEAP_ZERO_MEMORY,
+                                                  (wcslen(pszKeyName) + 1) * sizeof(WCHAR));
+    if (NotificationDll->pszKeyName == NULL)
+    {
+        dwError = ERROR_OUTOFMEMORY;
+        goto done;
+    }
+
+    wcscpy(NotificationDll->pszKeyName, pszKeyName);
+
+    dwSize = 0;
+    RegQueryValueExW(hDllKey,
+                     L"DllName",
+                     NULL,
+                     &dwType,
+                     NULL,
+                     &dwSize);
+    if (dwSize == 0)
+    {
+        dwError = ERROR_FILE_NOT_FOUND;
+        goto done;
+    }
 
-        NotificationDll->bEnabled = TRUE;
-        NotificationDll->dwMaxWait = 30; /* FIXME: ??? */
-        NotificationDll->hInstance = hInstance;
-
-        dwSize = sizeof(BOOL);
-        RegQueryValueExW(hDllKey,
-                         L"Asynchronous",
-                         NULL,
-                         &dwType,
-                         (PBYTE)&NotificationDll->bAsynchronous,
-                         &dwSize);
-
-        dwSize = sizeof(BOOL);
-        RegQueryValueExW(hDllKey,
-                         L"Enabled",
-                         NULL,
-                         &dwType,
-                         (PBYTE)&NotificationDll->bEnabled,
-                         &dwSize);
-
-        dwSize = sizeof(BOOL);
-        RegQueryValueExW(hDllKey,
-                         L"Impersonate",
-                         NULL,
-                         &dwType,
-                         (PBYTE)&NotificationDll->bImpersonate,
-                         &dwSize);
-
-        dwSize = sizeof(BOOL);
-        RegQueryValueExW(hDllKey,
-                         L"Safe",
-                         NULL,
-                         &dwType,
-                         (PBYTE)&NotificationDll->bSafe,
-                         &dwSize);
-
-        dwSize = sizeof(BOOL);
-        RegQueryValueExW(hDllKey,
-                         L"SmartCardLogonNotify",
-                         NULL,
-                         &dwType,
-                         (PBYTE)&NotificationDll->bSmartCardLogon,
-                         &dwSize);
-
-        dwSize = sizeof(DWORD);
-        RegQueryValueExW(hDllKey,
-                         L"MaxWait",
-                         NULL,
-                         &dwType,
-                         (PBYTE)&NotificationDll->dwMaxWait,
-                         &dwSize);
-
-        for (i = LogonHandler; i < LastHandler; i++)
+    NotificationDll->pszDllName = RtlAllocateHeap(RtlGetProcessHeap(),
+                                                  HEAP_ZERO_MEMORY,
+                                                  dwSize);
+    if (NotificationDll->pszDllName == NULL)
+    {
+        dwError = ERROR_OUTOFMEMORY;
+        goto done;
+    }
+
+    dwError = RegQueryValueExW(hDllKey,
+                               L"DllName",
+                               NULL,
+                               &dwType,
+                               (PBYTE)NotificationDll->pszDllName,
+                               &dwSize);
+    if (dwError != ERROR_SUCCESS)
+        goto done;
+
+    NotificationDll->bEnabled = TRUE;
+    NotificationDll->dwMaxWait = 30; /* FIXME: ??? */
+
+    dwSize = sizeof(BOOL);
+    RegQueryValueExW(hDllKey,
+                     L"Asynchronous",
+                     NULL,
+                     &dwType,
+                     (PBYTE)&NotificationDll->bAsynchronous,
+                     &dwSize);
+
+    dwSize = sizeof(BOOL);
+    RegQueryValueExW(hDllKey,
+                     L"Enabled",
+                     NULL,
+                     &dwType,
+                     (PBYTE)&NotificationDll->bEnabled,
+                     &dwSize);
+
+    dwSize = sizeof(BOOL);
+    RegQueryValueExW(hDllKey,
+                     L"Impersonate",
+                     NULL,
+                     &dwType,
+                     (PBYTE)&NotificationDll->bImpersonate,
+                     &dwSize);
+
+    dwSize = sizeof(BOOL);
+    RegQueryValueExW(hDllKey,
+                     L"Safe",
+                     NULL,
+                     &dwType,
+                     (PBYTE)&NotificationDll->bSafe,
+                     &dwSize);
+
+    dwSize = sizeof(BOOL);
+    RegQueryValueExW(hDllKey,
+                     L"SmartCardLogonNotify",
+                     NULL,
+                     &dwType,
+                     (PBYTE)&NotificationDll->bSmartCardLogon,
+                     &dwSize);
+
+    dwSize = sizeof(DWORD);
+    RegQueryValueExW(hDllKey,
+                     L"MaxWait",
+                     NULL,
+                     &dwType,
+                     (PBYTE)&NotificationDll->dwMaxWait,
+                     &dwSize);
+
+    InsertHeadList(&NotificationDllListHead,
+                   &NotificationDll->ListEntry);
+
+done:
+    if (dwError != ERROR_SUCCESS)
+    {
+        if (NotificationDll != NULL)
         {
-            NotificationDll->Handler[i] = GetNotificationHandler(hDllKey, hInstance, FuncNames[i]);
-            TRACE("%s: %p\n", FuncNames[i], NotificationDll->Handler[i]);
-        }
+            if (NotificationDll->pszKeyName != NULL)
+                RtlFreeHeap(RtlGetProcessHeap(), 0, NotificationDll->pszKeyName);
+
+            if (NotificationDll->pszDllName != NULL)
+                RtlFreeHeap(RtlGetProcessHeap(), 0, NotificationDll->pszDllName);
 
-        InsertHeadList(&NotificationDllListHead,
-                       &NotificationDll->ListEntry);
+            RtlFreeHeap(RtlGetProcessHeap(), 0, NotificationDll);
+        }
     }
 
     RegCloseKey(hDllKey);
@@ -232,7 +236,7 @@ InitNotifications(VOID)
             break;
 
         TRACE("Notification DLL: %S\n", szKeyName);
-        LoadNotificationDll(hNotifyKey, szKeyName);
+        AddNotificationDll(hNotifyKey, szKeyName);
 
         dwIndex++;
     }
@@ -245,6 +249,57 @@ InitNotifications(VOID)
 }
 
 
+static
+VOID
+CallNotificationDll(
+    HKEY hNotifyKey,
+    PNOTIFICATION_ITEM NotificationDll,
+    NOTIFICATION_TYPE Type,
+    PWLX_NOTIFICATION_INFO pInfo)
+{
+    HKEY hDllKey = NULL;
+    HMODULE hModule = NULL;
+    CHAR szFuncBuffer[128];
+    DWORD dwSize;
+    DWORD dwType;
+    DWORD dwError;
+    PWLX_NOTIFY_HANDLER pNotifyHandler;
+
+    dwError = RegOpenKeyExW(hNotifyKey,
+                            NotificationDll->pszKeyName,
+                            0,
+                            KEY_READ,
+                            &hDllKey);
+    if (dwError != ERROR_SUCCESS)
+    {
+        TRACE("RegOpenKeyExW()\n");
+        return;
+    }
+
+    dwSize = sizeof(szFuncBuffer);
+    dwError = RegQueryValueExA(hDllKey,
+                               FuncNames[Type],
+                               NULL,
+                               &dwType,
+                               (PBYTE)szFuncBuffer,
+                               &dwSize);
+    if (dwError == ERROR_SUCCESS)
+    {
+        hModule = LoadLibraryW(NotificationDll->pszDllName);
+        if (hModule != NULL)
+        {
+            pNotifyHandler = (PWLX_NOTIFY_HANDLER)GetProcAddress(hModule, szFuncBuffer);
+            if (pNotifyHandler != NULL)
+                pNotifyHandler(pInfo);
+
+            FreeLibrary(hModule);
+        }
+    }
+
+    RegCloseKey(hDllKey);
+}
+
+
 VOID
 CallNotificationDlls(
     PWLSESSION pSession,
@@ -253,9 +308,22 @@ CallNotificationDlls(
     PLIST_ENTRY ListEntry;
     PNOTIFICATION_ITEM NotificationDll;
     WLX_NOTIFICATION_INFO Info;
+    HKEY hNotifyKey = NULL;
+    DWORD dwError;
 
     TRACE("CallNotificationDlls()\n");
 
+    dwError = RegOpenKeyExW(HKEY_LOCAL_MACHINE,
+                            L"Software\\Microsoft\\Windows NT\\CurrentVersion\\Winlogon\\Notify",
+                            0,
+                            KEY_READ | KEY_ENUMERATE_SUB_KEYS,
+                            &hNotifyKey);
+    if (dwError != ERROR_SUCCESS)
+    {
+        TRACE("RegOpenKeyExW()\n");
+        return;
+    }
+
     Info.Size = sizeof(WLX_NOTIFICATION_INFO);
 
     switch (Type)
@@ -303,14 +371,12 @@ TRACE("ListEntry %p\n", ListEntry);
                                             ListEntry);
 TRACE("NotificationDll: %p\n", NotificationDll);
         if (NotificationDll != NULL && NotificationDll->bEnabled)
-        {
-TRACE("NotificationDll->Handler: %p\n", NotificationDll->Handler[Type]);
-            if (NotificationDll->Handler[Type] != NULL)
-                NotificationDll->Handler[Type](&Info);
-        }
+            CallNotificationDll(hNotifyKey, NotificationDll, Type, &Info);
 
         ListEntry = ListEntry->Flink;
     }
+
+    RegCloseKey(hNotifyKey);
 }
 
 
@@ -328,7 +394,11 @@ CleanupNotifications(VOID)
                                             ListEntry);
         if (NotificationDll != NULL)
         {
-            FreeLibrary(NotificationDll->hInstance);
+            if (NotificationDll->pszKeyName != NULL)
+                RtlFreeHeap(RtlGetProcessHeap(), 0, NotificationDll->pszKeyName);
+
+            if (NotificationDll->pszDllName != NULL)
+                RtlFreeHeap(RtlGetProcessHeap(), 0, NotificationDll->pszDllName);
         }
 
         ListEntry = ListEntry->Flink;