[KMTESTS]
[reactos.git] / rostests / kmtests / include / kmt_test.h
index 73ea959..2670b9e 100644 (file)
@@ -83,6 +83,7 @@ typedef struct
 extern BOOLEAN KmtIsCheckedBuild;
 extern BOOLEAN KmtIsMultiProcessorBuild;
 extern PCSTR KmtMajorFunctionNames[];
+extern PDRIVER_OBJECT KmtDriverObject;
 
 VOID KmtSetIrql(IN KIRQL NewIrql);
 BOOLEAN KmtAreInterruptsEnabled(VOID);
@@ -129,6 +130,8 @@ VOID KmtVTrace(PCSTR FileAndLine, PCSTR Format, va_list Arguments)
 VOID KmtTrace(PCSTR FileAndLine, PCSTR Format, ...)                                 KMT_FORMAT(ms_printf, 2, 3);
 BOOLEAN KmtVSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, va_list Arguments) KMT_FORMAT(ms_printf, 3, 0);
 BOOLEAN KmtSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, ...)                KMT_FORMAT(ms_printf, 3, 4);
+PVOID KmtAllocateGuarded(SIZE_T SizeRequested);
+VOID KmtFreeGuarded(PVOID Pointer);
 
 #ifdef KMT_KERNEL_MODE
 #define ok_irql(irql)                       ok(KeGetCurrentIrql() == irql, "IRQL is %d, expected %d\n", KeGetCurrentIrql(), irql)
@@ -231,7 +234,7 @@ static PKMT_RESULTBUFFER KmtAllocateResultBuffer(SIZE_T ResultBufferSize)
     Buffer->Failures = 0;
     Buffer->Skipped = 0;
     Buffer->LogBufferLength = 0;
-    Buffer->LogBufferMaxLength = ResultBufferSize - FIELD_OFFSET(KMT_RESULTBUFFER, LogBuffer);
+    Buffer->LogBufferMaxLength = (ULONG)ResultBufferSize - FIELD_OFFSET(KMT_RESULTBUFFER, LogBuffer);
 
     return Buffer;
 }
@@ -257,7 +260,7 @@ static VOID KmtAddToLogBuffer(PKMT_RESULTBUFFER Buffer, PCSTR String, SIZE_T Len
     do
     {
         OldLength = Buffer->LogBufferLength;
-        NewLength = OldLength + Length;
+        NewLength = OldLength + (ULONG)Length;
         if (NewLength > Buffer->LogBufferMaxLength)
             return;
     } while (InterlockedCompareExchange(&Buffer->LogBufferLength, NewLength, OldLength) != OldLength);
@@ -411,6 +414,44 @@ BOOLEAN KmtSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, ...)
     return Ret;
 }
 
+PVOID KmtAllocateGuarded(SIZE_T SizeRequested)
+{
+    NTSTATUS Status;
+    SIZE_T Size = PAGE_ROUND_UP(SizeRequested + PAGE_SIZE);
+    PVOID VirtualMemory = NULL;
+    PCHAR StartOfBuffer;
+
+    Status = ZwAllocateVirtualMemory(ZwCurrentProcess(), &VirtualMemory, 0, &Size, MEM_RESERVE, PAGE_NOACCESS);
+
+    if (!NT_SUCCESS(Status))
+        return NULL;
+
+    Size -= PAGE_SIZE;
+    Status = ZwAllocateVirtualMemory(ZwCurrentProcess(), &VirtualMemory, 0, &Size, MEM_COMMIT, PAGE_READWRITE);
+    if (!NT_SUCCESS(Status))
+    {
+        Size = 0;
+        Status = ZwFreeVirtualMemory(ZwCurrentProcess(), &VirtualMemory, &Size, MEM_RELEASE);
+        ok_eq_hex(Status, STATUS_SUCCESS);
+        return NULL;
+    }
+
+    StartOfBuffer = VirtualMemory;
+    StartOfBuffer += Size - SizeRequested;
+
+    return StartOfBuffer;
+}
+
+VOID KmtFreeGuarded(PVOID Pointer)
+{
+    NTSTATUS Status;
+    PVOID VirtualMemory = (PVOID)PAGE_ROUND_DOWN((SIZE_T)Pointer);
+    SIZE_T Size = 0;
+
+    Status = ZwFreeVirtualMemory(ZwCurrentProcess(), &VirtualMemory, &Size, MEM_RELEASE);
+    ok_eq_hex(Status, STATUS_SUCCESS);
+}
+
 #endif /* defined KMT_DEFINE_TEST_FUNCTIONS */
 
 #endif /* !defined _KMTEST_TEST_H_ */