[SHIMENG] Prevent a nullptr dereference
[reactos.git] / dll / appcompat / apphelp / shimeng.c
index e03a573..3520266 100644 (file)
@@ -1,8 +1,8 @@
 /*
  * PROJECT:     ReactOS Application compatibility module
- * LICENSE:     GPL-2.0+ (https://spdx.org/licenses/GPL-2.0+)
+ * LICENSE:     GPL-2.0-or-later (https://spdx.org/licenses/GPL-2.0-or-later)
  * PURPOSE:     Shim engine core
- * COPYRIGHT:   Copyright 2015-2018 Mark Jansen (mark.jansen@reactos.org)
+ * COPYRIGHT:   Copyright 2015-2019 Mark Jansen (mark.jansen@reactos.org)
  */
 
 #define WIN32_NO_STATUS
@@ -25,6 +25,7 @@ extern HMODULE g_hInstance;
 static UNICODE_STRING g_WindowsDirectory;
 static UNICODE_STRING g_System32Directory;
 static UNICODE_STRING g_SxsDirectory;
+static UNICODE_STRING g_LoadingShimDll;
 ULONG g_ShimEngDebugLevel = 0xffffffff;
 BOOL g_bComPlusImage = FALSE;
 BOOL g_bShimDuringInit = FALSE;
@@ -315,6 +316,11 @@ PHOOKMODULEINFO SeiFindHookModuleInfo(PUNICODE_STRING ModuleName, PVOID BaseAddr
 {
     DWORD n;
 
+    if (ModuleName == NULL && BaseAddress == NULL)
+    {
+        BaseAddress = NtCurrentPeb()->ImageBaseAddress;
+    }
+
     for (n = 0; n < ARRAY_Size(&g_pHookArray); ++n)
     {
         PHOOKMODULEINFO pModuleInfo = ARRAY_At(&g_pHookArray, HOOKMODULEINFO, n);
@@ -342,13 +348,15 @@ PHOOKMODULEINFO SeiFindHookModuleInfoForImportDescriptor(PBYTE DllBase, PIMAGE_I
     }
 
     Success = LdrGetDllHandle(NULL, NULL, &DllName, &DllHandle);
-    RtlFreeUnicodeString(&DllName);
 
     if (!NT_SUCCESS(Success))
     {
-        SHIMENG_FAIL("Unable to get module handle for %wZ\n", &DllName);
+        SHIMENG_FAIL("Unable to get module handle for %wZ (%p)\n", &DllName, DllBase);
+        RtlFreeUnicodeString(&DllName);
+
         return NULL;
     }
+    RtlFreeUnicodeString(&DllName);
 
     return SeiFindHookModuleInfo(NULL, DllHandle);
 }
@@ -577,7 +585,8 @@ VOID SeiAddHooks(PHOOKAPIEX hooks, DWORD dwHookCount, PSHIMINFO pShim)
             }
         }
         pHookApi = ARRAY_Append(&HookModuleInfo->HookApis, PHOOKAPIEX);
-        *pHookApi = hook;
+        if (pHookApi)
+            *pHookApi = hook;
     }
 }
 
@@ -592,9 +601,9 @@ FARPROC WINAPI StubGetProcAddress(HINSTANCE hModule, LPCSTR lpProcName)
     PHOOKMODULEINFO HookModuleInfo;
     FARPROC proc = ((GETPROCADDRESSPROC)g_IntHookEx[0].OriginalFunction)(hModule, lpProcName);
 
-    if (!HIWORD(lpProcName))
+    if ((DWORD_PTR)lpProcName <= MAXUSHORT)
     {
-        sprintf(szOrdProcName, "#%lu", (DWORD)lpProcName);
+        sprintf(szOrdProcName, "#%Iu", (DWORD_PTR)lpProcName);
         lpPrintName = szOrdProcName;
     }
 
@@ -671,7 +680,7 @@ VOID SeiPatchNewImport(PIMAGE_THUNK_DATA FirstThunk, PHOOKAPIEX HookApi, PLDR_DA
 {
     ULONG OldProtection = 0;
     PVOID Ptr;
-    ULONG Size;
+    SIZE_T Size;
     NTSTATUS Status;
 
     SHIMENG_INFO("Hooking API \"%s!%s\" for DLL \"%wZ\"\n", HookApi->LibraryName, HookApi->FunctionName, &LdrEntry->BaseDllName);
@@ -863,7 +872,8 @@ VOID SeiHookImports(PLDR_DATA_TABLE_ENTRY LdrEntry)
     PIMAGE_IMPORT_DESCRIPTOR ImportDescriptor;
     PBYTE DllBase = LdrEntry->DllBase;
 
-    if (SE_IsShimDll(DllBase) || g_hInstance == LdrEntry->DllBase)
+    if (SE_IsShimDll(DllBase) || g_hInstance == LdrEntry->DllBase ||
+        (g_LoadingShimDll.Buffer && RtlEqualUnicodeString(&g_LoadingShimDll, &LdrEntry->BaseDllName, TRUE)))
     {
         SHIMENG_INFO("Skipping shim module 0x%p \"%wZ\"\n", LdrEntry->DllBase, &LdrEntry->BaseDllName);
         return;
@@ -1104,6 +1114,7 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
                 continue;
             }
 
+            RtlInitUnicodeString(&g_LoadingShimDll, DllName);
             RtlInitUnicodeString(&UnicodeDllName, FullNameBuffer);
             if (NT_SUCCESS(LdrGetDllHandle(NULL, NULL, &UnicodeDllName, &BaseAddress)))
             {
@@ -1115,6 +1126,7 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
                 SHIMENG_WARN("Failed to load %wZ for %S\n", &UnicodeDllName, ShimName);
                 continue;
             }
+            RtlInitUnicodeString(&g_LoadingShimDll, NULL);
             /* No shim module found (or we just loaded it) */
             if (!pShimModuleInfo)
             {