[KMTESTS]
[reactos.git] / rostests / kmtests / ntos_ke / KeGuardedMutex.c
index e0d4e10..9b2a79e 100644 (file)
 #define NDEBUG
 #include <debug.h>
 
+static
+_IRQL_requires_min_(PASSIVE_LEVEL)
+_IRQL_requires_max_(DISPATCH_LEVEL)
+BOOLEAN
+(NTAPI
+*pKeAreAllApcsDisabled)(VOID);
+
+static
+_Acquires_lock_(_Global_critical_region_)
+_Requires_lock_not_held_(*Mutex)
+_Acquires_lock_(*Mutex)
+_IRQL_requires_max_(APC_LEVEL)
+_IRQL_requires_min_(PASSIVE_LEVEL)
+VOID
+(FASTCALL
+*pKeAcquireGuardedMutex)(
+    _Inout_ PKGUARDED_MUTEX GuardedMutex);
+
+static
+_Requires_lock_not_held_(*FastMutex)
+_Acquires_lock_(*FastMutex)
+_IRQL_requires_max_(APC_LEVEL)
+_IRQL_requires_min_(PASSIVE_LEVEL)
+VOID
+(FASTCALL
+*pKeAcquireGuardedMutexUnsafe)(
+    _Inout_ PKGUARDED_MUTEX GuardedMutex);
+
+static
+_Acquires_lock_(_Global_critical_region_)
+_IRQL_requires_max_(APC_LEVEL)
+VOID
+(NTAPI
+*pKeEnterGuardedRegion)(VOID);
+
+static
+_Releases_lock_(_Global_critical_region_)
+_IRQL_requires_max_(APC_LEVEL)
+VOID
+(NTAPI
+*pKeLeaveGuardedRegion)(VOID);
+
+static
+_IRQL_requires_max_(APC_LEVEL)
+_IRQL_requires_min_(PASSIVE_LEVEL)
+VOID
+(FASTCALL
+*pKeInitializeGuardedMutex)(
+    _Out_ PKGUARDED_MUTEX GuardedMutex);
+
+static
+_Requires_lock_held_(*FastMutex)
+_Releases_lock_(*FastMutex)
+_IRQL_requires_max_(APC_LEVEL)
+VOID
+(FASTCALL
+*pKeReleaseGuardedMutexUnsafe)(
+    _Inout_ PKGUARDED_MUTEX GuardedMutex);
+
+static
+_Releases_lock_(_Global_critical_region_)
+_Requires_lock_held_(*Mutex)
+_Releases_lock_(*Mutex)
+_IRQL_requires_max_(APC_LEVEL)
+VOID
+(FASTCALL
+*pKeReleaseGuardedMutex)(
+    _Inout_ PKGUARDED_MUTEX GuardedMutex);
+
+static
+_Must_inspect_result_
+_Success_(return != FALSE)
+_IRQL_requires_max_(APC_LEVEL)
+_Post_satisfies_(return == 1 || return == 0)
+BOOLEAN
+(FASTCALL
+*pKeTryToAcquireGuardedMutex)(
+    _When_ (return, _Requires_lock_not_held_(*_Curr_) _Acquires_exclusive_lock_(*_Curr_)) _Acquires_lock_(_Global_critical_region_)
+        _Inout_ PKGUARDED_MUTEX GuardedMutex);
+
 #define CheckMutex(Mutex, ExpectedCount, ExpectedOwner, ExpectedContention,     \
                    ExpectedKernelApcDisable, ExpectedSpecialApcDisable,         \
                    KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled,    \
         ok_eq_int((Mutex)->SpecialApcDisable, 0x5555);                          \
     ok_eq_bool(KeAreApcsDisabled(), KernelApcsDisabled || SpecialApcsDisabled); \
     ok_eq_int(Thread->KernelApcDisable, KernelApcsDisabled);                    \
-    ok_eq_bool(KeAreAllApcsDisabled(), AllApcsDisabled);                        \
+    ok_eq_bool(pKeAreAllApcsDisabled(), AllApcsDisabled);                       \
     ok_eq_int(Thread->SpecialApcDisable, SpecialApcsDisabled);                  \
     ok_irql(ExpectedIrql);                                                      \
 } while (0)
@@ -48,17 +128,17 @@ TestGuardedMutex(
     if (!KmtIsCheckedBuild || OriginalIrql <= APC_LEVEL)
     {
         /* acquire/release normally */
-        KeAcquireGuardedMutex(Mutex);
+        pKeAcquireGuardedMutex(Mutex);
         CheckMutex(Mutex, 0L, Thread, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled - 1, TRUE, OriginalIrql);
-        ok_bool_false(KeTryToAcquireGuardedMutex(Mutex), "KeTryToAcquireGuardedMutex returned");
+        ok_bool_false(pKeTryToAcquireGuardedMutex(Mutex), "KeTryToAcquireGuardedMutex returned");
         CheckMutex(Mutex, 0L, Thread, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled - 1, TRUE, OriginalIrql);
-        KeReleaseGuardedMutex(Mutex);
+        pKeReleaseGuardedMutex(Mutex);
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
 
         /* try to acquire */
-        ok_bool_true(KeTryToAcquireGuardedMutex(Mutex), "KeTryToAcquireGuardedMutex returned");
+        ok_bool_true(pKeTryToAcquireGuardedMutex(Mutex), "KeTryToAcquireGuardedMutex returned");
         CheckMutex(Mutex, 0L, Thread, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled - 1, TRUE, OriginalIrql);
-        KeReleaseGuardedMutex(Mutex);
+        pKeReleaseGuardedMutex(Mutex);
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
     }
     else
@@ -69,9 +149,9 @@ TestGuardedMutex(
     if (!KmtIsCheckedBuild || OriginalIrql == APC_LEVEL || SpecialApcsDisabled < 0)
     {
         /* acquire/release unsafe */
-        KeAcquireGuardedMutexUnsafe(Mutex);
+        pKeAcquireGuardedMutexUnsafe(Mutex);
         CheckMutex(Mutex, 0L, Thread, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
-        KeReleaseGuardedMutexUnsafe(Mutex);
+        pKeReleaseGuardedMutexUnsafe(Mutex);
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
     }
 
@@ -79,29 +159,29 @@ TestGuardedMutex(
     if (!KmtIsCheckedBuild)
     {
         /* mismatched acquire/release */
-        KeAcquireGuardedMutex(Mutex);
+        pKeAcquireGuardedMutex(Mutex);
         CheckMutex(Mutex, 0L, Thread, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled - 1, TRUE, OriginalIrql);
-        KeReleaseGuardedMutexUnsafe(Mutex);
+        pKeReleaseGuardedMutexUnsafe(Mutex);
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled - 1, TRUE, OriginalIrql);
-        KeLeaveGuardedRegion();
+        pKeLeaveGuardedRegion();
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
 
-        KeAcquireGuardedMutexUnsafe(Mutex);
+        pKeAcquireGuardedMutexUnsafe(Mutex);
         CheckMutex(Mutex, 0L, Thread, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
-        KeReleaseGuardedMutex(Mutex);
+        pKeReleaseGuardedMutex(Mutex);
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled + 1, OriginalIrql >= APC_LEVEL || SpecialApcsDisabled != -1, OriginalIrql);
-        KeEnterGuardedRegion();
+        pKeEnterGuardedRegion();
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
 
         /* release without acquire */
-        KeReleaseGuardedMutexUnsafe(Mutex);
+        pKeReleaseGuardedMutexUnsafe(Mutex);
         CheckMutex(Mutex, 0L, NULL, 0LU, 0x5555, SpecialApcsDisabled - 1, KernelApcsDisabled, SpecialApcsDisabled, AllApcsDisabled, OriginalIrql);
-        KeReleaseGuardedMutex(Mutex);
+        pKeReleaseGuardedMutex(Mutex);
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled, KernelApcsDisabled, SpecialApcsDisabled + 1, OriginalIrql >= APC_LEVEL || SpecialApcsDisabled != -1, OriginalIrql);
-        KeReleaseGuardedMutex(Mutex);
+        pKeReleaseGuardedMutex(Mutex);
         /* TODO: here we see that Mutex->Count isn't actually just a count. Test the bits correctly! */
         CheckMutex(Mutex, 0L, NULL, 0LU, 0x5555, SpecialApcsDisabled, KernelApcsDisabled, SpecialApcsDisabled + 2, OriginalIrql >= APC_LEVEL || SpecialApcsDisabled != -2, OriginalIrql);
-        KeReleaseGuardedMutex(Mutex);
+        pKeReleaseGuardedMutex(Mutex);
         CheckMutex(Mutex, 1L, NULL, 0LU, 0x5555, SpecialApcsDisabled, KernelApcsDisabled, SpecialApcsDisabled + 3, OriginalIrql >= APC_LEVEL || SpecialApcsDisabled != -3, OriginalIrql);
         Thread->SpecialApcDisable -= 3;
     }
@@ -242,10 +322,10 @@ TestGuardedMutexConcurrent(
     LARGE_INTEGER Timeout;
     Timeout.QuadPart = -50 * 1000 * 10; /* 50 ms */
 
-    InitThreadData(&ThreadData, Mutex, KeAcquireGuardedMutex, NULL, KeReleaseGuardedMutex);
-    InitThreadData(&ThreadData2, Mutex, KeAcquireGuardedMutex, NULL, KeReleaseGuardedMutex);
-    InitThreadData(&ThreadDataUnsafe, Mutex, KeAcquireGuardedMutexUnsafe, NULL, KeReleaseGuardedMutexUnsafe);
-    InitThreadData(&ThreadDataTry, Mutex, NULL, KeTryToAcquireGuardedMutex, KeReleaseGuardedMutex);
+    InitThreadData(&ThreadData, Mutex, pKeAcquireGuardedMutex, NULL, pKeReleaseGuardedMutex);
+    InitThreadData(&ThreadData2, Mutex, pKeAcquireGuardedMutex, NULL, pKeReleaseGuardedMutex);
+    InitThreadData(&ThreadDataUnsafe, Mutex, pKeAcquireGuardedMutexUnsafe, NULL, pKeReleaseGuardedMutexUnsafe);
+    InitThreadData(&ThreadDataTry, Mutex, NULL, pKeTryToAcquireGuardedMutex, pKeReleaseGuardedMutex);
 
     /* have a thread acquire the mutex */
     Status = StartThread(&ThreadData, NULL, PASSIVE_LEVEL, FALSE, FALSE);
@@ -337,6 +417,29 @@ START_TEST(KeGuardedMutex)
     };
     int i;
 
+    pKeAreAllApcsDisabled = KmtGetSystemRoutineAddress(L"KeAreAllApcsDisabled");
+    pKeInitializeGuardedMutex = KmtGetSystemRoutineAddress(L"KeInitializeGuardedMutex");
+    pKeAcquireGuardedMutex = KmtGetSystemRoutineAddress(L"KeAcquireGuardedMutex");
+    pKeAcquireGuardedMutexUnsafe = KmtGetSystemRoutineAddress(L"KeAcquireGuardedMutexUnsafe");
+    pKeEnterGuardedRegion = KmtGetSystemRoutineAddress(L"KeEnterGuardedRegion");
+    pKeLeaveGuardedRegion = KmtGetSystemRoutineAddress(L"KeLeaveGuardedRegion");
+    pKeReleaseGuardedMutex = KmtGetSystemRoutineAddress(L"KeReleaseGuardedMutex");
+    pKeReleaseGuardedMutexUnsafe = KmtGetSystemRoutineAddress(L"KeReleaseGuardedMutexUnsafe");
+    pKeTryToAcquireGuardedMutex = KmtGetSystemRoutineAddress(L"KeTryToAcquireGuardedMutex");
+
+    if (skip(pKeAreAllApcsDisabled &&
+             pKeInitializeGuardedMutex &&
+             pKeAcquireGuardedMutex &&
+             pKeAcquireGuardedMutexUnsafe &&
+             pKeEnterGuardedRegion &&
+             pKeLeaveGuardedRegion &&
+             pKeReleaseGuardedMutex &&
+             pKeReleaseGuardedMutexUnsafe &&
+             pKeTryToAcquireGuardedMutex, "No guarded mutexes\n"))
+    {
+        return;
+    }
+
     for (i = 0; i < sizeof TestIterations / sizeof TestIterations[0]; ++i)
     {
         trace("Run %d\n", i);
@@ -345,7 +448,7 @@ START_TEST(KeGuardedMutex)
         Thread->SpecialApcDisable = TestIterations[i].SpecialApcsDisabled;
 
         RtlFillMemory(&Mutex, sizeof Mutex, 0x55);
-        KeInitializeGuardedMutex(&Mutex);
+        pKeInitializeGuardedMutex(&Mutex);
         CheckMutex(&Mutex, 1L, NULL, 0LU, 0x5555, 0x5555, TestIterations[i].KernelApcsDisabled, TestIterations[i].SpecialApcsDisabled, TestIterations[i].AllApcsDisabled, TestIterations[i].Irql);
         TestGuardedMutex(&Mutex, TestIterations[i].KernelApcsDisabled, TestIterations[i].SpecialApcsDisabled, TestIterations[i].AllApcsDisabled, TestIterations[i].Irql);
 
@@ -356,6 +459,6 @@ START_TEST(KeGuardedMutex)
 
     trace("Concurrent test\n");
     RtlFillMemory(&Mutex, sizeof Mutex, 0x55);
-    KeInitializeGuardedMutex(&Mutex);
+    pKeInitializeGuardedMutex(&Mutex);
     TestGuardedMutexConcurrent(&Mutex);
 }