[ROSTESTS]
[reactos.git] / rostests / kmtests / kmtest_drv / kmtest_drv.c
diff --git a/rostests/kmtests/kmtest_drv/kmtest_drv.c b/rostests/kmtests/kmtest_drv/kmtest_drv.c
new file mode 100644 (file)
index 0000000..05b5709
--- /dev/null
@@ -0,0 +1,377 @@
+/*
+ * PROJECT:         ReactOS kernel-mode tests
+ * LICENSE:         GPLv2+ - See COPYING in the top level directory
+ * PURPOSE:         Kernel-Mode Test Suite Driver
+ * PROGRAMMER:      Thomas Faber <thfabba@gmx.de>
+ */
+
+#include <ntddk.h>
+#include <ntifs.h>
+#include <ndk/ketypes.h>
+#include <ntstrsafe.h>
+#include <limits.h>
+#include <pseh/pseh2.h>
+
+//#define NDEBUG
+#include <debug.h>
+
+#include <kmt_public.h>
+#define KMT_DEFINE_TEST_FUNCTIONS
+#include <kmt_test.h>
+
+/* Prototypes */
+DRIVER_INITIALIZE DriverEntry;
+static DRIVER_UNLOAD DriverUnload;
+static DRIVER_DISPATCH DriverCreate;
+static DRIVER_DISPATCH DriverClose;
+static DRIVER_DISPATCH DriverIoControl;
+
+/* Globals */
+static PDEVICE_OBJECT MainDeviceObject;
+
+/* Entry */
+/**
+ * @name DriverEntry
+ *
+ * Driver Entry point.
+ *
+ * @param DriverObject
+ *        Driver Object
+ * @param RegistryPath
+ *        Driver Registry Path
+ *
+ * @return Status
+ */
+NTSTATUS
+NTAPI
+DriverEntry(
+    IN PDRIVER_OBJECT DriverObject,
+    IN PUNICODE_STRING RegistryPath)
+{
+    NTSTATUS Status = STATUS_SUCCESS;
+    UNICODE_STRING DeviceName;
+    PKMT_DEVICE_EXTENSION DeviceExtension;
+    PKPRCB Prcb;
+
+    PAGED_CODE();
+
+    UNREFERENCED_PARAMETER(RegistryPath);
+
+    DPRINT("DriverEntry\n");
+
+    Prcb = KeGetCurrentPrcb();
+    KmtIsCheckedBuild = (Prcb->BuildType & PRCB_BUILD_DEBUG) != 0;
+    KmtIsMultiProcessorBuild = (Prcb->BuildType & PRCB_BUILD_UNIPROCESSOR) == 0;
+
+    RtlInitUnicodeString(&DeviceName, KMTEST_DEVICE_DRIVER_PATH);
+    Status = IoCreateDevice(DriverObject, sizeof(KMT_DEVICE_EXTENSION),
+                            &DeviceName,
+                            FILE_DEVICE_UNKNOWN,
+                            FILE_DEVICE_SECURE_OPEN | FILE_READ_ONLY_DEVICE,
+                            FALSE, &MainDeviceObject);
+
+    if (!NT_SUCCESS(Status))
+        goto cleanup;
+
+    DPRINT("DriverEntry. Created DeviceObject %p. DeviceExtension %p\n",
+             MainDeviceObject, MainDeviceObject->DeviceExtension);
+    DeviceExtension = MainDeviceObject->DeviceExtension;
+    DeviceExtension->ResultBuffer = NULL;
+    DeviceExtension->Mdl = NULL;
+
+    DriverObject->DriverUnload = DriverUnload;
+    DriverObject->MajorFunction[IRP_MJ_CREATE] = DriverCreate;
+    DriverObject->MajorFunction[IRP_MJ_CLOSE] = DriverClose;
+    DriverObject->MajorFunction[IRP_MJ_DEVICE_CONTROL] = DriverIoControl;
+
+cleanup:
+    if (MainDeviceObject && !NT_SUCCESS(Status))
+    {
+        IoDeleteDevice(MainDeviceObject);
+        MainDeviceObject = NULL;
+    }
+
+    return Status;
+}
+
+/* Dispatch functions */
+/**
+ * @name DriverUnload
+ *
+ * Driver cleanup funtion.
+ *
+ * @param DriverObject
+ *        Driver Object
+ */
+static
+VOID
+NTAPI
+DriverUnload(
+    IN PDRIVER_OBJECT DriverObject)
+{
+    PAGED_CODE();
+
+    UNREFERENCED_PARAMETER(DriverObject);
+
+    DPRINT("DriverUnload\n");
+
+    if (MainDeviceObject)
+    {
+        PKMT_DEVICE_EXTENSION DeviceExtension = MainDeviceObject->DeviceExtension;
+        ASSERT(!DeviceExtension->Mdl);
+        ASSERT(!DeviceExtension->ResultBuffer);
+        ASSERT(!ResultBuffer);
+        IoDeleteDevice(MainDeviceObject);
+    }
+}
+
+/**
+ * @name DriverCreate
+ *
+ * Driver Dispatch function for CreateFile
+ *
+ * @param DeviceObject
+ *        Device Object
+ * @param Irp
+ *        I/O request packet
+ *
+ * @return Status
+ */
+static
+NTSTATUS
+NTAPI
+DriverCreate(
+    IN PDEVICE_OBJECT DeviceObject,
+    IN PIRP Irp)
+{
+    NTSTATUS Status = STATUS_SUCCESS;
+    PIO_STACK_LOCATION IoStackLocation;
+
+    PAGED_CODE();
+
+    IoStackLocation = IoGetCurrentIrpStackLocation(Irp);
+
+    DPRINT("DriverCreate. DeviceObject=%p, RequestorMode=%d, FileObject=%p, FsContext=%p, FsContext2=%p\n",
+             DeviceObject, Irp->RequestorMode, IoStackLocation->FileObject,
+             IoStackLocation->FileObject->FsContext, IoStackLocation->FileObject->FsContext2);
+
+    Irp->IoStatus.Status = Status;
+    Irp->IoStatus.Information = 0;
+
+    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+
+    return Status;
+}
+
+/**
+ * @name DriverClose
+ *
+ * Driver Dispatch function for CloseHandle.
+ *
+ * @param DeviceObject
+ *        Device Object
+ * @param Irp
+ *        I/O request packet
+ *
+ * @return Status
+ */
+static
+NTSTATUS
+NTAPI
+DriverClose(
+    IN PDEVICE_OBJECT DeviceObject,
+    IN PIRP Irp)
+{
+    NTSTATUS Status = STATUS_SUCCESS;
+    PIO_STACK_LOCATION IoStackLocation;
+    PKMT_DEVICE_EXTENSION DeviceExtension;
+
+    PAGED_CODE();
+
+    IoStackLocation = IoGetCurrentIrpStackLocation(Irp);
+
+    DPRINT("DriverClose. DeviceObject=%p, RequestorMode=%d, FileObject=%p, FsContext=%p, FsContext2=%p\n",
+             DeviceObject, Irp->RequestorMode, IoStackLocation->FileObject,
+             IoStackLocation->FileObject->FsContext, IoStackLocation->FileObject->FsContext2);
+
+    ASSERT(IoStackLocation->FileObject->FsContext2 == NULL);
+    DeviceExtension = DeviceObject->DeviceExtension;
+    if (DeviceExtension->Mdl && IoStackLocation->FileObject->FsContext == DeviceExtension->Mdl)
+    {
+        MmUnlockPages(DeviceExtension->Mdl);
+        IoFreeMdl(DeviceExtension->Mdl);
+        DeviceExtension->Mdl = NULL;
+        ResultBuffer = DeviceExtension->ResultBuffer = NULL;
+    }
+    else
+    {
+        ASSERT(IoStackLocation->FileObject->FsContext == NULL);
+    }
+
+    Irp->IoStatus.Status = Status;
+    Irp->IoStatus.Information = 0;
+
+    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+
+    return Status;
+}
+
+/**
+ * @name DriverIoControl
+ *
+ * Driver Dispatch function for DeviceIoControl.
+ *
+ * @param DeviceObject
+ *        Device Object
+ * @param Irp
+ *        I/O request packet
+ *
+ * @return Status
+ */
+static
+NTSTATUS
+NTAPI
+DriverIoControl(
+    IN PDEVICE_OBJECT DeviceObject,
+    IN PIRP Irp)
+{
+    NTSTATUS Status = STATUS_SUCCESS;
+    PIO_STACK_LOCATION IoStackLocation;
+    ULONG Length = 0;
+
+    PAGED_CODE();
+
+    IoStackLocation = IoGetCurrentIrpStackLocation(Irp);
+
+    DPRINT("DriverIoControl. Code=0x%08X, DeviceObject=%p, FileObject=%p, FsContext=%p, FsContext2=%p\n",
+             IoStackLocation->Parameters.DeviceIoControl.IoControlCode,
+             DeviceObject, IoStackLocation->FileObject,
+             IoStackLocation->FileObject->FsContext, IoStackLocation->FileObject->FsContext2);
+
+    switch (IoStackLocation->Parameters.DeviceIoControl.IoControlCode)
+    {
+        case IOCTL_KMTEST_GET_TESTS:
+        {
+            PCKMT_TEST TestEntry;
+            LPSTR OutputBuffer = Irp->AssociatedIrp.SystemBuffer;
+            size_t Remaining = IoStackLocation->Parameters.DeviceIoControl.OutputBufferLength;
+
+            DPRINT("DriverIoControl. IOCTL_KMTEST_GET_TESTS, outlen=%lu\n",
+                     IoStackLocation->Parameters.DeviceIoControl.OutputBufferLength);
+
+            for (TestEntry = TestList; TestEntry->TestName; ++TestEntry)
+            {
+                RtlStringCbCopyExA(OutputBuffer, Remaining, TestEntry->TestName, &OutputBuffer, &Remaining, 0);
+                if (Remaining)
+                {
+                    *OutputBuffer++ = '\0';
+                    --Remaining;
+                }
+            }
+            if (Remaining)
+            {
+                *OutputBuffer++ = '\0';
+                --Remaining;
+            }
+            Length = IoStackLocation->Parameters.DeviceIoControl.OutputBufferLength - Remaining;
+            break;
+        }
+        case IOCTL_KMTEST_RUN_TEST:
+        {
+            ANSI_STRING TestName;
+            PCKMT_TEST TestEntry;
+
+            DPRINT("DriverIoControl. IOCTL_KMTEST_RUN_TEST, inlen=%lu, outlen=%lu\n",
+                     IoStackLocation->Parameters.DeviceIoControl.InputBufferLength,
+                     IoStackLocation->Parameters.DeviceIoControl.OutputBufferLength);
+            TestName.Length = TestName.MaximumLength = (USHORT)min(IoStackLocation->Parameters.DeviceIoControl.InputBufferLength, USHRT_MAX);
+            TestName.Buffer = Irp->AssociatedIrp.SystemBuffer;
+            DPRINT("DriverIoControl. Run test: %Z\n", &TestName);
+
+            for (TestEntry = TestList; TestEntry->TestName; ++TestEntry)
+            {
+                ANSI_STRING EntryName;
+                if (TestEntry->TestName[0] == '-')
+                    RtlInitAnsiString(&EntryName, TestEntry->TestName + 1);
+                else
+                    RtlInitAnsiString(&EntryName, TestEntry->TestName);
+
+                if (!RtlCompareString(&TestName, &EntryName, FALSE))
+                {
+                    DPRINT1("DriverIoControl. Starting test %Z\n", &EntryName);
+                    TestEntry->TestFunction();
+                    DPRINT1("DriverIoControl. Finished test %Z\n", &EntryName);
+                    break;
+                }
+            }
+
+            if (!TestEntry->TestName)
+                Status = STATUS_OBJECT_NAME_INVALID;
+
+            break;
+        }
+        case IOCTL_KMTEST_SET_RESULTBUFFER:
+        {
+            PKMT_DEVICE_EXTENSION DeviceExtension = DeviceObject->DeviceExtension;
+
+            DPRINT("DriverIoControl. IOCTL_KMTEST_SET_RESULTBUFFER, buffer=%p, inlen=%lu, outlen=%lu\n",
+                    IoStackLocation->Parameters.DeviceIoControl.Type3InputBuffer,
+                    IoStackLocation->Parameters.DeviceIoControl.InputBufferLength,
+                    IoStackLocation->Parameters.DeviceIoControl.OutputBufferLength);
+
+            if (DeviceExtension->Mdl)
+            {
+                if (IoStackLocation->FileObject->FsContext != DeviceExtension->Mdl)
+                {
+                    Status = STATUS_ACCESS_DENIED;
+                    break;
+                }
+                MmUnlockPages(DeviceExtension->Mdl);
+                IoFreeMdl(DeviceExtension->Mdl);
+                IoStackLocation->FileObject->FsContext = NULL;
+                ResultBuffer = DeviceExtension->ResultBuffer = NULL;
+            }
+
+            DeviceExtension->Mdl = IoAllocateMdl(IoStackLocation->Parameters.DeviceIoControl.Type3InputBuffer,
+                                                    IoStackLocation->Parameters.DeviceIoControl.InputBufferLength,
+                                                    FALSE, FALSE, NULL);
+            if (!DeviceExtension->Mdl)
+            {
+                Status = STATUS_INSUFFICIENT_RESOURCES;
+                break;
+            }
+
+            _SEH2_TRY
+            {
+                MmProbeAndLockPages(DeviceExtension->Mdl, UserMode, IoModifyAccess);
+            }
+            _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+            {
+                Status = _SEH2_GetExceptionCode();
+                IoFreeMdl(DeviceExtension->Mdl);
+                DeviceExtension->Mdl = NULL;
+                break;
+            } _SEH2_END;
+
+            ResultBuffer = DeviceExtension->ResultBuffer = MmGetSystemAddressForMdlSafe(DeviceExtension->Mdl, NormalPagePriority);
+            IoStackLocation->FileObject->FsContext = DeviceExtension->Mdl;
+
+            DPRINT("DriverIoControl. ResultBuffer: %ld %ld %ld %ld\n",
+                    ResultBuffer->Successes, ResultBuffer->Failures,
+                    ResultBuffer->LogBufferLength, ResultBuffer->LogBufferMaxLength);
+            break;
+        }
+        default:
+            DPRINT1("DriverIoControl. Invalid IoCtl code 0x%08X\n",
+                     IoStackLocation->Parameters.DeviceIoControl.IoControlCode);
+            Status = STATUS_INVALID_DEVICE_REQUEST;
+            break;
+    }
+
+    Irp->IoStatus.Status = Status;
+    Irp->IoStatus.Information = Length;
+
+    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+
+    return Status;
+}