[SHIMENG] Don't crash on a shim not found
[reactos.git] / dll / appcompat / apphelp / shimeng.c
index 0d95d61..d1ba71d 100644 (file)
@@ -1,8 +1,8 @@
 /*
  * PROJECT:     ReactOS Application compatibility module
- * LICENSE:     GPL-2.0+ (https://spdx.org/licenses/GPL-2.0+)
+ * LICENSE:     GPL-2.0-or-later (https://spdx.org/licenses/GPL-2.0-or-later)
  * PURPOSE:     Shim engine core
- * COPYRIGHT:   Copyright 2015-2017 Mark Jansen (mark.jansen@reactos.org)
+ * COPYRIGHT:   Copyright 2015-2019 Mark Jansen (mark.jansen@reactos.org)
  */
 
 #define WIN32_NO_STATUS
@@ -25,6 +25,7 @@ extern HMODULE g_hInstance;
 static UNICODE_STRING g_WindowsDirectory;
 static UNICODE_STRING g_System32Directory;
 static UNICODE_STRING g_SxsDirectory;
+static UNICODE_STRING g_LoadingShimDll;
 ULONG g_ShimEngDebugLevel = 0xffffffff;
 BOOL g_bComPlusImage = FALSE;
 BOOL g_bShimDuringInit = FALSE;
@@ -315,6 +316,11 @@ PHOOKMODULEINFO SeiFindHookModuleInfo(PUNICODE_STRING ModuleName, PVOID BaseAddr
 {
     DWORD n;
 
+    if (ModuleName == NULL && BaseAddress == NULL)
+    {
+        BaseAddress = NtCurrentPeb()->ImageBaseAddress;
+    }
+
     for (n = 0; n < ARRAY_Size(&g_pHookArray); ++n)
     {
         PHOOKMODULEINFO pModuleInfo = ARRAY_At(&g_pHookArray, HOOKMODULEINFO, n);
@@ -342,13 +348,15 @@ PHOOKMODULEINFO SeiFindHookModuleInfoForImportDescriptor(PBYTE DllBase, PIMAGE_I
     }
 
     Success = LdrGetDllHandle(NULL, NULL, &DllName, &DllHandle);
-    RtlFreeUnicodeString(&DllName);
 
     if (!NT_SUCCESS(Success))
     {
-        SHIMENG_FAIL("Unable to get module handle for %wZ\n", &DllName);
+        SHIMENG_FAIL("Unable to get module handle for %wZ (%p)\n", &DllName, DllBase);
+        RtlFreeUnicodeString(&DllName);
+
         return NULL;
     }
+    RtlFreeUnicodeString(&DllName);
 
     return SeiFindHookModuleInfo(NULL, DllHandle);
 }
@@ -368,9 +376,17 @@ static DWORD SeiGetDWORD(PDB pdb, TAGID tag, TAG type)
     if (tagEntry == TAGID_NULL)
         return 0;
 
-    return SdbReadDWORDTag(pdb, tagEntry, TAGID_NULL);
+    return SdbReadDWORDTag(pdb, tagEntry, 0);
 }
 
+static QWORD SeiGetQWORD(PDB pdb, TAGID tag, TAG type)
+{
+    TAGID tagEntry = SdbFindFirstTag(pdb, tag, type);
+    if (tagEntry == TAGID_NULL)
+        return 0;
+
+    return SdbReadQWORDTag(pdb, tagEntry, 0);
+}
 
 static VOID SeiAddShim(TAGREF trShimRef, PARRAY pShimRef)
 {
@@ -383,6 +399,22 @@ static VOID SeiAddShim(TAGREF trShimRef, PARRAY pShimRef)
     *Data = trShimRef;
 }
 
+static VOID SeiAddFlag(PDB pdb, TAGID tiFlagRef, PFLAGINFO pFlagInfo)
+{
+    ULARGE_INTEGER Flag;
+
+    /* Resolve the FLAG_REF to the real FLAG node */
+    TAGID FlagTag = SeiGetDWORD(pdb, tiFlagRef, TAG_FLAG_TAGID);
+
+    if (FlagTag == TAGID_NULL)
+        return;
+
+    pFlagInfo->AppCompatFlags.QuadPart |= SeiGetQWORD(pdb, FlagTag, TAG_FLAG_MASK_KERNEL);
+    pFlagInfo->AppCompatFlagsUser.QuadPart |= SeiGetQWORD(pdb, FlagTag, TAG_FLAG_MASK_USER);
+    Flag.QuadPart = SeiGetQWORD(pdb, FlagTag, TAG_FLAG_PROCESSPARAM);
+    pFlagInfo->ProcessParameters_Flags |= Flag.LowPart;
+}
+
 /* Propagate layers to child processes */
 static VOID SeiSetLayerEnvVar(LPCWSTR wszLayer)
 {
@@ -402,7 +434,7 @@ static VOID SeiSetLayerEnvVar(LPCWSTR wszLayer)
 #define MAX_LAYER_LENGTH            256
 
 /* Translate all Exe and Layer entries to Shims, and propagate all layers */
-static VOID SeiBuildShimRefArray(HSDB hsdb, SDBQUERYRESULT* pQuery, PARRAY pShimRef)
+static VOID SeiBuildShimRefArray(HSDB hsdb, SDBQUERYRESULT* pQuery, PARRAY pShimRef, PFLAGINFO pFlagInfo)
 {
     WCHAR wszLayerEnvVar[MAX_LAYER_LENGTH] = { 0 };
     DWORD n;
@@ -415,6 +447,7 @@ static VOID SeiBuildShimRefArray(HSDB hsdb, SDBQUERYRESULT* pQuery, PARRAY pShim
         {
             LPCWSTR ExeName = SeiGetStringPtr(pdb, tag, TAG_NAME);
             TAGID ShimRef = SdbFindFirstTag(pdb, tag, TAG_SHIM_REF);
+            TAGID FlagRef = SdbFindFirstTag(pdb, tag, TAG_FLAG_REF);
 
             if (ExeName)
                 SeiDbgPrint(SEI_MSG, NULL, "ShimInfo(Exe(%S))\n", ExeName);
@@ -425,11 +458,15 @@ static VOID SeiBuildShimRefArray(HSDB hsdb, SDBQUERYRESULT* pQuery, PARRAY pShim
                 if (SdbTagIDToTagRef(hsdb, pdb, ShimRef, &trShimRef))
                     SeiAddShim(trShimRef, pShimRef);
 
-
                 ShimRef = SdbFindNextTag(pdb, tag, ShimRef);
             }
 
-            /* Handle FLAG_REF */
+            while (FlagRef != TAGID_NULL)
+            {
+                SeiAddFlag(pdb, FlagRef, pFlagInfo);
+
+                FlagRef = SdbFindNextTag(pdb, tag, FlagRef);
+            }
         }
     }
 
@@ -442,6 +479,8 @@ static VOID SeiBuildShimRefArray(HSDB hsdb, SDBQUERYRESULT* pQuery, PARRAY pShim
         {
             LPCWSTR LayerName = SeiGetStringPtr(pdb, tag, TAG_NAME);
             TAGID ShimRef = SdbFindFirstTag(pdb, tag, TAG_SHIM_REF);
+            TAGID FlagRef = SdbFindFirstTag(pdb, tag, TAG_FLAG_REF);
+
             if (LayerName)
             {
                 HRESULT hr;
@@ -464,7 +503,12 @@ static VOID SeiBuildShimRefArray(HSDB hsdb, SDBQUERYRESULT* pQuery, PARRAY pShim
                 ShimRef = SdbFindNextTag(pdb, tag, ShimRef);
             }
 
-            /* Handle FLAG_REF */
+            while (FlagRef != TAGID_NULL)
+            {
+                SeiAddFlag(pdb, FlagRef, pFlagInfo);
+
+                FlagRef = SdbFindNextTag(pdb, tag, FlagRef);
+            }
         }
     }
     if (wszLayerEnvVar[0])
@@ -541,7 +585,8 @@ VOID SeiAddHooks(PHOOKAPIEX hooks, DWORD dwHookCount, PSHIMINFO pShim)
             }
         }
         pHookApi = ARRAY_Append(&HookModuleInfo->HookApis, PHOOKAPIEX);
-        *pHookApi = hook;
+        if (pHookApi)
+            *pHookApi = hook;
     }
 }
 
@@ -556,9 +601,9 @@ FARPROC WINAPI StubGetProcAddress(HINSTANCE hModule, LPCSTR lpProcName)
     PHOOKMODULEINFO HookModuleInfo;
     FARPROC proc = ((GETPROCADDRESSPROC)g_IntHookEx[0].OriginalFunction)(hModule, lpProcName);
 
-    if (!HIWORD(lpProcName))
+    if ((DWORD_PTR)lpProcName <= MAXUSHORT)
     {
-        sprintf(szOrdProcName, "#%lu", (DWORD)lpProcName);
+        sprintf(szOrdProcName, "#%Iu", (DWORD_PTR)lpProcName);
         lpPrintName = szOrdProcName;
     }
 
@@ -569,7 +614,7 @@ FARPROC WINAPI StubGetProcAddress(HINSTANCE hModule, LPCSTR lpProcName)
         return proc;
     }
 
-    SHIMENG_MSG("(GetProcAddress(%p!%s) => %p\n", hModule, lpPrintName, proc);
+    SHIMENG_INFO("(GetProcAddress(%p!%s) => %p\n", hModule, lpPrintName, proc);
 
     HookModuleInfo = SeiFindHookModuleInfo(NULL, hModule);
 
@@ -635,7 +680,7 @@ VOID SeiPatchNewImport(PIMAGE_THUNK_DATA FirstThunk, PHOOKAPIEX HookApi, PLDR_DA
 {
     ULONG OldProtection = 0;
     PVOID Ptr;
-    ULONG Size;
+    SIZE_T Size;
     NTSTATUS Status;
 
     SHIMENG_INFO("Hooking API \"%s!%s\" for DLL \"%wZ\"\n", HookApi->LibraryName, HookApi->FunctionName, &LdrEntry->BaseDllName);
@@ -827,7 +872,8 @@ VOID SeiHookImports(PLDR_DATA_TABLE_ENTRY LdrEntry)
     PIMAGE_IMPORT_DESCRIPTOR ImportDescriptor;
     PBYTE DllBase = LdrEntry->DllBase;
 
-    if (SE_IsShimDll(DllBase) || g_hInstance == LdrEntry->DllBase)
+    if (SE_IsShimDll(DllBase) || g_hInstance == LdrEntry->DllBase ||
+        (g_LoadingShimDll.Buffer && RtlEqualUnicodeString(&g_LoadingShimDll, &LdrEntry->BaseDllName, TRUE)))
     {
         SHIMENG_INFO("Skipping shim module 0x%p \"%wZ\"\n", LdrEntry->DllBase, &LdrEntry->BaseDllName);
         return;
@@ -954,15 +1000,12 @@ VOID SeiInitPaths(VOID)
 #undef WINSXS
 }
 
-/*
-Level(INFO) Using USER apphack flags 0x2080000
-*/
-
 VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
 {
     DWORD n;
     ARRAY ShimRefArray;
     DWORD dwTotalHooks = 0;
+    FLAGINFO ShimFlags;
 
     PPEB Peb = NtCurrentPeb();
 
@@ -973,6 +1016,7 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
     ARRAY_Init(&g_pShimInfo, PSHIMMODULE);
     ARRAY_Init(&g_pHookArray, HOOKMODULEINFO);
     ARRAY_Init(&g_InExclude, INEXCLUDE);
+    RtlZeroMemory(&ShimFlags, sizeof(ShimFlags));
 
     SeiInitPaths();
 
@@ -984,7 +1028,22 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
     */
 
     SeiDbgPrint(SEI_MSG, NULL, "ShimInfo(ExePath(%wZ))\n", ProcessImage);
-    SeiBuildShimRefArray(hsdb, pQuery, &ShimRefArray);
+    SeiBuildShimRefArray(hsdb, pQuery, &ShimRefArray, &ShimFlags);
+    if (ShimFlags.AppCompatFlags.QuadPart)
+    {
+        SeiDbgPrint(SEI_MSG, NULL, "Using KERNEL apphack flags 0x%I64x\n", ShimFlags.AppCompatFlags.QuadPart);
+        Peb->AppCompatFlags.QuadPart |= ShimFlags.AppCompatFlags.QuadPart;
+    }
+    if (ShimFlags.AppCompatFlagsUser.QuadPart)
+    {
+        SeiDbgPrint(SEI_MSG, NULL, "Using USER apphack flags 0x%I64x\n", ShimFlags.AppCompatFlagsUser.QuadPart);
+        Peb->AppCompatFlagsUser.QuadPart |= ShimFlags.AppCompatFlagsUser.QuadPart;
+    }
+    if (ShimFlags.ProcessParameters_Flags)
+    {
+        SeiDbgPrint(SEI_MSG, NULL, "Using ProcessParameters flags 0x%x\n", ShimFlags.ProcessParameters_Flags);
+        Peb->ProcessParameters->Flags |= ShimFlags.ProcessParameters_Flags;
+    }
     SeiDbgPrint(SEI_MSG, NULL, "ShimInfo(Complete)\n");
 
     SHIMENG_INFO("Got %d shims\n", ARRAY_Size(&ShimRefArray));
@@ -1040,7 +1099,7 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
                 continue;
             }
 
-            if (!SdbGetAppPatchDir(NULL, FullNameBuffer, ARRAYSIZE(FullNameBuffer)))
+            if (!SUCCEEDED(SdbGetAppPatchDir(NULL, FullNameBuffer, ARRAYSIZE(FullNameBuffer))))
             {
                 SHIMENG_WARN("Failed to get the AppPatch dir\n");
                 continue;
@@ -1055,6 +1114,7 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
                 continue;
             }
 
+            RtlInitUnicodeString(&g_LoadingShimDll, DllName);
             RtlInitUnicodeString(&UnicodeDllName, FullNameBuffer);
             if (NT_SUCCESS(LdrGetDllHandle(NULL, NULL, &UnicodeDllName, &BaseAddress)))
             {
@@ -1066,6 +1126,7 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
                 SHIMENG_WARN("Failed to load %wZ for %S\n", &UnicodeDllName, ShimName);
                 continue;
             }
+            RtlInitUnicodeString(&g_LoadingShimDll, NULL);
             /* No shim module found (or we just loaded it) */
             if (!pShimModuleInfo)
             {
@@ -1081,10 +1142,13 @@ VOID SeiInit(PUNICODE_STRING ProcessImage, HSDB hsdb, SDBQUERYRESULT* pQuery)
             SHIMENG_INFO("Using SHIM \"%S!%S\"\n", DllName, ShimName);
 
             /* Ask this shim what hooks it needs (and pass along the commandline) */
+            dwHookCount = 0;
             pHookApi = pShimModuleInfo->pGetHookAPIs(AnsiCommandLine.Buffer, ShimName, &dwHookCount);
             SHIMENG_INFO("GetHookAPIs returns %d hooks for DLL \"%wZ\" SHIM \"%S\"\n", dwHookCount, &UnicodeDllName, ShimName);
-            if (dwHookCount)
+            if (dwHookCount && pHookApi)
                 pShimInfo = SeiAppendHookInfo(pShimModuleInfo, pHookApi, dwHookCount, ShimName);
+            else
+                dwHookCount = 0;
 
             /* If this shim has hooks, create the include / exclude lists */
             if (pShimInfo)
@@ -1165,7 +1229,7 @@ VOID NTAPI SE_InstallBeforeInit(PUNICODE_STRING ProcessImage, PVOID pShimData)
 {
     HSDB hsdb = NULL;
     SDBQUERYRESULT QueryResult = { { 0 } };
-    SHIMENG_MSG("(%wZ, %p)\n", ProcessImage, pShimData);
+    SHIMENG_INFO("(%wZ, %p)\n", ProcessImage, pShimData);
 
     if (!SeiGetShimData(ProcessImage, pShimData, &hsdb, &QueryResult))
     {