[KMTESTS]
[reactos.git] / kmtests / include / kmt_test.h
index 447d775..cc1e815 100644 (file)
@@ -16,6 +16,7 @@
 #include <stdarg.h>
 
 typedef VOID KMT_TESTFUNC(VOID);
+typedef KMT_TESTFUNC *PKMT_TESTFUNC;
 
 typedef struct
 {
@@ -27,38 +28,158 @@ typedef const KMT_TEST CKMT_TEST, *PCKMT_TEST;
 
 extern const KMT_TEST TestList[];
 
-typedef struct {
+typedef struct
+{
     volatile LONG Successes;
     volatile LONG Failures;
+    volatile LONG Skipped;
     volatile LONG LogBufferLength;
     LONG LogBufferMaxLength;
     CHAR LogBuffer[ANYSIZE_ARRAY];
 } KMT_RESULTBUFFER, *PKMT_RESULTBUFFER;
 
+#ifdef KMT_STANDALONE_DRIVER
+#define KMT_KERNEL_MODE
+
+typedef NTSTATUS (KMT_IRP_HANDLER)(
+    IN PDEVICE_OBJECT DeviceObject,
+    IN PIRP Irp,
+    IN PIO_STACK_LOCATION IoStackLocation);
+typedef KMT_IRP_HANDLER *PKMT_IRP_HANDLER;
+
+NTSTATUS KmtRegisterIrpHandler(IN UCHAR MajorFunction, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_IRP_HANDLER IrpHandler);
+NTSTATUS KmtUnregisterIrpHandler(IN UCHAR MajorFunction, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_IRP_HANDLER IrpHandler);
+
+typedef NTSTATUS (KMT_MESSAGE_HANDLER)(
+    IN PDEVICE_OBJECT DeviceObject,
+    IN ULONG ControlCode,
+    IN PVOID Buffer OPTIONAL,
+    IN SIZE_T InLength,
+    IN OUT PSIZE_T OutLength);
+typedef KMT_MESSAGE_HANDLER *PKMT_MESSAGE_HANDLER;
+
+NTSTATUS KmtRegisterMessageHandler(IN ULONG ControlCode OPTIONAL, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_MESSAGE_HANDLER MessageHandler);
+NTSTATUS KmtUnregisterMessageHandler(IN ULONG ControlCode OPTIONAL, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_MESSAGE_HANDLER MessageHandler);
+
+typedef enum
+{
+    TESTENTRY_NO_CREATE_DEVICE = 1,
+    TESTENTRY_NO_REGISTER_DISPATCH = 2,
+    TESTENTRY_NO_REGISTER_UNLOAD = 4,
+} KMT_TESTENTRY_FLAGS;
+
+NTSTATUS TestEntry(IN PDRIVER_OBJECT DriverObject, IN PCUNICODE_STRING RegistryPath, OUT PCWSTR *DeviceName, OUT INT *Flags);
+VOID TestUnload(IN PDRIVER_OBJECT DriverObject);
+#endif /* defined KMT_STANDALONE_DRIVER */
+
+#ifdef KMT_KERNEL_MODE
+/* Device Extension layout */
+typedef struct
+{
+    PKMT_RESULTBUFFER ResultBuffer;
+    PMDL Mdl;
+} KMT_DEVICE_EXTENSION, *PKMT_DEVICE_EXTENSION;
+
+extern BOOLEAN KmtIsCheckedBuild;
+extern BOOLEAN KmtIsMultiProcessorBuild;
+
+VOID KmtSetIrql(IN KIRQL NewIrql);
+BOOLEAN KmtAreInterruptsEnabled(VOID);
+#elif defined KMT_USER_MODE
+DWORD KmtRunKernelTest(IN PCSTR TestName);
+
+VOID KmtLoadDriver(IN PCWSTR ServiceName, IN BOOLEAN RestartIfRunning);
+VOID KmtUnloadDriver(VOID);
+VOID KmtOpenDriver(VOID);
+VOID KmtCloseDriver(VOID);
+
+DWORD KmtSendToDriver(IN DWORD ControlCode);
+DWORD KmtSendStringToDriver(IN DWORD ControlCode, IN PCSTR String);
+DWORD KmtSendBufferToDriver(IN DWORD ControlCode, IN OUT PVOID Buffer OPTIONAL, IN DWORD InLength, IN OUT PDWORD OutLength);
+#endif /* defined KMT_USER_MODE */
+
 extern PKMT_RESULTBUFFER ResultBuffer;
 
-#define KMT_STRINGIZE(x) #x
-#define ok(test, ...)               ok_(test, __FILE__, __LINE__, __VA_ARGS__)
-#define trace(...)                  trace_(   __FILE__, __LINE__, __VA_ARGS__)
+#ifdef __GNUC__
+#define KMT_FORMAT(type, fmt, first) __attribute__((__format__(type, fmt, first)))
+#elif !defined __GNUC__
+#define KMT_FORMAT(type, fmt, first)
+#endif /* !defined __GNUC__ */
+
+#define START_TEST(name) VOID Test_##name(VOID)
 
-#define ok_(test, file, line, ...)  KmtOk(test, file ":" KMT_STRINGIZE(line), __VA_ARGS__)
-#define trace_(file, line, ...)     KmtTrace(   file ":" KMT_STRINGIZE(line), __VA_ARGS__)
+#ifndef KMT_STRINGIZE
+#define KMT_STRINGIZE(x) #x
+#endif /* !defined KMT_STRINGIZE */
+#define ok(test, ...)                ok_(test,   __FILE__, __LINE__, __VA_ARGS__)
+#define trace(...)                   trace_(     __FILE__, __LINE__, __VA_ARGS__)
+#define skip(test, ...)              skip_(test, __FILE__, __LINE__, __VA_ARGS__)
+
+#define ok_(test, file, line, ...)   KmtOk(test,   file ":" KMT_STRINGIZE(line), __VA_ARGS__)
+#define trace_(file, line, ...)      KmtTrace(     file ":" KMT_STRINGIZE(line), __VA_ARGS__)
+#define skip_(test, file, line, ...) KmtSkip(test, file ":" KMT_STRINGIZE(line), __VA_ARGS__)
+
+VOID KmtVOk(INT Condition, PCSTR FileAndLine, PCSTR Format, va_list Arguments)      KMT_FORMAT(ms_printf, 3, 0);
+VOID KmtOk(INT Condition, PCSTR FileAndLine, PCSTR Format, ...)                     KMT_FORMAT(ms_printf, 3, 4);
+VOID KmtVTrace(PCSTR FileAndLine, PCSTR Format, va_list Arguments)                  KMT_FORMAT(ms_printf, 2, 0);
+VOID KmtTrace(PCSTR FileAndLine, PCSTR Format, ...)                                 KMT_FORMAT(ms_printf, 2, 3);
+BOOLEAN KmtVSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, va_list Arguments) KMT_FORMAT(ms_printf, 3, 0);
+BOOLEAN KmtSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, ...)                KMT_FORMAT(ms_printf, 3, 4);
 
-VOID KmtVOk(INT Condition, PCSTR FileAndLine, PCSTR Format, va_list Arguments);
-VOID KmtOk(INT Condition, PCSTR FileAndLine, PCSTR Format, ...);
-VOID KmtVTrace(PCSTR FileAndLine, PCSTR Format, va_list Arguments);
-VOID KmtTrace(PCSTR FileAndLine, PCSTR Format, ...);
+#ifdef KMT_KERNEL_MODE
+#define ok_irql(irql)                       ok(KeGetCurrentIrql() == irql, "IRQL is %d, expected %d\n", KeGetCurrentIrql(), irql)
+#endif /* defined KMT_KERNEL_MODE */
+#define ok_eq_print(value, expected, spec)  ok((value) == (expected), #value " = " spec ", expected " spec "\n", value, expected)
+#define ok_eq_pointer(value, expected)      ok_eq_print(value, expected, "%p")
+#define ok_eq_int(value, expected)          ok_eq_print(value, expected, "%d")
+#define ok_eq_uint(value, expected)         ok_eq_print(value, expected, "%u")
+#define ok_eq_long(value, expected)         ok_eq_print(value, expected, "%ld")
+#define ok_eq_ulong(value, expected)        ok_eq_print(value, expected, "%lu")
+#define ok_eq_longlong(value, expected)     ok_eq_print(value, expected, "%I64d")
+#define ok_eq_ulonglong(value, expected)    ok_eq_print(value, expected, "%I64u")
+#define ok_eq_hex(value, expected)          ok_eq_print(value, expected, "0x%08lx")
+#define ok_bool_true(value, desc)           ok((value) == TRUE, desc " FALSE, expected TRUE\n")
+#define ok_bool_false(value, desc)          ok((value) == FALSE, desc " TRUE, expected FALSE\n")
+#define ok_eq_bool(value, expected)         ok((value) == (expected), #value " = %s, expected %s\n",    \
+                                                (value) ? "TRUE" : "FALSE",                             \
+                                                (expected) ? "TRUE" : "FALSE")
+#define ok_eq_str(value, expected)          ok(!strcmp(value, expected), #value " = \"%s\", expected \"%s\"\n", value, expected)
+#define ok_eq_wstr(value, expected)         ok(!wcscmp(value, expected), #value " = \"%ls\", expected \"%ls\"\n", value, expected)
+
+#define KMT_MAKE_CODE(ControlCode)  CTL_CODE(FILE_DEVICE_UNKNOWN,           \
+                                             0xC00 + (ControlCode),         \
+                                             METHOD_BUFFERED,               \
+                                             FILE_ANY_ACCESS)
 
 #if defined KMT_DEFINE_TEST_FUNCTIONS
-PKMT_RESULTBUFFER ResultBuffer = NULL;
 
-#if defined KMT_USER_MODE
+#if defined KMT_KERNEL_MODE
+BOOLEAN KmtIsCheckedBuild;
+BOOLEAN KmtIsMultiProcessorBuild;
+
+VOID KmtSetIrql(IN KIRQL NewIrql)
+{
+    KIRQL Irql = KeGetCurrentIrql();
+    if (Irql > NewIrql)
+        KeLowerIrql(NewIrql);
+    else if (Irql < NewIrql)
+        KeRaiseIrql(NewIrql, &Irql);
+}
+
+BOOLEAN KmtAreInterruptsEnabled(VOID)
+{
+    return (__readeflags() & (1 << 9)) != 0;
+}
+
+INT __cdecl KmtVSNPrintF(PSTR Buffer, SIZE_T BufferMaxLength, PCSTR Format, va_list Arguments) KMT_FORMAT(ms_printf, 3, 0);
+#elif defined KMT_USER_MODE
 static PKMT_RESULTBUFFER KmtAllocateResultBuffer(SIZE_T LogBufferMaxLength)
 {
     PKMT_RESULTBUFFER Buffer = HeapAlloc(GetProcessHeap(), 0, FIELD_OFFSET(KMT_RESULTBUFFER, LogBuffer[LogBufferMaxLength]));
 
     Buffer->Successes = 0;
     Buffer->Failures = 0;
+    Buffer->Skipped = 0;
     Buffer->LogBufferLength = 0;
     Buffer->LogBufferMaxLength = LogBufferMaxLength;
 
@@ -69,17 +190,20 @@ static VOID KmtFreeResultBuffer(PKMT_RESULTBUFFER Buffer)
 {
     HeapFree(GetProcessHeap(), 0, Buffer);
 }
+
+#define KmtVSNPrintF vsnprintf
 #endif /* defined KMT_USER_MODE */
 
-#define KmtMemCpy memcpy
-#define KmtStrLen strlen
-#define KmtAssert assert
+PKMT_RESULTBUFFER ResultBuffer = NULL;
 
 static VOID KmtAddToLogBuffer(PKMT_RESULTBUFFER Buffer, PCSTR String, SIZE_T Length)
 {
     LONG OldLength;
     LONG NewLength;
 
+    if (!Buffer)
+        return;
+
     do
     {
         OldLength = Buffer->LogBufferLength;
@@ -92,48 +216,55 @@ static VOID KmtAddToLogBuffer(PKMT_RESULTBUFFER Buffer, PCSTR String, SIZE_T Len
         }
     } while (InterlockedCompareExchange(&Buffer->LogBufferLength, NewLength, OldLength) != OldLength);
 
-    KmtMemCpy(&Buffer->LogBuffer[OldLength], String, Length);
+    memcpy(&Buffer->LogBuffer[OldLength], String, Length);
 }
 
-#ifdef KMT_KERNEL_MODE
-INT __cdecl KmtVSNPrintF(PSTR Buffer, SIZE_T BufferMaxLength, PCSTR Format, va_list Arguments);
-#elif defined KMT_USER_MODE
-#define KmtVSNPrintF vsnprintf
-#endif /* defined KMT_USER_MODE */
-
-static SIZE_T KmtXVSNPrintF(PSTR Buffer, SIZE_T BufferMaxLength, PCSTR Prepend1, PCSTR Prepend2, PCSTR Format, va_list Arguments)
+KMT_FORMAT(ms_printf, 5, 0)
+static SIZE_T KmtXVSNPrintF(PSTR Buffer, SIZE_T BufferMaxLength, PCSTR FileAndLine, PCSTR Prepend, PCSTR Format, va_list Arguments)
 {
     SIZE_T BufferLength = 0;
     SIZE_T Length;
 
-    if (Prepend1)
+    if (FileAndLine)
     {
-        SIZE_T Length = min(BufferMaxLength, KmtStrLen(Prepend1));
-        KmtMemCpy(Buffer, Prepend1, Length);
+        PCSTR Slash;
+        Slash = strrchr(FileAndLine, '\\');
+        if (Slash)
+            FileAndLine = Slash + 1;
+        Slash = strrchr(FileAndLine, '/');
+        if (Slash)
+            FileAndLine = Slash + 1;
+
+        Length = min(BufferMaxLength, strlen(FileAndLine));
+        memcpy(Buffer, FileAndLine, Length);
         Buffer += Length;
         BufferLength += Length;
         BufferMaxLength -= Length;
     }
-    if (Prepend2)
+    if (Prepend)
     {
-        SIZE_T Length = min(BufferMaxLength, KmtStrLen(Prepend2));
-        KmtMemCpy(Buffer, Prepend2, Length);
+        Length = min(BufferMaxLength, strlen(Prepend));
+        memcpy(Buffer, Prepend, Length);
         Buffer += Length;
         BufferLength += Length;
         BufferMaxLength -= Length;
     }
-    Length = KmtVSNPrintF(Buffer, BufferMaxLength, Format, Arguments);
-    /* vsnprintf can return more than maxLength, we don't want to do that */
-    BufferLength += min(Length, BufferMaxLength);
+    if (Format)
+    {
+        Length = KmtVSNPrintF(Buffer, BufferMaxLength, Format, Arguments);
+        /* vsnprintf can return more than maxLength, we don't want to do that */
+        BufferLength += min(Length, BufferMaxLength);
+    }
     return BufferLength;
 }
 
-static SIZE_T KmtXSNPrintF(PSTR Buffer, SIZE_T BufferMaxLength, PCSTR Prepend1, PCSTR Prepend2, PCSTR Format, ...)
+KMT_FORMAT(ms_printf, 5, 6)
+static SIZE_T KmtXSNPrintF(PSTR Buffer, SIZE_T BufferMaxLength, PCSTR FileAndLine, PCSTR Prepend, PCSTR Format, ...)
 {
     SIZE_T BufferLength;
     va_list Arguments;
     va_start(Arguments, Format);
-    BufferLength = KmtXVSNPrintF(Buffer, BufferMaxLength, Prepend1, Prepend2, Format, Arguments);
+    BufferLength = KmtXVSNPrintF(Buffer, BufferMaxLength, FileAndLine, Prepend, Format, Arguments);
     va_end(Arguments);
     return BufferLength;
 }
@@ -143,11 +274,15 @@ VOID KmtFinishTest(PCSTR TestName)
     CHAR MessageBuffer[512];
     SIZE_T MessageLength;
 
+    if (!ResultBuffer)
+        return;
+
     MessageLength = KmtXSNPrintF(MessageBuffer, sizeof MessageBuffer, NULL, NULL,
-                                    "%s: %d tests executed (0 marked as todo, %d failures), 0 skipped.\n",
+                                    "%s: %ld tests executed (0 marked as todo, %ld failures), %ld skipped.\n",
                                     TestName,
                                     ResultBuffer->Successes + ResultBuffer->Failures,
-                                    ResultBuffer->Failures);
+                                    ResultBuffer->Failures,
+                                    ResultBuffer->Skipped);
     KmtAddToLogBuffer(ResultBuffer, MessageBuffer, MessageLength);
 }
 
@@ -156,13 +291,16 @@ VOID KmtVOk(INT Condition, PCSTR FileAndLine, PCSTR Format, va_list Arguments)
     CHAR MessageBuffer[512];
     SIZE_T MessageLength;
 
+    if (!ResultBuffer)
+        return;
+
     if (Condition)
     {
         InterlockedIncrement(&ResultBuffer->Successes);
 
         if (0/*KmtReportSuccess*/)
         {
-            MessageLength = KmtXSNPrintF(MessageBuffer, sizeof MessageBuffer, FileAndLine, ": Test succeeded\n", "");
+            MessageLength = KmtXSNPrintF(MessageBuffer, sizeof MessageBuffer, FileAndLine, ": Test succeeded\n", NULL);
             KmtAddToLogBuffer(ResultBuffer, MessageBuffer, MessageLength);
         }
     }
@@ -199,6 +337,34 @@ VOID KmtTrace(PCSTR FileAndLine, PCSTR Format, ...)
     va_end(Arguments);
 }
 
+BOOLEAN KmtVSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, va_list Arguments)
+{
+    CHAR MessageBuffer[512];
+    SIZE_T MessageLength;
+
+    if (!ResultBuffer)
+        return !Condition;
+
+    if (!Condition)
+    {
+        InterlockedIncrement(&ResultBuffer->Skipped);
+        MessageLength = KmtXVSNPrintF(MessageBuffer, sizeof MessageBuffer, FileAndLine, ": Tests skipped: ", Format, Arguments);
+        KmtAddToLogBuffer(ResultBuffer, MessageBuffer, MessageLength);
+    }
+
+    return !Condition;
+}
+
+BOOLEAN KmtSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, ...)
+{
+    BOOLEAN Ret;
+    va_list Arguments;
+    va_start(Arguments, Format);
+    Ret = KmtVSkip(Condition, FileAndLine, Format, Arguments);
+    va_end(Arguments);
+    return Ret;
+}
+
 #endif /* defined KMT_DEFINE_TEST_FUNCTIONS */
 
 #endif /* !defined _KMTEST_TEST_H_ */