[FLTMGR] Latest from my branch (#135)
authorGed Murphy <gedmurphy@reactos.org>
Tue, 21 Nov 2017 16:36:29 +0000 (16:36 +0000)
committerGitHub <noreply@github.com>
Tue, 21 Nov 2017 16:36:29 +0000 (16:36 +0000)
[FLTMGR][KMTEST]
Squash and push my local branch across to master as the patch is getting a bit large. This is still WIP and none of this code is run in ros yet, so don't fret if you see ugly/unfinished code or int3's dotted around.

[FLTMGR] Improve loading/reg of filters and start to implement client connections
- Implement handling of connections from clients
- Implement closing of client ports
- Add a basic message waiter queue using CSQ's (untested)
- Hand off messages for the comms object to be handled by the comms file
- Initialize the connection list
- Add a registry file which will contain lib functions for accessing filter service entries

- [KMTEST] Initial usermode support for testing FS mini-filters
- Add base routines to wrap the win32 'Filter' APis
- Add support routines to be used when testing FS filter drivers
- Move KmtCreateService to a private routine so it can be shared with KmtFltCreateService
- Completely untested at the mo, so likely contains bugs at this point
- Add support for adding altitude and flags registry entries for minifilters
- Allow minifilters to setup without requiring instance attach/detach callbacks
- Add tests for FltRegisterFilter and FltUnregisterFilter and start to add associated tests

24 files changed:
drivers/filters/fltmgr/CMakeLists.txt
drivers/filters/fltmgr/Filter.c
drivers/filters/fltmgr/Interface.c
drivers/filters/fltmgr/Messaging.c
drivers/filters/fltmgr/Misc.c
drivers/filters/fltmgr/Registry.c [new file with mode: 0644]
drivers/filters/fltmgr/Registry.h [new file with mode: 0644]
drivers/filters/fltmgr/fltmgr.h
drivers/filters/fltmgr/fltmgrint.h
modules/rostests/kmtests/CMakeLists.txt
modules/rostests/kmtests/fltmgr/CMakeLists.txt
modules/rostests/kmtests/fltmgr/fltmgr_load/CMakeLists.txt
modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_load.c
modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_user.c [new file with mode: 0644]
modules/rostests/kmtests/fltmgr/fltmgr_register/CMakeLists.txt [new file with mode: 0644]
modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_reg_user.c [new file with mode: 0644]
modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_register.c [new file with mode: 0644]
modules/rostests/kmtests/include/kmt_platform.h
modules/rostests/kmtests/include/kmt_test.h
modules/rostests/kmtests/kmtest/filter.c
modules/rostests/kmtests/kmtest/fltsupport.c
modules/rostests/kmtests/kmtest/kmtest.h
modules/rostests/kmtests/kmtest/testlist.c
modules/rostests/kmtests/kmtest_drv/kmtest_fsminifilter.c

index 5a02cab..cb6ebbd 100644 (file)
@@ -7,6 +7,7 @@ list(APPEND SOURCE
     Messaging.c
     Misc.c
     Object.c
     Messaging.c
     Misc.c
     Object.c
+    Registry.c
     Volume.c
     ${CMAKE_CURRENT_BINARY_DIR}/fltmgr.def
     fltmgr.h)
     Volume.c
     ${CMAKE_CURRENT_BINARY_DIR}/fltmgr.def
     fltmgr.h)
index 71e45ae..92a824c 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "fltmgr.h"
 #include "fltmgrint.h"
 
 #include "fltmgr.h"
 #include "fltmgrint.h"
+#include "Registry.h"
 
 #define NDEBUG
 #include <debug.h>
 
 #define NDEBUG
 #include <debug.h>
@@ -25,6 +26,17 @@ FltpStartingToDrainObject(
     _Inout_ PFLT_OBJECT Object
 );
 
     _Inout_ PFLT_OBJECT Object
 );
 
+VOID
+FltpMiniFilterDriverUnload(
+);
+
+static
+NTSTATUS
+GetFilterAltitude(
+    _In_ PFLT_FILTER Filter,
+    _Inout_ PUNICODE_STRING AltitudeString
+);
+
 
 /* EXPORTED FUNCTIONS ******************************************************/
 
 
 /* EXPORTED FUNCTIONS ******************************************************/
 
@@ -56,8 +68,26 @@ NTSTATUS
 NTAPI
 FltUnloadFilter(_In_ PCUNICODE_STRING FilterName)
 {
 NTAPI
 FltUnloadFilter(_In_ PCUNICODE_STRING FilterName)
 {
-    UNREFERENCED_PARAMETER(FilterName);
-    return STATUS_NOT_IMPLEMENTED;
+    //
+    //FIXME: This is a temp hack, it needs properly implementing
+    //
+
+    UNICODE_STRING DriverServiceName;
+    UNICODE_STRING ServicesKey;
+    CHAR Buffer[MAX_KEY_LENGTH];
+
+    /* Setup the base services key */
+    RtlInitUnicodeString(&ServicesKey, SERVICES_KEY);
+
+    /* Initialize the string data */
+    DriverServiceName.Length = 0;
+    DriverServiceName.Buffer = (PWCH)Buffer;
+    DriverServiceName.MaximumLength = MAX_KEY_LENGTH;
+
+    /* Create the full service key for this filter */
+    RtlCopyUnicodeString(&DriverServiceName, &ServicesKey);
+    RtlAppendUnicodeStringToString(&DriverServiceName, FilterName);
+    return ZwUnloadDriver(&DriverServiceName);
 }
 
 NTSTATUS
 }
 
 NTSTATUS
@@ -74,6 +104,8 @@ FltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject,
     PCHAR Ptr;
     NTSTATUS Status;
 
     PCHAR Ptr;
     NTSTATUS Status;
 
+    *RetFilter = NULL;
+
     /* Make sure we're targeting the correct major revision */
     if ((Registration->Version & 0xFF00) != FLT_MAJOR_VERSION)
     {
     /* Make sure we're targeting the correct major revision */
     if ((Registration->Version & 0xFF00) != FLT_MAJOR_VERSION)
     {
@@ -145,6 +177,10 @@ FltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject,
     InitializeListHead(&Filter->ActiveOpens.mList);
     Filter->ActiveOpens.mCount = 0;
 
     InitializeListHead(&Filter->ActiveOpens.mList);
     Filter->ActiveOpens.mCount = 0;
 
+    ExInitializeFastMutex(&Filter->ConnectionList.mLock);
+    InitializeListHead(&Filter->ConnectionList.mList);
+    Filter->ConnectionList.mCount = 0;
+
     /* Initialize the usermode port list */
     ExInitializeFastMutex(&Filter->PortList.mLock);
     InitializeListHead(&Filter->PortList.mList);
     /* Initialize the usermode port list */
     ExInitializeFastMutex(&Filter->PortList.mLock);
     InitializeListHead(&Filter->PortList.mList);
@@ -199,15 +235,42 @@ FltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject,
     Filter->Name.Buffer = (PWCH)Ptr;
     RtlCopyUnicodeString(&Filter->Name, &DriverObject->DriverExtension->ServiceKeyName);
 
     Filter->Name.Buffer = (PWCH)Ptr;
     RtlCopyUnicodeString(&Filter->Name, &DriverObject->DriverExtension->ServiceKeyName);
 
+    Status = GetFilterAltitude(Filter, &Filter->DefaultAltitude);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
     //
     //
-    // - Get the altitude string
     // - Slot the filter into the correct altitude location
     // - More stuff??
     //
 
     // - Slot the filter into the correct altitude location
     // - More stuff??
     //
 
+    /* Store any existing driver unload routine before we make any changes */
+    Filter->OldDriverUnload = (PFLT_FILTER_UNLOAD_CALLBACK)DriverObject->DriverUnload;
+
+    /* Check we opted not to have an unload routine, or if we want to stop the driver from being unloaded */
+    if (!FlagOn(Filter->Flags, FLTFL_REGISTRATION_DO_NOT_SUPPORT_SERVICE_STOP))
+    {
+        DriverObject->DriverUnload = (PDRIVER_UNLOAD)FltpMiniFilterDriverUnload;
+    }
+    else
+    {
+        DriverObject->DriverUnload = (PDRIVER_UNLOAD)NULL;
+    }
+
+
 Quit:
 Quit:
-    if (!NT_SUCCESS(Status))
+
+    if (NT_SUCCESS(Status))
+    {
+        DPRINT1("Loaded FS mini-filter %wZ\n", &DriverObject->DriverExtension->ServiceKeyName);
+        *RetFilter = Filter;
+    }
+    else
     {
     {
+        DPRINT1("Failed to load FS mini-filter %wZ : 0x%X\n", &DriverObject->DriverExtension->ServiceKeyName, Status);
+
         // Add cleanup for context resources
 
         ExDeleteResourceLite(&Filter->InstanceList.rLock);
         // Add cleanup for context resources
 
         ExDeleteResourceLite(&Filter->InstanceList.rLock);
@@ -319,3 +382,155 @@ FltpStartingToDrainObject(_Inout_ PFLT_OBJECT Object)
 
     return STATUS_SUCCESS;
 }
 
     return STATUS_SUCCESS;
 }
+
+VOID
+FltpMiniFilterDriverUnload()
+{
+    __debugbreak();
+}
+
+/* PRIVATE FUNCTIONS ******************************************************/
+
+static
+NTSTATUS
+GetFilterAltitude(
+    _In_ PFLT_FILTER Filter,
+    _Inout_ PUNICODE_STRING AltitudeString)
+{
+    UNICODE_STRING InstancesKey = RTL_CONSTANT_STRING(L"Instances");
+    UNICODE_STRING DefaultInstance = RTL_CONSTANT_STRING(L"DefaultInstance");
+    UNICODE_STRING Altitude = RTL_CONSTANT_STRING(L"Altitude");
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    UNICODE_STRING FilterInstancePath;
+    ULONG BytesRequired;
+    HANDLE InstHandle = NULL;
+    HANDLE RootHandle;
+    PWCH InstBuffer = NULL;
+    PWCH AltBuffer = NULL;
+    NTSTATUS Status;
+
+    /* Get a handle to the instances key in the filter's services key */
+    Status = FltpOpenFilterServicesKey(Filter,
+                                       KEY_QUERY_VALUE,
+                                       &InstancesKey,
+                                       &RootHandle);
+    if (!NT_SUCCESS(Status))
+    {
+        return Status;
+    }
+
+    /* Read the size 'default instances' string value */
+    Status = FltpReadRegistryValue(RootHandle,
+                                   &DefaultInstance,
+                                   REG_SZ,
+                                   NULL,
+                                   0,
+                                   &BytesRequired);
+
+    /* We should get a buffer too small error */
+    if (Status == STATUS_BUFFER_TOO_SMALL)
+    {
+        /* Allocate the buffer we need to hold the string */
+        InstBuffer = ExAllocatePoolWithTag(PagedPool, BytesRequired, FM_TAG_UNICODE_STRING);
+        if (InstBuffer == NULL)
+        {
+            Status = STATUS_INSUFFICIENT_RESOURCES;
+            goto Quit;
+        }
+
+        /* Now read the string value */
+        Status = FltpReadRegistryValue(RootHandle,
+                                       &DefaultInstance,
+                                       REG_SZ,
+                                       InstBuffer,
+                                       BytesRequired,
+                                       &BytesRequired);
+    }
+
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Convert the string to a unicode_string */
+    RtlInitUnicodeString(&FilterInstancePath, InstBuffer);
+
+    /* Setup the attributes using the root key handle */
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &FilterInstancePath,
+                               OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE,
+                               RootHandle,
+                               NULL);
+
+    /* Now open the key name which was stored in the default instance */
+    Status = ZwOpenKey(&InstHandle, KEY_QUERY_VALUE, &ObjectAttributes);
+    if (NT_SUCCESS(Status))
+    {
+        /* Get the size of the buffer that holds the altitude */
+        Status = FltpReadRegistryValue(InstHandle,
+                                       &Altitude,
+                                       REG_SZ,
+                                       NULL,
+                                       0,
+                                       &BytesRequired);
+        if (Status == STATUS_BUFFER_TOO_SMALL)
+        {
+            /* Allocate the required buffer */
+            AltBuffer = ExAllocatePoolWithTag(PagedPool, BytesRequired, FM_TAG_UNICODE_STRING);
+            if (AltBuffer == NULL)
+            {
+                Status = STATUS_INSUFFICIENT_RESOURCES;
+                goto Quit;
+            }
+
+            /* And now finally read in the actual altitude string */
+            Status = FltpReadRegistryValue(InstHandle,
+                                           &Altitude,
+                                           REG_SZ,
+                                           AltBuffer,
+                                           BytesRequired,
+                                           &BytesRequired);
+            if (NT_SUCCESS(Status))
+            {
+                /* We made it, setup the return buffer */
+                AltitudeString->Length = BytesRequired;
+                AltitudeString->MaximumLength = BytesRequired;
+                AltitudeString->Buffer = AltBuffer;
+            }
+        }
+    }
+
+Quit:
+    if (!NT_SUCCESS(Status))
+    {
+        if (AltBuffer)
+        {
+            ExFreePoolWithTag(AltBuffer, FM_TAG_UNICODE_STRING);
+        }
+    }
+
+    if (InstBuffer)
+    {
+        ExFreePoolWithTag(InstBuffer, FM_TAG_UNICODE_STRING);
+    }
+
+    if (InstHandle)
+    {
+        ZwClose(InstHandle);
+    }
+    ZwClose(RootHandle);
+
+    return Status;
+}
+
+
+
+NTSTATUS
+FltpReadRegistryValue(
+    _In_ HANDLE KeyHandle,
+    _In_ PUNICODE_STRING ValueName,
+    _In_opt_ ULONG Type,
+    _Out_writes_bytes_(BufferSize) PVOID Buffer,
+    _In_ ULONG BufferSize,
+    _Out_opt_ PULONG BytesRequired
+);
\ No newline at end of file
index be1288a..158303b 100644 (file)
@@ -28,6 +28,8 @@
     ((_devObj)->DriverObject == Dispatcher::DriverObject) && \
       ((_devObj)->DeviceExtension != NULL))
 
     ((_devObj)->DriverObject == Dispatcher::DriverObject) && \
       ((_devObj)->DeviceExtension != NULL))
 
+extern PDEVICE_OBJECT CommsDeviceObject;
+
 
 DRIVER_INITIALIZE DriverEntry;
 NTSTATUS
 
 DRIVER_INITIALIZE DriverEntry;
 NTSTATUS
@@ -454,6 +456,13 @@ FltpDispatch(_In_ PDEVICE_OBJECT DeviceObject,
         return Status;
     }
 
         return Status;
     }
 
+    /* Check if this is a request for a the messaging device */
+    if (DeviceObject == CommsDeviceObject)
+    {
+        /* Hand off to our internal routine */
+        return FltpMsgDispatch(DeviceObject, Irp);
+    }
+
     FLT_ASSERT(DeviceExtension &&
                DeviceExtension->AttachedToDeviceObject);
 
     FLT_ASSERT(DeviceExtension &&
                DeviceExtension->AttachedToDeviceObject);
 
@@ -494,6 +503,13 @@ FltpCreate(_In_ PDEVICE_OBJECT DeviceObject,
         return STATUS_SUCCESS;
     }
 
         return STATUS_SUCCESS;
     }
 
+    /* Check if this is a request for a the new comms connection */
+    if (DeviceObject == CommsDeviceObject)
+    {
+        /* Hand off to our internal routine */
+        return FltpMsgCreate(DeviceObject, Irp);
+    }
+
     FLT_ASSERT(DeviceExtension &&
                DeviceExtension->AttachedToDeviceObject);
 
     FLT_ASSERT(DeviceExtension &&
                DeviceExtension->AttachedToDeviceObject);
 
@@ -2104,9 +2120,9 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
     Status = SetupDispatchAndCallbacksTables(DriverObject);
     if (!NT_SUCCESS(Status)) goto Cleanup;
 
     Status = SetupDispatchAndCallbacksTables(DriverObject);
     if (!NT_SUCCESS(Status)) goto Cleanup;
 
-    //
-    // TODO: Create fltmgr message device
-    //
+    /* Initialize the comms objects */
+    Status = FltpSetupCommunicationObjects(DriverObject);
+    if (!NT_SUCCESS(Status)) goto Cleanup;
 
     /* Register for notifications when a new file system is loaded. This also enumerates any existing file systems */
     Status = IoRegisterFsRegistrationChange(DriverObject, FltpFsNotification);
 
     /* Register for notifications when a new file system is loaded. This also enumerates any existing file systems */
     Status = IoRegisterFsRegistrationChange(DriverObject, FltpFsNotification);
index 97e6a94..6a40e0d 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "fltmgr.h"
 #include "fltmgrint.h"
 
 #include "fltmgr.h"
 #include "fltmgrint.h"
+#include <fltmgr_shared.h>
 
 #define NDEBUG
 #include <debug.h>
 
 #define NDEBUG
 #include <debug.h>
@@ -29,6 +30,32 @@ FltpDisconnectPort(
     _In_ PFLT_PORT_OBJECT PortObject
 );
 
     _In_ PFLT_PORT_OBJECT PortObject
 );
 
+static
+NTSTATUS
+CreateClientPort(
+    _In_ PFILE_OBJECT FileObject,
+    _Inout_ PIRP Irp
+);
+
+static
+NTSTATUS
+CloseClientPort(
+    _In_ PFILE_OBJECT FileObject,
+    _Inout_ PIRP Irp
+);
+
+static
+NTSTATUS
+InitializeMessageWaiterQueue(
+    _Inout_ PFLT_MESSAGE_WAITER_QUEUE MsgWaiterQueue
+);
+
+static
+PPORT_CCB
+CreatePortCCB(
+    _In_ PFLT_PORT_OBJECT PortObject
+);
+
 
 
 /* EXPORTED FUNCTIONS ******************************************************/
 
 
 /* EXPORTED FUNCTIONS ******************************************************/
@@ -72,8 +99,8 @@ FltCreateCommunicationPort(_In_ PFLT_FILTER Filter,
         return Status;
     }
 
         return Status;
     }
 
-    /* Create our new server port object */
-    Status = ObCreateObject(0,
+    /* Create the server port object for this filter */
+    Status = ObCreateObject(KernelMode,
                             ServerPortObjectType,
                             ObjectAttributes,
                             KernelMode,
                             ServerPortObjectType,
                             ObjectAttributes,
                             KernelMode,
@@ -191,6 +218,72 @@ FltSendMessage(_In_ PFLT_FILTER Filter,
 
 /* INTERNAL FUNCTIONS ******************************************************/
 
 
 /* INTERNAL FUNCTIONS ******************************************************/
 
+
+NTSTATUS
+FltpMsgCreate(_In_ PDEVICE_OBJECT DeviceObject,
+              _Inout_ PIRP Irp)
+{
+    PIO_STACK_LOCATION StackPtr;
+    NTSTATUS Status;
+
+    /* Get the stack location */
+    StackPtr = IoGetCurrentIrpStackLocation(Irp);
+
+    FLT_ASSERT(StackPtr->MajorFunction == IRP_MJ_CREATE);
+
+    /* Check if this is a caller wanting to connect */
+    if (StackPtr->MajorFunction == IRP_MJ_CREATE)
+    {
+        /* Create the client port for this connection and exit */
+        Status = CreateClientPort(StackPtr->FileObject, Irp);
+    }
+    else
+    {
+        Status = STATUS_INVALID_PARAMETER;
+    }
+
+    if (Status != STATUS_PENDING)
+    {
+        Irp->IoStatus.Status = Status;
+        Irp->IoStatus.Information = 0;
+        IoCompleteRequest(Irp, 0);
+    }
+
+    return Status;
+}
+
+NTSTATUS
+FltpMsgDispatch(_In_ PDEVICE_OBJECT DeviceObject,
+                _Inout_ PIRP Irp)
+{
+    PIO_STACK_LOCATION StackPtr;
+    NTSTATUS Status;
+
+    /* Get the stack location */
+    StackPtr = IoGetCurrentIrpStackLocation(Irp);
+
+    /* Check if this is a caller wanting to connect */
+    if (StackPtr->MajorFunction == IRP_MJ_CLOSE)
+    {
+        /* Create the client port for this connection and exit */
+        Status = CloseClientPort(StackPtr->FileObject, Irp);
+    }
+    else
+    {
+        // We don't support anything else yet
+        Status = STATUS_NOT_IMPLEMENTED;
+    }
+
+    if (Status != STATUS_PENDING)
+    {
+        Irp->IoStatus.Status = Status;
+        Irp->IoStatus.Information = 0;
+        IoCompleteRequest(Irp, 0);
+    }
+
+    return Status;
+}
+
 VOID
 NTAPI
 FltpServerPortClose(_In_opt_ PEPROCESS Process,
 VOID
 NTAPI
 FltpServerPortClose(_In_opt_ PEPROCESS Process,
@@ -369,5 +462,370 @@ Quit:
     return Status;
 }
 
     return Status;
 }
 
+/* CSQ IRP CALLBACKS *******************************************************/
+
+
+NTSTATUS
+NTAPI
+FltpAddMessageWaiter(_In_ PIO_CSQ Csq,
+                     _In_ PIRP Irp,
+                     _In_ PVOID InsertContext)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Insert the IRP at the end of the queue */
+    InsertTailList(&MessageWaiterQueue->WaiterQ.mList,
+                   &Irp->Tail.Overlay.ListEntry);
+
+    /* return success */
+    return STATUS_SUCCESS;
+}
+
+VOID
+NTAPI
+FltpRemoveMessageWaiter(_In_ PIO_CSQ Csq,
+                        _In_ PIRP Irp)
+{
+    /* Remove the IRP from the queue */
+    RemoveEntryList(&Irp->Tail.Overlay.ListEntry);
+}
+
+PIRP
+NTAPI
+FltpGetNextMessageWaiter(_In_ PIO_CSQ Csq,
+                         _In_ PIRP Irp,
+                         _In_ PVOID PeekContext)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+    PIRP NextIrp = NULL;
+    PLIST_ENTRY NextEntry;
+    PIO_STACK_LOCATION IrpStack;
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Is the IRP valid? */
+    if (Irp == NULL)
+    {
+        /* Start peeking from the listhead */
+        NextEntry = MessageWaiterQueue->WaiterQ.mList.Flink;
+    }
+    else
+    {
+        /* Start peeking from that IRP onwards */
+        NextEntry = Irp->Tail.Overlay.ListEntry.Flink;
+    }
+
+    /* Loop through the queue */
+    while (NextEntry != &MessageWaiterQueue->WaiterQ.mList)
+    {
+        /* Store the next IRP in the list */
+        NextIrp = CONTAINING_RECORD(NextEntry, IRP, Tail.Overlay.ListEntry);
+
+        /* Did we supply a PeekContext on insert? */
+        if (!PeekContext)
+        {
+            /* We already have the next IRP */
+            break;
+        }
+        else
+        {
+            /* Get the stack of the next IRP */
+            IrpStack = IoGetCurrentIrpStackLocation(NextIrp);
+
+            /* Does the PeekContext match the object? */
+            if (IrpStack->FileObject == (PFILE_OBJECT)PeekContext)
+            {
+                /* We have a match */
+                break;
+            }
+
+            /* Move to the next IRP */
+            NextIrp = NULL;
+            NextEntry = NextEntry->Flink;
+        }
+    }
+
+    return NextIrp;
+}
+
+_Acquires_lock_(((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, FLT_MESSAGE_WAITER_QUEUE, Csq))->WaiterQ.mLock)
+_IRQL_saves_global_(Irql, ((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock)
+_IRQL_raises_(DISPATCH_LEVEL)
+VOID
+NTAPI
+FltpAcquireMessageWaiterLock(_In_ PIO_CSQ Csq,
+                             _Out_ PKIRQL Irql)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+
+    UNREFERENCED_PARAMETER(Irql);
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Acquire the IRP queue lock */
+    ExAcquireFastMutex(&MessageWaiterQueue->WaiterQ.mLock);
+}
+
+_Releases_lock_(((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock)
+_IRQL_restores_global_(Irql, ((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock)
+_IRQL_requires_(DISPATCH_LEVEL)
+VOID
+NTAPI
+FltpReleaseMessageWaiterLock(_In_ PIO_CSQ Csq,
+                             _In_ KIRQL Irql)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+
+    UNREFERENCED_PARAMETER(Irql);
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Release the IRP queue lock */
+    ExReleaseFastMutex(&MessageWaiterQueue->WaiterQ.mLock);
+}
+
+VOID
+NTAPI
+FltpCancelMessageWaiter(_In_ PIO_CSQ Csq,
+                        _In_ PIRP Irp)
+{
+    /* Cancel the IRP */
+    Irp->IoStatus.Status = STATUS_CANCELLED;
+    Irp->IoStatus.Information = 0;
+    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+}
+
+
 /* PRIVATE FUNCTIONS ******************************************************/
 
 /* PRIVATE FUNCTIONS ******************************************************/
 
+static
+NTSTATUS
+CreateClientPort(_In_ PFILE_OBJECT FileObject,
+                   _Inout_ PIRP Irp)
+{
+    PFLT_SERVER_PORT_OBJECT ServerPortObject = NULL;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    PFILTER_PORT_DATA FilterPortData;
+    PFLT_PORT_OBJECT ClientPortObject = NULL;
+    PFLT_PORT PortHandle = NULL;
+    PPORT_CCB PortCCB = NULL;
+    //ULONG BufferLength;
+    LONG NumConns;
+    NTSTATUS Status;
+
+    /* We received the buffer via FilterConnectCommunicationPort, cast it back to its original form */
+    FilterPortData = Irp->AssociatedIrp.SystemBuffer;
+
+    /* Get a reference to the server port the filter created */
+    Status = ObReferenceObjectByName(&FilterPortData->PortName,
+                                     0,
+                                     0,
+                                     FLT_PORT_ALL_ACCESS,
+                                     ServerPortObjectType,
+                                     ExGetPreviousMode(),
+                                     0,
+                                     (PVOID *)&ServerPortObject);
+    if (!NT_SUCCESS(Status))
+    {
+        return Status;
+    }
+
+    /* Increment the number of connections on the server port */
+    NumConns = InterlockedIncrement(&ServerPortObject->NumberOfConnections);
+    if (NumConns > ServerPortObject->MaxConnections)
+    {
+        Status = STATUS_CONNECTION_COUNT_LIMIT;
+        goto Quit;
+    }
+
+    /* Initialize a basic kernel handle request */
+    InitializeObjectAttributes(&ObjectAttributes,
+                               NULL,
+                               OBJ_KERNEL_HANDLE,
+                               NULL,
+                               NULL);
+
+    /* Now create the new client port object */
+    Status = ObCreateObject(KernelMode,
+                            ClientPortObjectType,
+                            &ObjectAttributes,
+                            KernelMode,
+                            NULL,
+                            sizeof(FLT_PORT_OBJECT),
+                            0,
+                            0,
+                            (PVOID *)&ClientPortObject);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Clear out the buffer */
+    RtlZeroMemory(ClientPortObject, sizeof(FLT_PORT_OBJECT));
+
+    /* Initialize the locks */
+    ExInitializeRundownProtection(&ClientPortObject->MsgNotifRundownRef);
+    ExInitializeFastMutex(&ClientPortObject->Lock);
+
+    /* Set the server port object this belongs to */
+    ClientPortObject->ServerPort = ServerPortObject;
+
+    /* Setup the message queue */
+    Status = InitializeMessageWaiterQueue(&ClientPortObject->MsgQ);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Create the CCB which we'll attach to the file object */
+    PortCCB = CreatePortCCB(ClientPortObject);
+    if (PortCCB == NULL)
+    {
+        Status = STATUS_INSUFFICIENT_RESOURCES;
+        goto Quit;
+    }
+
+    /* Now insert the new client port into the object manager*/
+    Status = ObInsertObject(ClientPortObject, 0, FLT_PORT_ALL_ACCESS, 1, 0, (PHANDLE)&PortHandle);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Add a reference to the filter to keep it alive while we do some work with it */
+    Status = FltObjectReference(ServerPortObject->Filter);
+    if (NT_SUCCESS(Status))
+    {
+        /* Invoke the callback to let the filter know we have a connection */
+        Status = ServerPortObject->ConnectNotify(PortHandle,
+                                                 ServerPortObject->Cookie,
+                                                 NULL, //ConnectionContext
+                                                 0, //SizeOfContext
+                                                 &ClientPortObject->Cookie);
+        if (NT_SUCCESS(Status))
+        {
+            /* Add the client port CCB to the file object */
+            FileObject->FsContext2 = PortCCB;
+
+            /* Lock the port list on the filter and add this new port object to the list */
+            ExAcquireFastMutex(&ServerPortObject->Filter->PortList.mLock);
+            InsertTailList(&ServerPortObject->Filter->PortList.mList, &ClientPortObject->FilterLink);
+            ExReleaseFastMutex(&ServerPortObject->Filter->PortList.mLock);
+        }
+
+        /* We're done with the filter object, decremement the count */
+        FltObjectDereference(ServerPortObject->Filter);
+    }
+
+
+Quit:
+    if (!NT_SUCCESS(Status))
+    {
+        if (ClientPortObject)
+        {
+            ObfDereferenceObject(ClientPortObject);
+        }
+
+        if (PortHandle)
+        {
+            ZwClose(PortHandle);
+        }
+        else if (ServerPortObject)
+        {
+            InterlockedDecrement(&ServerPortObject->NumberOfConnections);
+            ObfDereferenceObject(ServerPortObject);
+        }
+
+        if (PortCCB)
+        {
+            ExFreePoolWithTag(PortCCB, FM_TAG_CCB);
+        }
+    }
+
+    return Status;
+}
+
+static
+NTSTATUS
+CloseClientPort(_In_ PFILE_OBJECT FileObject,
+                _Inout_ PIRP Irp)
+{
+    PFLT_CCB Ccb;
+
+    Ccb = (PFLT_CCB)FileObject->FsContext2;
+
+    /* Remove the reference on the filter we added when we opened the port */
+    ObDereferenceObject(Ccb->Data.Port.Port);
+
+    // FIXME: Free the CCB
+
+    return STATUS_SUCCESS;
+}
+
+static
+NTSTATUS
+InitializeMessageWaiterQueue(_Inout_ PFLT_MESSAGE_WAITER_QUEUE MsgWaiterQueue)
+{
+    NTSTATUS Status;
+
+    /* Setup the IRP queue */
+    Status = IoCsqInitializeEx(&MsgWaiterQueue->Csq,
+                               FltpAddMessageWaiter,
+                               FltpRemoveMessageWaiter,
+                               FltpGetNextMessageWaiter,
+                               FltpAcquireMessageWaiterLock,
+                               FltpReleaseMessageWaiterLock,
+                               FltpCancelMessageWaiter);
+    if (!NT_SUCCESS(Status))
+    {
+        return Status;
+    }
+
+    /* Initialize the waiter queue */
+    ExInitializeFastMutex(&MsgWaiterQueue->WaiterQ.mLock);
+    InitializeListHead(&MsgWaiterQueue->WaiterQ.mList);
+    MsgWaiterQueue->WaiterQ.mCount = 0;
+
+    /* We don't have a minimum waiter length */
+    MsgWaiterQueue->MinimumWaiterLength = (ULONG)-1;
+
+    /* Init the semaphore and event used for counting and signaling available IRPs */
+    KeInitializeSemaphore(&MsgWaiterQueue->Semaphore, 0, MAXLONG);
+    KeInitializeEvent(&MsgWaiterQueue->Event, NotificationEvent, FALSE);
+
+    return STATUS_SUCCESS;
+}
+
+static
+PPORT_CCB
+CreatePortCCB(_In_ PFLT_PORT_OBJECT PortObject)
+{
+    PPORT_CCB PortCCB;
+
+    /* Allocate a CCB struct to hold the client port object info */
+    PortCCB = ExAllocatePoolWithTag(NonPagedPool, sizeof(PPORT_CCB), FM_TAG_CCB);
+    if (PortCCB)
+    {
+        /* Initialize the structure */
+        PortCCB->Port = PortObject;
+        PortCCB->ReplyWaiterList.mCount = 0;
+        ExInitializeFastMutex(&PortCCB->ReplyWaiterList.mLock);
+        KeInitializeEvent(&PortCCB->ReplyWaiterList.mLock.Event, SynchronizationEvent, 0);
+    }
+
+    return PortCCB;
+}
\ No newline at end of file
index 3ef1424..8643c48 100644 (file)
@@ -29,8 +29,8 @@ FltBuildDefaultSecurityDescriptor(
     _In_ ACCESS_MASK DesiredAccess
 )
 {
     _In_ ACCESS_MASK DesiredAccess
 )
 {
-    UNREFERENCED_PARAMETER(SecurityDescriptor);
     UNREFERENCED_PARAMETER(DesiredAccess);
     UNREFERENCED_PARAMETER(DesiredAccess);
+    *SecurityDescriptor = NULL;
     return 0;
 }
 
     return 0;
 }
 
diff --git a/drivers/filters/fltmgr/Registry.c b/drivers/filters/fltmgr/Registry.c
new file mode 100644 (file)
index 0000000..2470d70
--- /dev/null
@@ -0,0 +1,142 @@
+/*
+ * PROJECT:         Filesystem Filter Manager
+ * LICENSE:         GPL - See COPYING in the top level directory
+ * FILE:            drivers/filters/fltmgr/Misc.c
+ * PURPOSE:         Uncataloged functions
+ * PROGRAMMERS:     Ged Murphy (gedmurphy@reactos.org)
+ */
+
+/* INCLUDES ******************************************************************/
+
+#include "fltmgr.h"
+#include "fltmgrint.h"
+
+#define NDEBUG
+#include <debug.h>
+
+
+/* DATA *********************************************************************/
+
+#define REG_SERVICES_KEY    L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\"
+#define REG_PATH_LENGTH     512
+
+
+/* INTERNAL FUNCTIONS ******************************************************/
+
+
+NTSTATUS
+FltpOpenFilterServicesKey(
+    _In_ PFLT_FILTER Filter,
+    _In_ ACCESS_MASK DesiredAccess,
+    _In_opt_ PUNICODE_STRING SubKey,
+    _Out_ PHANDLE Handle)
+{
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    UNICODE_STRING ServicesKey;
+    UNICODE_STRING Path;
+    WCHAR Buffer[REG_PATH_LENGTH];
+
+    /* Setup a local buffer to hold the services key path */
+    Path.Length = 0;
+    Path.MaximumLength = REG_PATH_LENGTH;
+    Path.Buffer = Buffer;
+
+    /* Build up the serices key name */
+    RtlInitUnicodeString(&ServicesKey, REG_SERVICES_KEY);
+    RtlCopyUnicodeString(&Path, &ServicesKey);
+    RtlAppendUnicodeStringToString(&Path, &Filter->Name);
+
+    if (SubKey)
+    {
+        /* Tag on any child key */
+        RtlAppendUnicodeToString(&Path, L"\\");
+        RtlAppendUnicodeStringToString(&Path, SubKey);
+    }
+
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &Path,
+                               OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE,
+                               NULL,
+                               NULL);
+
+    /* Open and return the key handle param*/
+    return ZwOpenKey(Handle, DesiredAccess, &ObjectAttributes);
+}
+
+NTSTATUS
+FltpReadRegistryValue(_In_ HANDLE KeyHandle,
+                      _In_ PUNICODE_STRING ValueName,
+                      _In_opt_ ULONG Type,
+                      _Out_writes_bytes_(BufferSize) PVOID Buffer,
+                      _In_ ULONG BufferSize,
+                      _Out_opt_ PULONG BytesRequired)
+{
+    PKEY_VALUE_PARTIAL_INFORMATION Value = NULL;
+    ULONG ValueLength = 0;
+    NTSTATUS Status;
+
+    PAGED_CODE();
+
+    /* Get the size of the buffer required to hold the string */
+    Status = ZwQueryValueKey(KeyHandle,
+                             ValueName,
+                             KeyValuePartialInformation,
+                             NULL,
+                             0,
+                             &ValueLength);
+    if (Status != STATUS_BUFFER_TOO_SMALL && Status != STATUS_BUFFER_OVERFLOW)
+    {
+        return Status;
+    }
+
+    /* Allocate the buffer */
+    Value = (PKEY_VALUE_PARTIAL_INFORMATION)ExAllocatePoolWithTag(PagedPool,
+                                                                  ValueLength,
+                                                                  FM_TAG_TEMP_REGISTRY);
+    if (Value == NULL)
+    {
+        Status = STATUS_INSUFFICIENT_RESOURCES;
+        goto Quit;
+    }
+
+    /* Now read in the value */
+    Status = ZwQueryValueKey(KeyHandle,
+                             ValueName,
+                             KeyValuePartialInformation,
+                             Value,
+                             ValueLength,
+                             &ValueLength);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Make sure we got the type expected */
+    if (Value->Type != Type)
+    {
+        Status = STATUS_INVALID_PARAMETER;
+        goto Quit;
+    }
+
+    if (BytesRequired)
+    {
+        *BytesRequired = Value->DataLength;
+    }
+
+    /* Make sure the caller buffer is big enough to hold the data */
+    if (!BufferSize || BufferSize < Value->DataLength)
+    {
+        Status = STATUS_BUFFER_TOO_SMALL;
+        goto Quit;
+    }
+
+    /* Copy the data into the caller buffer */
+    RtlCopyMemory(Buffer, Value->Data, Value->DataLength);
+
+Quit:
+
+    if (Value)
+        ExFreePoolWithTag(Value, FM_TAG_TEMP_REGISTRY);
+
+    return Status;
+}
diff --git a/drivers/filters/fltmgr/Registry.h b/drivers/filters/fltmgr/Registry.h
new file mode 100644 (file)
index 0000000..b295dae
--- /dev/null
@@ -0,0 +1,20 @@
+#pragma once
+
+
+NTSTATUS
+FltpOpenFilterServicesKey(
+    _In_ PFLT_FILTER Filter,
+    _In_ ACCESS_MASK DesiredAccess,
+    _In_opt_ PUNICODE_STRING SubKey,
+    _Out_ PHANDLE Handle
+);
+
+NTSTATUS
+FltpReadRegistryValue(
+    _In_ HANDLE KeyHandle,
+    _In_ PUNICODE_STRING ValueName,
+    _In_opt_ ULONG Type,
+    _Out_writes_bytes_(BufferSize) PVOID Buffer,
+    _In_ ULONG BufferSize,
+    _Out_opt_ PULONG BytesRequired
+);
\ No newline at end of file
index 71e070b..7008c1d 100644 (file)
@@ -18,6 +18,8 @@
 #define FM_TAG_UNICODE_STRING   'suMF'
 #define FM_TAG_FILTER           'lfMF'
 #define FM_TAG_CONTEXT_REGISTA  'rcMF'
 #define FM_TAG_UNICODE_STRING   'suMF'
 #define FM_TAG_FILTER           'lfMF'
 #define FM_TAG_CONTEXT_REGISTA  'rcMF'
+#define FM_TAG_CCB              'bcMF'
+#define FM_TAG_TEMP_REGISTRY    'rtMF'
 
 #define MAX_DEVNAME_LENGTH  64
 
 
 #define MAX_DEVNAME_LENGTH  64
 
@@ -100,6 +102,9 @@ FltpReallocateUnicodeString(_In_ PUNICODE_STRING String,
 
 VOID
 FltpFreeUnicodeString(_In_ PUNICODE_STRING String);
 
 VOID
 FltpFreeUnicodeString(_In_ PUNICODE_STRING String);
+
+
+
 ////////////////////////////////////////////////
 
 
 ////////////////////////////////////////////////
 
 
@@ -252,7 +257,7 @@ FltGetUpperInstance
 FltGetVolumeContext
 FltGetVolumeFromDeviceObject
 FltGetVolumeFromFileObject
 FltGetVolumeContext
 FltGetVolumeFromDeviceObject
 FltGetVolumeFromFileObject
-FltGetVolumeFromInstance
+FltLoadFilter
 FltGetVolumeFromName
 FltGetVolumeGuidName
 FltGetVolumeInstanceFromName
 FltGetVolumeFromName
 FltGetVolumeGuidName
 FltGetVolumeInstanceFromName
index a060b70..64065be 100644 (file)
@@ -58,10 +58,45 @@ typedef struct _FLT_MUTEX_LIST_HEAD
 
 } FLT_MUTEX_LIST_HEAD, *PFLT_MUTEX_LIST_HEAD;
 
 
 } FLT_MUTEX_LIST_HEAD, *PFLT_MUTEX_LIST_HEAD;
 
+typedef struct _FLT_TYPE
+{
+    USHORT Signature;
+    USHORT Size;
+
+} FLT_TYPE, *PFLT_TYPE;
+
+// http://fsfilters.blogspot.co.uk/2010/02/filter-manager-concepts-part-1.html
+typedef struct _FLTP_FRAME
+{
+    FLT_TYPE Type;
+    LIST_ENTRY Links;
+    unsigned int FrameID;
+    ERESOURCE AltitudeLock;
+    UNICODE_STRING AltitudeIntervalLow;
+    UNICODE_STRING AltitudeIntervalHigh;
+    char LargeIrpCtrlStackSize;
+    char SmallIrpCtrlStackSize;
+    FLT_RESOURCE_LIST_HEAD RegisteredFilters;
+    FLT_RESOURCE_LIST_HEAD AttachedVolumes;
+    LIST_ENTRY MountingVolumes;
+    FLT_MUTEX_LIST_HEAD AttachedFileSystems;
+    FLT_MUTEX_LIST_HEAD ZombiedFltObjectContexts;
+    ERESOURCE FilterUnloadLock;
+    FAST_MUTEX DeviceObjectAttachLock;
+    //FLT_PRCB *Prcb;
+    void *PrcbPoolToFree;
+    void *LookasidePoolToFree;
+    //FLTP_IRPCTRL_STACK_PROFILER IrpCtrlStackProfiler;
+    NPAGED_LOOKASIDE_LIST SmallIrpCtrlLookasideList;
+    NPAGED_LOOKASIDE_LIST LargeIrpCtrlLookasideList;
+    //STATIC_IRP_CONTROL GlobalSIC;
+
+} FLTP_FRAME, *PFLTP_FRAME;
+
 typedef struct _FLT_FILTER   // size = 0x120
 {
     FLT_OBJECT Base;
 typedef struct _FLT_FILTER   // size = 0x120
 {
     FLT_OBJECT Base;
-    PVOID Frame;  //FLTP_FRAME
+    PFLTP_FRAME Frame;
     UNICODE_STRING Name;
     UNICODE_STRING DefaultAltitude;
     FLT_FILTER_FLAGS Flags;
     UNICODE_STRING Name;
     UNICODE_STRING DefaultAltitude;
     FLT_FILTER_FLAGS Flags;
@@ -97,12 +132,7 @@ typedef enum _FLT_yINSTANCE_FLAGS
 
 } FLT_INSTANCE_FLAGS, *PFLT_INSTANCE_FLAGS;
 
 
 } FLT_INSTANCE_FLAGS, *PFLT_INSTANCE_FLAGS;
 
-typedef struct _FLT_TYPE
-{
-    USHORT Signature;
-    USHORT Size;
 
 
-} FLT_TYPE, *PFLT_TYPE;
 
 typedef struct _FLT_INSTANCE   // size = 0x144 (324)
 {
 
 typedef struct _FLT_INSTANCE   // size = 0x144 (324)
 {
@@ -121,34 +151,18 @@ typedef struct _FLT_INSTANCE   // size = 0x144 (324)
 
 } FLT_INSTANCE, *PFLT_INSTANCE;
 
 
 } FLT_INSTANCE, *PFLT_INSTANCE;
 
-// http://fsfilters.blogspot.co.uk/2010/02/filter-manager-concepts-part-1.html
-typedef struct _FLTP_FRAME
+
+typedef struct _TREE_ROOT
 {
 {
-    FLT_TYPE Type;
-    LIST_ENTRY Links;
-    unsigned int FrameID;
-    ERESOURCE AltitudeLock;
-    UNICODE_STRING AltitudeIntervalLow;
-    UNICODE_STRING AltitudeIntervalHigh;
-    char LargeIrpCtrlStackSize;
-    char SmallIrpCtrlStackSize;
-    FLT_RESOURCE_LIST_HEAD RegisteredFilters;
-    FLT_RESOURCE_LIST_HEAD AttachedVolumes;
-    LIST_ENTRY MountingVolumes;
-    FLT_MUTEX_LIST_HEAD AttachedFileSystems;
-    FLT_MUTEX_LIST_HEAD ZombiedFltObjectContexts;
-    ERESOURCE FilterUnloadLock;
-    FAST_MUTEX DeviceObjectAttachLock;
-    //FLT_PRCB *Prcb;
-    void *PrcbPoolToFree;
-    void *LookasidePoolToFree;
-    //FLTP_IRPCTRL_STACK_PROFILER IrpCtrlStackProfiler;
-    NPAGED_LOOKASIDE_LIST SmallIrpCtrlLookasideList;
-    NPAGED_LOOKASIDE_LIST LargeIrpCtrlLookasideList;
-    //STATIC_IRP_CONTROL GlobalSIC;
+    RTL_SPLAY_LINKS *Tree;
 
 
-} FLTP_FRAME, *PFLTP_FRAME;
+} TREE_ROOT, *PTREE_ROOT;
+
+typedef struct _CONTEXT_LIST_CTRL
+{
+    TREE_ROOT List;
 
 
+} CONTEXT_LIST_CTRL, *PCONTEXT_LIST_CTRL;
 
 // http://fsfilters.blogspot.co.uk/2010/02/filter-manager-concepts-part-6.html
 typedef struct _STREAM_LIST_CTRL // size = 0xC8 (200)
 
 // http://fsfilters.blogspot.co.uk/2010/02/filter-manager-concepts-part-6.html
 typedef struct _STREAM_LIST_CTRL // size = 0xC8 (200)
@@ -156,16 +170,16 @@ typedef struct _STREAM_LIST_CTRL // size = 0xC8 (200)
     FLT_TYPE Type;
     FSRTL_PER_STREAM_CONTEXT ContextCtrl;
     LIST_ENTRY VolumeLink;
     FLT_TYPE Type;
     FSRTL_PER_STREAM_CONTEXT ContextCtrl;
     LIST_ENTRY VolumeLink;
-    //STREAM_LIST_CTRL_FLAGS Flags;
+    ULONG Flags; //STREAM_LIST_CTRL_FLAGS Flags;
     int UseCount;
     ERESOURCE ContextLock;
     int UseCount;
     ERESOURCE ContextLock;
-    //CONTEXT_LIST_CTRL StreamContexts;
-    //CONTEXT_LIST_CTRL StreamHandleContexts;
+    CONTEXT_LIST_CTRL StreamContexts;
+    CONTEXT_LIST_CTRL StreamHandleContexts;
     ERESOURCE NameCacheLock;
     LARGE_INTEGER LastRenameCompleted;
     ERESOURCE NameCacheLock;
     LARGE_INTEGER LastRenameCompleted;
-    //NAME_CACHE_LIST_CTRL NormalizedNameCache;
-   // NAME_CACHE_LIST_CTRL ShortNameCache;
-   // NAME_CACHE_LIST_CTRL OpenedNameCache;
+    ULONG NormalizedNameCache; //NAME_CACHE_LIST_CTRL NormalizedNameCache;
+    ULONG ShortNameCache; // NAME_CACHE_LIST_CTRL ShortNameCache;
+    ULONG OpenedNameCache; // NAME_CACHE_LIST_CTRL OpenedNameCache;
     int AllNameContextsTemporary;
 
 } STREAM_LIST_CTRL, *PSTREAM_LIST_CTRL;
     int AllNameContextsTemporary;
 
 } STREAM_LIST_CTRL, *PSTREAM_LIST_CTRL;
@@ -186,6 +200,17 @@ typedef struct _FLT_SERVER_PORT_OBJECT
 } FLT_SERVER_PORT_OBJECT, *PFLT_SERVER_PORT_OBJECT;
 
 
 } FLT_SERVER_PORT_OBJECT, *PFLT_SERVER_PORT_OBJECT;
 
 
+typedef struct _FLT_MESSAGE_WAITER_QUEUE
+{
+    IO_CSQ Csq;
+    FLT_MUTEX_LIST_HEAD WaiterQ;
+    ULONG MinimumWaiterLength;
+    KSEMAPHORE Semaphore;
+    KEVENT Event;
+
+} FLT_MESSAGE_WAITER_QUEUE, *PFLT_MESSAGE_WAITER_QUEUE;
+
+
 typedef struct _FLT_PORT_OBJECT
 {
     LIST_ENTRY FilterLink;
 typedef struct _FLT_PORT_OBJECT
 {
     LIST_ENTRY FilterLink;
@@ -193,7 +218,7 @@ typedef struct _FLT_PORT_OBJECT
     PVOID Cookie;
     EX_RUNDOWN_REF MsgNotifRundownRef;
     FAST_MUTEX Lock;
     PVOID Cookie;
     EX_RUNDOWN_REF MsgNotifRundownRef;
     FAST_MUTEX Lock;
-    PVOID MsgQ; // FLT_MESSAGE_WAITER_QUEUE MsgQ;
+    FLT_MESSAGE_WAITER_QUEUE MsgQ;
     ULONGLONG MessageId;
     KEVENT DisconnectEvent;
     BOOLEAN Disconnected;
     ULONGLONG MessageId;
     KEVENT DisconnectEvent;
     BOOLEAN Disconnected;
@@ -232,18 +257,6 @@ typedef struct _CALLBACK_CTRL
 
 } CALLBACK_CTRL, *PCALLBACK_CTRL;
 
 
 } CALLBACK_CTRL, *PCALLBACK_CTRL;
 
-typedef struct _TREE_ROOT
-{
-    RTL_SPLAY_LINKS *Tree;
-
-} TREE_ROOT, *PTREE_ROOT;
-
-
-typedef struct _CONTEXT_LIST_CTRL
-{
-    TREE_ROOT List;
-
-} CONTEXT_LIST_CTRL, *PCONTEXT_LIST_CTRL;
 
 typedef struct _NAME_CACHE_LIST_CTRL_STATS
 {
 
 typedef struct _NAME_CACHE_LIST_CTRL_STATS
 {
@@ -311,6 +324,58 @@ typedef struct _FLT_VOLUME
 } FLT_VOLUME, *PFLT_VOLUME;
 
 
 } FLT_VOLUME, *PFLT_VOLUME;
 
 
+typedef struct _MANAGER_CCB
+{
+    PFLTP_FRAME Frame;
+    unsigned int Iterator;
+
+} MANAGER_CCB, *PMANAGER_CCB;
+
+typedef struct _FILTER_CCB
+{
+    PFLT_FILTER Filter;
+    unsigned int Iterator;
+
+} FILTER_CCB, *PFILTER_CCB;
+
+typedef struct _INSTANCE_CCB
+{
+    PFLT_INSTANCE Instance;
+
+} INSTANCE_CCB, *PINSTANCE_CCB;
+
+typedef struct _VOLUME_CCB
+{
+    UNICODE_STRING Volume;
+    unsigned int Iterator;
+
+} VOLUME_CCB, *PVOLUME_CCB;
+
+typedef struct _PORT_CCB
+{
+    PFLT_PORT_OBJECT Port;
+    FLT_MUTEX_LIST_HEAD ReplyWaiterList;
+
+} PORT_CCB, *PPORT_CCB;
+
+
+typedef union _CCB_TYPE
+{
+    MANAGER_CCB Manager;
+    FILTER_CCB Filter;
+    INSTANCE_CCB Instance;
+    VOLUME_CCB Volume;
+    PORT_CCB Port;
+
+} CCB_TYPE, *PCCB_TYPE;
+
+
+typedef struct _FLT_CCB
+{
+    FLT_TYPE Type;
+    CCB_TYPE Data;
+
+} FLT_CCB, *PFLT_CCB;
 
 VOID
 FltpExInitializeRundownProtection(
 
 VOID
 FltpExInitializeRundownProtection(
@@ -387,6 +452,21 @@ FltpDispatchHandler(
     _Inout_ PIRP Irp
 );
 
     _Inout_ PIRP Irp
 );
 
+NTSTATUS
+FltpMsgCreate(
+    _In_ PDEVICE_OBJECT DeviceObject,
+    _Inout_ PIRP Irp
+);
 
 
+NTSTATUS
+FltpMsgDispatch(
+    _In_ PDEVICE_OBJECT DeviceObject,
+    _Inout_ PIRP Irp
+);
+
+NTSTATUS
+FltpSetupCommunicationObjects(
+    _In_ PDRIVER_OBJECT DriverObject
+);
 
 #endif /* _FLTMGR_INTERNAL_H */
 
 #endif /* _FLTMGR_INTERNAL_H */
index 64b088d..c66ed9a 100644 (file)
@@ -127,6 +127,10 @@ list(APPEND KMTEST_SOURCE
     kmtest/testlist.c
 
     example/Example_user.c
     kmtest/testlist.c
 
     example/Example_user.c
+
+    fltmgr/fltmgr_load/fltmgr_user.c
+    fltmgr/fltmgr_register/fltmgr_reg_user.c
+
     hidparse/HidP_user.c
     kernel32/FileAttributes_user.c
     kernel32/FindFile_user.c
     hidparse/HidP_user.c
     kernel32/FileAttributes_user.c
     kernel32/FindFile_user.c
index fd7a55e..f0fdac8 100644 (file)
@@ -1,3 +1,4 @@
 
 add_subdirectory(fltmgr_load)
 add_subdirectory(fltmgr_create)
 
 add_subdirectory(fltmgr_load)
 add_subdirectory(fltmgr_create)
+add_subdirectory(fltmgr_register)
index 7fa022d..f93e516 100644 (file)
@@ -5,10 +5,10 @@ list(APPEND FLTMGR_TEST_DRV_SOURCE
     ../../kmtest_drv/kmtest_fsminifilter.c
     fltmgr_load.c)
 
     ../../kmtest_drv/kmtest_fsminifilter.c
     fltmgr_load.c)
 
-add_library(fltmgr_load SHARED ${FLTMGR_TEST_DRV_SOURCE})
-set_module_type(fltmgr_load kernelmodedriver)
-target_link_libraries(fltmgr_load kmtest_printf ${PSEH_LIB})
-add_importlibs(fltmgr_load fltmgr ntoskrnl hal)
-add_target_compile_definitions(fltmgr_load KMT_STANDALONE_DRIVER KMT_FILTER_DRIVER NTDDI_VERSION=NTDDI_WS03SP1)
+add_library(FltMgrLoad_drv SHARED ${FLTMGR_TEST_DRV_SOURCE})
+set_module_type(FltMgrLoad_drv kernelmodedriver)
+target_link_libraries(FltMgrLoad_drv kmtest_printf ${PSEH_LIB})
+add_importlibs(FltMgrLoad_drv fltmgr ntoskrnl hal)
+add_target_compile_definitions(FltMgrLoad_drv KMT_STANDALONE_DRIVER KMT_FILTER_DRIVER NTDDI_VERSION=NTDDI_WS03SP1)
 #add_pch(example_drv ../include/kmt_test.h)
 #add_pch(example_drv ../include/kmt_test.h)
-add_rostests_file(TARGET fltmgr_load)
+add_rostests_file(TARGET FltMgrLoad_drv)
index 7b9f523..e56d7eb 100644 (file)
@@ -89,7 +89,7 @@ TestEntry(
     ok_irql(PASSIVE_LEVEL);
     TestDriverObject = DriverObject;
 
     ok_irql(PASSIVE_LEVEL);
     TestDriverObject = DriverObject;
 
-    *DeviceName = L"fltmgr_load";
+    *DeviceName = L"FltMgrLoad";
 
     trace("Hi, this is the filter manager load test driver\n");
 
 
     trace("Hi, this is the filter manager load test driver\n");
 
diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_user.c b/modules/rostests/kmtests/fltmgr/fltmgr_load/fltmgr_user.c
new file mode 100644 (file)
index 0000000..727e9c0
--- /dev/null
@@ -0,0 +1,28 @@
+/*
+ * PROJECT:         ReactOS kernel-mode tests - Filter Manager
+ * LICENSE:         GPLv2+ - See COPYING in the top level directory
+ * PURPOSE:         Tests for checking filters load and connect correctly
+ * PROGRAMMER:      Ged Murphy <gedmurphy@reactos.org>
+ */
+
+#include <kmt_test.h>
+
+
+START_TEST(FltMgrLoad)
+{
+    static WCHAR FilterName[] = L"FltMgrLoad";
+    SC_HANDLE hService;
+    HANDLE hPort;
+
+    trace("Message from user-mode\n");
+
+    ok(KmtFltCreateService(FilterName, L"FltMgrLoad test driver", &hService) == ERROR_SUCCESS, "\n");
+    ok(KmtFltLoadDriver(FALSE, FALSE, FALSE, &hPort) == ERROR_PRIVILEGE_NOT_HELD, "\n");
+    ok(KmtFltLoadDriver(TRUE, FALSE, FALSE, &hPort) == ERROR_SUCCESS, "\n");
+
+    ok(KmtFltConnectComms(&hPort) == ERROR_SUCCESS, "\n");
+
+    ok(KmtFltDisconnectComms(hPort) == ERROR_SUCCESS, "\n");
+    ok(KmtFltUnloadDriver(hPort, FALSE) == ERROR_SUCCESS, "\n");
+    KmtFltDeleteService(NULL, &hService);
+}
diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_register/CMakeLists.txt b/modules/rostests/kmtests/fltmgr/fltmgr_register/CMakeLists.txt
new file mode 100644 (file)
index 0000000..2857a32
--- /dev/null
@@ -0,0 +1,15 @@
+
+include_directories(../../include)
+include_directories(${REACTOS_SOURCE_DIR}/drivers/filters/fltmgr)
+
+list(APPEND FLTMGR_TEST_DRV_SOURCE
+    ../../kmtest_drv/kmtest_fsminifilter.c
+    fltmgr_register.c)
+
+add_library(fltmgrreg_drv SHARED ${FLTMGR_TEST_DRV_SOURCE})
+set_module_type(fltmgrreg_drv kernelmodedriver)
+target_link_libraries(fltmgrreg_drv kmtest_printf ${PSEH_LIB})
+add_importlibs(fltmgrreg_drv fltmgr ntoskrnl hal)
+add_target_compile_definitions(fltmgrreg_drv KMT_STANDALONE_DRIVER KMT_FILTER_DRIVER NTDDI_VERSION=NTDDI_WS03SP1)
+#add_pch(example_drv ../include/kmt_test.h)
+add_rostests_file(TARGET fltmgrreg_drv)
diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_reg_user.c b/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_reg_user.c
new file mode 100644 (file)
index 0000000..da7d072
--- /dev/null
@@ -0,0 +1,23 @@
+/*
+ * PROJECT:         ReactOS kernel-mode tests - Filter Manager
+ * LICENSE:         GPLv2+ - See COPYING in the top level directory
+ * PURPOSE:         Tests for checking filter registration
+ * PROGRAMMER:      Ged Murphy <gedmurphy@reactos.org>
+ */
+
+#include <kmt_test.h>
+
+
+START_TEST(FltMgrReg)
+{
+    static WCHAR FilterName[] = L"FltMgrReg";
+    SC_HANDLE hService;
+    HANDLE hPort;
+
+    ok(KmtFltCreateService(FilterName, L"FltMgrLoad test driver", &hService) == ERROR_SUCCESS, "Failed to create the reg entry\n");
+    ok(KmtFltAddAltitude(L"123456") == ERROR_SUCCESS, "\n");
+    ok(KmtFltLoadDriver(TRUE, FALSE, FALSE, &hPort) == ERROR_SUCCESS, "Failed to load the driver\n");
+    //__debugbreak();
+    ok(KmtFltUnloadDriver(hPort, FALSE) == ERROR_SUCCESS, "Failed to unload the driver\n");
+    ok(KmtFltDeleteService(NULL, &hService) == ERROR_SUCCESS, "Failed to delete the driver\n");
+}
diff --git a/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_register.c b/modules/rostests/kmtests/fltmgr/fltmgr_register/fltmgr_register.c
new file mode 100644 (file)
index 0000000..9703dfa
--- /dev/null
@@ -0,0 +1,248 @@
+/*
+ * PROJECT:         ReactOS kernel-mode tests - Filter Manager
+ * LICENSE:         GPLv2+ - See COPYING in the top level directory
+ * PURPOSE:         Tests for checking filter registration
+ * PROGRAMMER:      Ged Murphy <gedmurphy@reactos.org>
+ */
+
+// This tests needs to be run via a standalone driver because FltRegisterFilter
+// uses the DriverObject in its internal structures, and we don't want it to be
+// linked to a device object from the test suite itself.
+
+#include <kmt_test.h>
+#include <fltkernel.h>
+#include <fltmgrint.h>
+
+//#define NDEBUG
+#include <debug.h>
+
+#define RESET_REGISTRATION(basic) \
+    do { \
+        RtlZeroMemory(&FilterRegistration, sizeof(FLT_REGISTRATION)); \
+        if (basic) { \
+            FilterRegistration.Size = sizeof(FLT_REGISTRATION); \
+            FilterRegistration.Version = FLT_REGISTRATION_VERSION; \
+        } \
+    } while (0)
+
+#define RESET_UNLOAD(DO) DO->DriverUnload = NULL;
+
+
+NTSTATUS
+FLTAPI
+TestRegFilterUnload(
+    _In_ FLT_FILTER_UNLOAD_FLAGS Flags
+);
+
+/* Globals */
+static PDRIVER_OBJECT TestDriverObject;
+static FLT_REGISTRATION FilterRegistration;
+static PFLT_FILTER TestFilter = NULL;
+
+
+
+
+BOOLEAN
+TestFltRegisterFilter(_In_ PDRIVER_OBJECT DriverObject)
+{
+    UNICODE_STRING Altitude;
+    UNICODE_STRING Name;
+    PFLT_FILTER Filter = NULL;
+    PFLT_FILTER Temp = NULL;
+    NTSTATUS Status;
+
+    RESET_REGISTRATION(FALSE);
+#if 0
+    KmtStartSeh()
+        Status = FltRegisterFilter(NULL, &FilterRegistration, &Filter);
+    KmtEndSeh(STATUS_INVALID_PARAMETER);
+
+    KmtStartSeh()
+        Status = FltRegisterFilter(DriverObject, NULL, &Filter);
+    KmtEndSeh(STATUS_INVALID_PARAMETER);
+
+    KmtStartSeh()
+        Status = FltRegisterFilter(DriverObject, &FilterRegistration, NULL);
+    KmtEndSeh(STATUS_INVALID_PARAMETER)
+#endif
+
+    RESET_REGISTRATION(TRUE);
+    FilterRegistration.Version = 0x0100;
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter);
+    ok_eq_hex(Status, STATUS_INVALID_PARAMETER);
+
+    RESET_REGISTRATION(TRUE);
+    FilterRegistration.Version = 0x0300;
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter);
+    ok_eq_hex(Status, STATUS_INVALID_PARAMETER);
+
+    RESET_REGISTRATION(TRUE);
+    FilterRegistration.Version = 0x0200;
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter);
+    ok_eq_hex(Status, STATUS_SUCCESS);
+    FltUnregisterFilter(Filter);
+
+
+    /* Test invalid sizes. MSDN says this is required, but it doesn't appear to be */
+    RESET_REGISTRATION(TRUE);
+    FilterRegistration.Size = 0;
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter);
+    ok_eq_hex(Status, STATUS_SUCCESS);
+    FltUnregisterFilter(Filter);
+
+    RESET_REGISTRATION(TRUE);
+    FilterRegistration.Size = 0xFFFF;
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter);
+    ok_eq_hex(Status, STATUS_SUCCESS);
+    FltUnregisterFilter(Filter);
+
+
+    /* Now make a valid registration */
+    RESET_REGISTRATION(TRUE);
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Filter);
+    ok_eq_hex(Status, STATUS_SUCCESS);
+
+    /* Try to register again */
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &Temp);
+    ok_eq_hex(Status, STATUS_FLT_INSTANCE_ALTITUDE_COLLISION);
+
+
+    ok_eq_hex(Filter->Base.Flags, FLT_OBFL_TYPE_FILTER);
+    
+    /* Check we have the right filter name */
+    RtlInitUnicodeString(&Name, L"Kmtest-FltMgrReg");
+    ok_eq_long(RtlCompareUnicodeString(&Filter->Name, &Name, FALSE), 0);
+
+    /* And the altitude is corect */
+    RtlInitUnicodeString(&Altitude, L"123456");
+    ok_eq_long(RtlCompareUnicodeString(&Filter->DefaultAltitude, &Altitude, FALSE), 0);
+
+    //
+    // FIXME: More checks
+    //
+
+    /* Cleanup the valid registration */
+    FltUnregisterFilter(Filter);
+
+    /*
+     * The last thing we'll do before we exit is to properly register with the filter manager 
+     * and set an unload routine. This'll let us test the FltUnregisterFilter routine
+     */
+    RESET_REGISTRATION(TRUE);
+
+    /* Set a fake unload routine we'll use to test */
+    DriverObject->DriverUnload = (PDRIVER_UNLOAD)0x1234FFFF;
+
+    FilterRegistration.FilterUnloadCallback = TestRegFilterUnload;
+    Status = FltRegisterFilter(DriverObject, &FilterRegistration, &TestFilter);
+    ok_eq_hex(Status, STATUS_SUCCESS);
+    
+    /* Test all the unlod routines */
+    ok_eq_pointer(TestFilter->FilterUnload, TestRegFilterUnload);
+    ok_eq_pointer(TestFilter->OldDriverUnload, (PFLT_FILTER_UNLOAD_CALLBACK)0x1234FFFF);
+
+    // This should equal the fltmgr's private unload routine, but there's no easy way of testing it...
+    //ok_eq_pointer(DriverObject->DriverUnload, FltpMiniFilterDriverUnload);
+
+    /* Make sure our test address is never actually called */
+    TestFilter->OldDriverUnload = (PFLT_FILTER_UNLOAD_CALLBACK)NULL;
+
+    return TRUE;
+}
+
+
+NTSTATUS
+FLTAPI
+TestRegFilterUnload(
+    _In_ FLT_FILTER_UNLOAD_FLAGS Flags)
+{
+    //__debugbreak();
+
+    ok_irql(PASSIVE_LEVEL);
+    ok(TestFilter != NULL, "Buffer is NULL\n");
+
+    //
+    // FIXME: Add tests
+    //
+
+    FltUnregisterFilter(TestFilter);
+
+    //
+    // FIXME: Add tests
+    //
+
+    return STATUS_SUCCESS;
+}
+
+
+
+
+
+
+
+
+
+
+/*
+ * KMT Callback routines
+ */
+
+NTSTATUS
+TestEntry(
+    IN PDRIVER_OBJECT DriverObject,
+    IN PCUNICODE_STRING RegistryPath,
+    OUT PCWSTR *DeviceName,
+    IN OUT INT *Flags)
+{
+    NTSTATUS Status = STATUS_SUCCESS;
+
+    PAGED_CODE();
+
+    UNREFERENCED_PARAMETER(RegistryPath);
+
+    DPRINT("FltMgrReg Entry!\n");
+    trace("Entered FltMgrReg tests\n");
+
+    /* We'll do the work ourselves in this test */
+    *Flags = TESTENTRY_NO_ALL;
+
+    ok_irql(PASSIVE_LEVEL);
+    TestDriverObject = DriverObject;
+
+
+    /* Run the tests */
+    (VOID)TestFltRegisterFilter(DriverObject);
+
+    return Status;
+}
+
+VOID
+TestFilterUnload(
+    IN ULONG Flags)
+{
+    PAGED_CODE();
+    ok_irql(PASSIVE_LEVEL);
+}
+
+NTSTATUS
+TestInstanceSetup(
+    _In_ PCFLT_RELATED_OBJECTS FltObjects,
+    _In_ FLT_INSTANCE_SETUP_FLAGS Flags,
+    _In_ DEVICE_TYPE VolumeDeviceType,
+    _In_ FLT_FILESYSTEM_TYPE VolumeFilesystemType,
+    _In_ PUNICODE_STRING VolumeName,
+    _In_ ULONG SectorSize,
+    _In_ ULONG ReportedSectorSize
+)
+{
+    return STATUS_FLT_DO_NOT_ATTACH;
+}
+
+VOID
+TestQueryTeardown(
+    _In_ PCFLT_RELATED_OBJECTS FltObjects,
+    _In_ FLT_INSTANCE_QUERY_TEARDOWN_FLAGS Flags)
+{
+    UNREFERENCED_PARAMETER(FltObjects);
+    UNREFERENCED_PARAMETER(Flags);
+}
index 636f099..3e29d2c 100644 (file)
@@ -44,6 +44,7 @@
 #include <strsafe.h>
 #include <fltuser.h>
 
 #include <strsafe.h>
 #include <fltuser.h>
 
+
 #ifdef KMT_EMULATE_KERNEL
 #define ok_irql(i)
 #define KIRQL int
 #ifdef KMT_EMULATE_KERNEL
 #define ok_irql(i)
 #define KIRQL int
index 8fcc93a..1c5ddee 100644 (file)
@@ -126,9 +126,12 @@ NTSTATUS KmtFilterRegisterCallbacks(_In_ CONST FLT_OPERATION_REGISTRATION *Opera
 
 typedef enum
 {
 
 typedef enum
 {
-    TESTENTRY_NO_REGISTER_FILTER = 1,
-    TESTENTRY_NO_CREATE_COMMS_PORT = 2,
-    TESTENTRY_NO_START_FILTERING = 4,
+    TESTENTRY_NO_REGISTER_FILTER    = 0x01,
+    TESTENTRY_NO_CREATE_COMMS_PORT  = 0x02,
+    TESTENTRY_NO_START_FILTERING    = 0x04,
+    TESTENTRY_NO_INSTANCE_SETUP     = 0x08,
+    TESTENTRY_NO_QUERY_TEARDOWN     = 0x10,
+    TESTENTRY_NO_ALL                = 0xFF
 } KMT_MINIFILTER_FLAGS;
 
 VOID TestFilterUnload(_In_ ULONG Flags);
 } KMT_MINIFILTER_FLAGS;
 
 VOID TestFilterUnload(_In_ ULONG Flags);
@@ -174,8 +177,14 @@ DWORD KmtSendWStringToDriver(IN DWORD ControlCode, IN PCWSTR String);
 DWORD KmtSendUlongToDriver(IN DWORD ControlCode, IN DWORD Value);
 DWORD KmtSendBufferToDriver(IN DWORD ControlCode, IN OUT PVOID Buffer OPTIONAL, IN DWORD InLength, IN OUT PDWORD OutLength);
 
 DWORD KmtSendUlongToDriver(IN DWORD ControlCode, IN DWORD Value);
 DWORD KmtSendBufferToDriver(IN DWORD ControlCode, IN OUT PVOID Buffer OPTIONAL, IN DWORD InLength, IN OUT PDWORD OutLength);
 
-DWORD KmtFltLoadDriver(_In_z_ PCWSTR ServiceName, _In_ BOOLEAN RestartIfRunning, _In_ BOOLEAN ConnectComms, _Out_ HANDLE *hPort);
+
+DWORD KmtFltCreateService(_In_z_ PCWSTR ServiceName, _In_z_ PCWSTR DisplayName, _Out_ SC_HANDLE *ServiceHandle);
+DWORD KmtFltDeleteService(_In_opt_z_ PCWSTR ServiceName, _Inout_ SC_HANDLE *ServiceHandle);
+DWORD KmtFltAddAltitude(_In_z_ LPWSTR Altitude);
+DWORD KmtFltLoadDriver(_In_ BOOLEAN EnableDriverLoadPrivlege, _In_ BOOLEAN RestartIfRunning, _In_ BOOLEAN ConnectComms, _Out_ HANDLE *hPort);
 DWORD KmtFltUnloadDriver(_In_ HANDLE *hPort, _In_ BOOLEAN DisonnectComms);
 DWORD KmtFltUnloadDriver(_In_ HANDLE *hPort, _In_ BOOLEAN DisonnectComms);
+DWORD KmtFltConnectComms(_Out_ HANDLE *hPort);
+DWORD KmtFltDisconnectComms(_In_ HANDLE hPort);
 DWORD KmtFltRunKernelTest(_In_ HANDLE hPort, _In_z_ PCSTR TestName);
 DWORD KmtFltSendToDriver(_In_ HANDLE hPort, _In_ DWORD Message);
 DWORD KmtFltSendStringToDriver(_In_ HANDLE hPort, _In_ DWORD Message, _In_ PCSTR String);
 DWORD KmtFltRunKernelTest(_In_ HANDLE hPort, _In_z_ PCSTR TestName);
 DWORD KmtFltSendToDriver(_In_ HANDLE hPort, _In_ DWORD Message);
 DWORD KmtFltSendStringToDriver(_In_ HANDLE hPort, _In_ DWORD Message, _In_ PCSTR String);
index cb59574..042d174 100644 (file)
 
 #define SERVICE_ACCESS (SERVICE_START | SERVICE_STOP | DELETE)
 
 
 #define SERVICE_ACCESS (SERVICE_START | SERVICE_STOP | DELETE)
 
-/*
- * We need to call the internal function in the service.c file
- */
-DWORD
-KmtpCreateService(
-    IN PCWSTR ServiceName,
-    IN PCWSTR ServicePath,
-    IN PCWSTR DisplayName OPTIONAL,
-    IN DWORD ServiceType,
-    OUT SC_HANDLE *ServiceHandle);
-
-static SC_HANDLE ScmHandle;
 
 
 
 
-/**
- * @name KmtFltCreateService
- *
- * Create the specified driver service and return a handle to it
- *
- * @param ServiceName
- *        Name of the service to create
- * @param ServicePath
- *        File name of the driver, relative to the current directory
- * @param DisplayName
- *        Service display name
- * @param ServiceHandle
- *        Pointer to a variable to receive the handle to the service
- *
- * @return Win32 error code
- */
-DWORD
-KmtFltCreateService(
-    _In_z_ PCWSTR ServiceName,
-    _In_z_ PCWSTR ServicePath,
-    _In_z_ PCWSTR DisplayName OPTIONAL,
-    _Out_ SC_HANDLE *ServiceHandle)
-{
-    return KmtpCreateService(ServiceName,
-                             ServicePath,
-                             DisplayName,
-                             SERVICE_FILE_SYSTEM_DRIVER,
-                             ServiceHandle);
-}
-
 /**
  * @name KmtFltLoad
  *
 /**
  * @name KmtFltLoad
  *
@@ -82,7 +40,7 @@ KmtFltLoad(
 
     return Error;
 }
 
     return Error;
 }
-
+#if 0
 /**
  * @name KmtFltCreateAndStartService
  *
 /**
  * @name KmtFltCreateAndStartService
  *
@@ -143,7 +101,7 @@ cleanup:
     assert(Error);
     return Error;
 }
     assert(Error);
     return Error;
 }
-
+#endif
 
 /**
  * @name KmtFltConnect
 
 /**
  * @name KmtFltConnect
@@ -163,7 +121,7 @@ KmtFltConnect(
     _Out_ HANDLE *hPort)
 {
     HRESULT hResult;
     _Out_ HANDLE *hPort)
 {
     HRESULT hResult;
-    DWORD Error = ERROR_SUCCESS;
+    DWORD Error;
 
     assert(ServiceName);
     assert(hPort);
 
     assert(ServiceName);
     assert(hPort);
@@ -191,7 +149,7 @@ KmtFltConnect(
  */
 DWORD
 KmtFltDisconnect(
  */
 DWORD
 KmtFltDisconnect(
-    _Out_ HANDLE *hPort)
+    _In_ HANDLE hPort)
 {
     DWORD Error = ERROR_SUCCESS;
 
 {
     DWORD Error = ERROR_SUCCESS;
 
@@ -388,41 +346,3 @@ KmtFltUnload(
 
     return Error;
 }
 
     return Error;
 }
-
-/**
- * @name KmtFltDeleteService
- *
- * Delete the specified filter driver
- *
- * @param ServiceName
- *        If *ServiceHandle is NULL, name of the service to delete
- * @param ServiceHandle
- *        Pointer to a variable containing the service handle.
- *        Will be set to NULL on success
- *
- * @return Win32 error code
- */
-DWORD
-KmtFltDeleteService(
-    _In_z_ PCWSTR ServiceName OPTIONAL,
-    _Inout_ SC_HANDLE *ServiceHandle)
-{
-    return KmtDeleteService(ServiceName, ServiceHandle);
-}
-
-/**
- * @name KmtFltCloseService
- *
- * Close the specified driver service handle
- *
- * @param ServiceHandle
- *        Pointer to a variable containing the service handle.
- *        Will be set to NULL on success
- *
- * @return Win32 error code
- */
-DWORD KmtFltCloseService(
-    _Inout_ SC_HANDLE *ServiceHandle)
-{
-    return KmtCloseService(ServiceHandle);
-}
index 772e98a..4127cca 100644 (file)
@@ -7,12 +7,28 @@
 
 #include <kmt_test.h>
 
 
 #include <kmt_test.h>
 
+#define KMT_FLT_USER_MODE
 #include "kmtest.h"
 #include <kmt_public.h>
 
 #include <assert.h>
 #include <debug.h>
 
 #include "kmtest.h"
 #include <kmt_public.h>
 
 #include <assert.h>
 #include <debug.h>
 
+/*
+ * We need to call the internal function in the service.c file
+ */
+DWORD
+KmtpCreateService(
+    IN PCWSTR ServiceName,
+    IN PCWSTR ServicePath,
+    IN PCWSTR DisplayName OPTIONAL,
+    IN DWORD ServiceType,
+    OUT SC_HANDLE *ServiceHandle);
+
+DWORD EnablePrivilegeInCurrentProcess(
+    _In_z_ LPWSTR lpPrivName,
+    _In_ BOOL bEnable);
+
 // move to a shared location
 typedef struct _KMTFLT_MESSAGE_HEADER
 {
 // move to a shared location
 typedef struct _KMTFLT_MESSAGE_HEADER
 {
@@ -26,33 +42,30 @@ extern HANDLE KmtestHandle;
 static WCHAR TestServiceName[MAX_PATH];
 
 
 static WCHAR TestServiceName[MAX_PATH];
 
 
+
 /**
 /**
- * @name KmtFltLoadDriver
+ * @name KmtFltCreateService
  *
  *
- * Load the specified filter driver
- * This routine will create the service entry if it doesn't already exist
+ * Create the specified driver service and return a handle to it
  *
  * @param ServiceName
  *
  * @param ServiceName
- *        Name of the driver service (Kmtest- prefix will be added automatically)
- * @param RestartIfRunning
- *        TRUE to stop and restart the service if it is already running
- * @param ConnectComms
- *        TRUE to create a comms connection to the specified filter
- * @param hPort
- *        Handle to the filter's comms port
+ *        Name of the service to create
+ * @param ServicePath
+ *        File name of the driver, relative to the current directory
+ * @param DisplayName
+ *        Service display name
+ * @param ServiceHandle
+ *        Pointer to a variable to receive the handle to the service
  *
  * @return Win32 error code
  */
 DWORD
  *
  * @return Win32 error code
  */
 DWORD
-KmtFltLoadDriver(
+KmtFltCreateService(
     _In_z_ PCWSTR ServiceName,
     _In_z_ PCWSTR ServiceName,
-    _In_ BOOLEAN RestartIfRunning,
-    _In_ BOOLEAN ConnectComms,
-    _Out_ HANDLE *hPort)
+    _In_z_ PCWSTR DisplayName,
+    _Out_ SC_HANDLE *ServiceHandle)
 {
 {
-    DWORD Error = ERROR_SUCCESS;
     WCHAR ServicePath[MAX_PATH];
     WCHAR ServicePath[MAX_PATH];
-    SC_HANDLE TestServiceHandle;
 
     StringCbCopy(ServicePath, sizeof ServicePath, ServiceName);
     StringCbCat(ServicePath, sizeof ServicePath, L"_drv.sys");
 
     StringCbCopy(ServicePath, sizeof ServicePath, ServiceName);
     StringCbCat(ServicePath, sizeof ServicePath, L"_drv.sys");
@@ -60,11 +73,81 @@ KmtFltLoadDriver(
     StringCbCopy(TestServiceName, sizeof TestServiceName, L"Kmtest-");
     StringCbCat(TestServiceName, sizeof TestServiceName, ServiceName);
 
     StringCbCopy(TestServiceName, sizeof TestServiceName, L"Kmtest-");
     StringCbCat(TestServiceName, sizeof TestServiceName, ServiceName);
 
-    Error = KmtFltCreateAndStartService(TestServiceName, ServicePath, NULL, &TestServiceHandle, TRUE);
+    return KmtpCreateService(TestServiceName,
+                             ServicePath,
+                             DisplayName,
+                             SERVICE_FILE_SYSTEM_DRIVER,
+                             ServiceHandle);
+}
+
+/**
+ * @name KmtFltDeleteService
+ *
+ * Delete the specified filter driver
+ *
+ * @param ServiceName
+ *        If *ServiceHandle is NULL, name of the service to delete
+ * @param ServiceHandle
+ *        Pointer to a variable containing the service handle.
+ *        Will be set to NULL on success
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltDeleteService(
+    _In_opt_z_ PCWSTR ServiceName,
+    _Inout_ SC_HANDLE *ServiceHandle)
+{
+    return KmtDeleteService(ServiceName, ServiceHandle);
+}
 
 
-    if (Error == ERROR_SUCCESS && ConnectComms)
+/**
+ * @name KmtFltLoadDriver
+ *
+ * Delete the specified filter driver
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltLoadDriver(
+    _In_ BOOLEAN EnableDriverLoadPrivlege,
+    _In_ BOOLEAN RestartIfRunning,
+    _In_ BOOLEAN ConnectComms,
+    _Out_ HANDLE *hPort
+)
+{
+    DWORD Error;
+
+    if (EnableDriverLoadPrivlege)
+    {
+        Error = EnablePrivilegeInCurrentProcess(SE_LOAD_DRIVER_NAME , TRUE);
+        if (Error)
+        {
+            return Error;
+        }
+    }
+
+    Error = KmtFltLoad(TestServiceName);
+    if ((Error == ERROR_SERVICE_ALREADY_RUNNING) && RestartIfRunning)
+    {
+        Error = KmtFltUnload(TestServiceName);
+        if (Error)
+        {
+            // TODO
+            __debugbreak();
+        }
+
+        Error = KmtFltLoad(TestServiceName);
+    }
+
+    if (Error)
     {
     {
-        Error = KmtFltConnect(ServiceName, hPort);
+        return Error;
+    }
+
+    if (ConnectComms)
+    {
+        Error = KmtFltConnectComms(hPort);
     }
 
     return Error;
     }
 
     return Error;
@@ -77,7 +160,7 @@ KmtFltLoadDriver(
  *
  * @param hPort
  *        Handle to the filter's comms port
  *
  * @param hPort
  *        Handle to the filter's comms port
- * @param ConnectComms
+ * @param DisonnectComms
  *        TRUE to disconnect the comms connection before unloading
  *
  * @return Win32 error code
  *        TRUE to disconnect the comms connection before unloading
  *
  * @return Win32 error code
@@ -110,6 +193,57 @@ KmtFltUnloadDriver(
     return Error;
 }
 
     return Error;
 }
 
+/**
+ * @name KmtFltConnectComms
+ *
+ * Create a comms connection to the specified filter
+ *
+ * @param hPort
+ *        Handle to the filter's comms port
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltConnectComms(
+    _Out_ HANDLE *hPort)
+{
+    return KmtFltConnect(TestServiceName, hPort);
+}
+
+/**
+ * @name KmtFltDisconnectComms
+ *
+ * Disconenct from the comms port
+ *
+ * @param hPort
+ *        Handle to the filter's comms port
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltDisconnectComms(
+    _In_ HANDLE hPort)
+{
+    return KmtFltDisconnect(hPort);
+}
+
+
+/**
+* @name KmtFltCloseService
+*
+* Close the specified driver service handle
+*
+* @param ServiceHandle
+*        Pointer to a variable containing the service handle.
+*        Will be set to NULL on success
+*
+* @return Win32 error code
+*/
+DWORD KmtFltCloseService(
+    _Inout_ SC_HANDLE *ServiceHandle)
+{
+    return KmtCloseService(ServiceHandle);
+}
 
 /**
 * @name KmtFltRunKernelTest
 
 /**
 * @name KmtFltRunKernelTest
@@ -141,7 +275,7 @@ KmtFltRunKernelTest(
 * @param Message
 *        The message to send to the filter
 *
 * @param Message
 *        The message to send to the filter
 *
-* @return Win32 error code as returned by DeviceIoControl
+* @return Win32 error code
 */
 DWORD
 KmtFltSendToDriver(
 */
 DWORD
 KmtFltSendToDriver(
@@ -165,7 +299,7 @@ KmtFltSendToDriver(
  * @param String
  *        An ANSI string to send to the filter
  *
  * @param String
  *        An ANSI string to send to the filter
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendStringToDriver(
  */
 DWORD
 KmtFltSendStringToDriver(
@@ -190,7 +324,7 @@ KmtFltSendStringToDriver(
  * @param String
  *        An wide string to send to the filter
  *
  * @param String
  *        An wide string to send to the filter
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendWStringToDriver(
  */
 DWORD
 KmtFltSendWStringToDriver(
@@ -213,7 +347,7 @@ KmtFltSendWStringToDriver(
  * @param Value
  *        An 32bit valueng to send to the filter
  *
  * @param Value
  *        An 32bit valueng to send to the filter
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendUlongToDriver(
  */
 DWORD
 KmtFltSendUlongToDriver(
@@ -244,7 +378,7 @@ KmtFltSendUlongToDriver(
  * @param BytesReturned
  *        Number of bytes written in the reply buffer
  *
  * @param BytesReturned
  *        Number of bytes written in the reply buffer
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendBufferToDriver(
  */
 DWORD
 KmtFltSendBufferToDriver(
@@ -299,3 +433,154 @@ KmtFltSendBufferToDriver(
 
     return Error;
 }
 
     return Error;
 }
+
+/**
+* @name KmtFltAddAltitude
+*
+* Sets up the mini-filter altitude data in the registry
+*
+* @param hPort
+*        The altitude string to set
+*
+* @return Win32 error code
+*/
+DWORD
+KmtFltAddAltitude(
+    _In_z_ LPWSTR Altitude)
+{
+    WCHAR DefaultInstance[128];
+    WCHAR KeyPath[256];
+    HKEY hKey = NULL;
+    HKEY hSubKey = NULL;
+    DWORD Zero = 0;
+    LONG Error;
+
+    StringCbCopy(KeyPath, sizeof KeyPath, L"SYSTEM\\CurrentControlSet\\Services\\");
+    StringCbCat(KeyPath, sizeof KeyPath, TestServiceName);
+    StringCbCat(KeyPath, sizeof KeyPath, L"\\Instances\\");
+
+    Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE,
+                           KeyPath,
+                           0,
+                           NULL,
+                           REG_OPTION_NON_VOLATILE,
+                           KEY_CREATE_SUB_KEY | KEY_SET_VALUE,
+                           NULL,
+                           &hKey,
+                           NULL);
+    if (Error != ERROR_SUCCESS)
+    {
+        return Error;
+    }
+
+    StringCbCopy(DefaultInstance, sizeof DefaultInstance, TestServiceName);
+    StringCbCat(DefaultInstance, sizeof DefaultInstance, L" Instance");
+
+    Error = RegSetValueExW(hKey,
+                           L"DefaultInstance",
+                           0,
+                           REG_SZ,
+                           (LPBYTE)DefaultInstance,
+                           (wcslen(DefaultInstance) + 1) * sizeof(WCHAR));
+    if (Error != ERROR_SUCCESS)
+    {
+        goto Quit;
+    }
+
+    Error = RegCreateKeyW(hKey, DefaultInstance, &hSubKey);
+    if (Error != ERROR_SUCCESS)
+    {
+        goto Quit;
+    }
+
+    Error = RegSetValueExW(hSubKey,
+                           L"Altitude",
+                           0,
+                           REG_SZ,
+                           (LPBYTE)Altitude,
+                           (wcslen(Altitude) + 1) * sizeof(WCHAR));
+    if (Error != ERROR_SUCCESS)
+    {
+        goto Quit;
+    }
+
+    Error = RegSetValueExW(hSubKey,
+                           L"Flags",
+                           0,
+                           REG_DWORD,
+                           (LPBYTE)&Zero,
+                           sizeof(DWORD));
+
+Quit:
+    if (hSubKey)
+    {
+        RegCloseKey(hSubKey);
+    }
+    if (hKey)
+    {
+        RegCloseKey(hKey);
+    }
+
+    return Error;
+
+}
+
+/*
+* Private functions, not meant for use in kmtests
+*/
+
+DWORD EnablePrivilege(
+    _In_ HANDLE hToken,
+    _In_z_ LPWSTR lpPrivName,
+    _In_ BOOL bEnable)
+{
+    TOKEN_PRIVILEGES TokenPrivileges;
+    LUID luid;
+    BOOL bSuccess;
+    DWORD dwError = ERROR_SUCCESS;
+
+    /* Get the luid for this privilege */
+    if (!LookupPrivilegeValueW(NULL, lpPrivName, &luid))
+        return GetLastError();
+
+    /* Setup the struct with the priv info */
+    TokenPrivileges.PrivilegeCount = 1;
+    TokenPrivileges.Privileges[0].Luid = luid;
+    TokenPrivileges.Privileges[0].Attributes = bEnable ? SE_PRIVILEGE_ENABLED : 0;
+
+    /* Enable the privilege info in the token */
+    bSuccess = AdjustTokenPrivileges(hToken,
+                                     FALSE,
+                                     &TokenPrivileges,
+                                     sizeof(TOKEN_PRIVILEGES),
+                                     NULL,
+                                     NULL);
+    if (bSuccess == FALSE) dwError = GetLastError();
+
+    /* return status */
+    return dwError;
+}
+
+DWORD EnablePrivilegeInCurrentProcess(
+    _In_z_ LPWSTR lpPrivName,
+    _In_ BOOL bEnable)
+{
+    HANDLE hToken;
+    BOOL bSuccess;
+    DWORD dwError = ERROR_SUCCESS;
+
+    /* Get a handle to our token */
+    bSuccess = OpenProcessToken(GetCurrentProcess(),
+                                TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
+                                &hToken);
+    if (bSuccess == FALSE) return GetLastError();
+
+    /* Enable the privilege in the agent token */
+    dwError = EnablePrivilege(hToken, lpPrivName, bEnable);
+
+    /* We're done with this now */
+    CloseHandle(hToken);
+
+    /* return status */
+    return dwError;
+}
index 7329d18..f39c418 100644 (file)
@@ -60,16 +60,9 @@ KmtDeleteService(
 
 DWORD KmtCloseService(
     IN OUT SC_HANDLE *ServiceHandle);
 
 DWORD KmtCloseService(
     IN OUT SC_HANDLE *ServiceHandle);
-       
 
 
-/* FS Filter management functions */   
 
 
-DWORD
-KmtFltCreateService(
-    _In_z_ PCWSTR ServiceName,
-    _In_z_ PCWSTR ServicePath,
-    _In_z_ PCWSTR DisplayName OPTIONAL,
-    _Out_ SC_HANDLE *ServiceHandle);
+#ifdef KMT_FLT_USER_MODE
 
 DWORD
 KmtFltLoad(
 
 DWORD
 KmtFltLoad(
@@ -132,4 +125,6 @@ KmtFltDeleteService(
 DWORD KmtFltCloseService(
     _Inout_ SC_HANDLE *ServiceHandle);
 
 DWORD KmtFltCloseService(
     _Inout_ SC_HANDLE *ServiceHandle);
 
+#endif /* KMT_FILTER_DRIVER */
+
 #endif /* !defined _KMTESTS_H_ */
 #endif /* !defined _KMTESTS_H_ */
index 562691d..01cb31c 100644 (file)
@@ -11,6 +11,8 @@ KMT_TESTFUNC Test_CcCopyRead;
 KMT_TESTFUNC Test_Example;
 KMT_TESTFUNC Test_FileAttributes;
 KMT_TESTFUNC Test_FindFile;
 KMT_TESTFUNC Test_Example;
 KMT_TESTFUNC Test_FileAttributes;
 KMT_TESTFUNC Test_FindFile;
+KMT_TESTFUNC Test_FltMgrLoad;
+KMT_TESTFUNC Test_FltMgrReg;
 KMT_TESTFUNC Test_HidPDescription;
 KMT_TESTFUNC Test_IoCreateFile;
 KMT_TESTFUNC Test_IoDeviceObject;
 KMT_TESTFUNC Test_HidPDescription;
 KMT_TESTFUNC Test_IoCreateFile;
 KMT_TESTFUNC Test_IoDeviceObject;
@@ -37,6 +39,8 @@ const KMT_TEST TestList[] =
     { "-Example",                     Test_Example },
     { "FileAttributes",               Test_FileAttributes },
     { "FindFile",                     Test_FindFile },
     { "-Example",                     Test_Example },
     { "FileAttributes",               Test_FileAttributes },
     { "FindFile",                     Test_FindFile },
+    { "FltMgrLoad",                   Test_FltMgrLoad },
+    { "FltMgrReg",                    Test_FltMgrReg },
     { "HidPDescription",              Test_HidPDescription },
     { "IoCreateFile",                 Test_IoCreateFile },
     { "IoDeviceObject",               Test_IoDeviceObject },
     { "HidPDescription",              Test_HidPDescription },
     { "IoCreateFile",                 Test_IoCreateFile },
     { "IoDeviceObject",               Test_IoDeviceObject },
index 2de7f3b..3bb686d 100644 (file)
@@ -2,7 +2,8 @@
  * PROJECT:         ReactOS kernel-mode tests - Filter Manager
  * LICENSE:         GPLv2+ - See COPYING in the top level directory
  * PURPOSE:         FS Mini-filter wrapper to host the filter manager tests
  * PROJECT:         ReactOS kernel-mode tests - Filter Manager
  * LICENSE:         GPLv2+ - See COPYING in the top level directory
  * PURPOSE:         FS Mini-filter wrapper to host the filter manager tests
- * PROGRAMMER:      Ged Murphy <ged.murphy@reactos.org>
+ * PROGRAMMER:      Thomas Faber <thomas.faber@reactos.org>
+ *                  Ged Murphy <ged.murphy@reactos.org>
  */
 
 #include <ntifs.h>
  */
 
 #include <ntifs.h>
@@ -35,6 +36,7 @@ DRIVER_INITIALIZE DriverEntry;
 
 /* Globals */
 static PDRIVER_OBJECT TestDriverObject;
 
 /* Globals */
 static PDRIVER_OBJECT TestDriverObject;
+static PDEVICE_OBJECT KmtestDeviceObject;
 static FILTER_DATA FilterData;
 static PFLT_OPERATION_REGISTRATION Callbacks = NULL;
 static PFLT_CONTEXT_REGISTRATION Contexts = NULL;
 static FILTER_DATA FilterData;
 static PFLT_OPERATION_REGISTRATION Callbacks = NULL;
 static PFLT_CONTEXT_REGISTRATION Contexts = NULL;
@@ -133,6 +135,9 @@ DriverEntry(
     PSECURITY_DESCRIPTOR SecurityDescriptor;
     UNICODE_STRING DeviceName;
     WCHAR DeviceNameBuffer[128] = L"\\Device\\Kmtest-";
     PSECURITY_DESCRIPTOR SecurityDescriptor;
     UNICODE_STRING DeviceName;
     WCHAR DeviceNameBuffer[128] = L"\\Device\\Kmtest-";
+    UNICODE_STRING KmtestDeviceName;
+    PFILE_OBJECT KmtestFileObject;
+    PKMT_DEVICE_EXTENSION KmtestDeviceExtension;
     PCWSTR DeviceNameSuffix;
     INT Flags = 0;
     PKPRCB Prcb;
     PCWSTR DeviceNameSuffix;
     INT Flags = 0;
     PKPRCB Prcb;
@@ -148,6 +153,32 @@ DriverEntry(
     KmtIsMultiProcessorBuild = (Prcb->BuildType & PRCB_BUILD_UNIPROCESSOR) == 0;
     TestDriverObject = DriverObject;
 
     KmtIsMultiProcessorBuild = (Prcb->BuildType & PRCB_BUILD_UNIPROCESSOR) == 0;
     TestDriverObject = DriverObject;
 
+    /* get the Kmtest device, so that we get a ResultBuffer pointer */
+    RtlInitUnicodeString(&KmtestDeviceName, KMTEST_DEVICE_DRIVER_PATH);
+    Status = IoGetDeviceObjectPointer(&KmtestDeviceName, FILE_ALL_ACCESS, &KmtestFileObject, &KmtestDeviceObject);
+
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to get Kmtest device object pointer\n");
+        goto cleanup;
+    }
+
+    Status = ObReferenceObjectByPointer(KmtestDeviceObject, FILE_ALL_ACCESS, NULL, KernelMode);
+
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to reference Kmtest device object\n");
+        goto cleanup;
+    }
+
+    ObDereferenceObject(KmtestFileObject);
+    KmtestFileObject = NULL;
+    KmtestDeviceExtension = KmtestDeviceObject->DeviceExtension;
+    ResultBuffer = KmtestDeviceExtension->ResultBuffer;
+    DPRINT("KmtestDeviceObject: %p\n", (PVOID)KmtestDeviceObject);
+    DPRINT("KmtestDeviceExtension: %p\n", (PVOID)KmtestDeviceExtension);
+    DPRINT("Setting ResultBuffer: %p\n", (PVOID)ResultBuffer);
+
 
     /* call TestEntry */
     RtlInitUnicodeString(&DeviceName, DeviceNameBuffer);
 
     /* call TestEntry */
     RtlInitUnicodeString(&DeviceName, DeviceNameBuffer);
@@ -243,6 +274,7 @@ FilterUnload(
 {
     PAGED_CODE();
     UNREFERENCED_PARAMETER(Flags);
 {
     PAGED_CODE();
     UNREFERENCED_PARAMETER(Flags);
+    //__debugbreak();
 
     DPRINT("DriverUnload\n");
 
 
     DPRINT("DriverUnload\n");
 
@@ -312,50 +344,54 @@ FilterInstanceSetup(
     UNREFERENCED_PARAMETER(FltObjects);
     UNREFERENCED_PARAMETER(Flags);
 
     UNREFERENCED_PARAMETER(FltObjects);
     UNREFERENCED_PARAMETER(Flags);
 
-    RtlInitUnicodeString(&VolumeName, NULL);
+    if (!(Flags & TESTENTRY_NO_INSTANCE_SETUP))
+    {
+        RtlInitUnicodeString(&VolumeName, NULL);
+
 #if 0 // FltGetVolumeProperties is not yet implemented
     /* Get the properties of this volume */
 #if 0 // FltGetVolumeProperties is not yet implemented
     /* Get the properties of this volume */
-    Status = FltGetVolumeProperties(Volume,
-                                    VolumeProperties,
-                                    sizeof(VolPropBuffer),
-                                    &LengthReturned);
-    if (NT_SUCCESS(Status))
-    {
-        FLT_ASSERT((VolumeProperties->SectorSize == 0) || (VolumeProperties->SectorSize >= MIN_SECTOR_SIZE));
-        SectorSize = max(VolumeProperties->SectorSize, MIN_SECTOR_SIZE);
-        ReportedSectorSize = VolumeProperties->SectorSize;
-    }
-    else
-    {
-        DPRINT1("Failed to get the volume properties : 0x%X", Status);
-        return Status;
-    }
+        Status = FltGetVolumeProperties(Volume,
+                                        VolumeProperties,
+                                        sizeof(VolPropBuffer),
+                                        &LengthReturned);
+        if (NT_SUCCESS(Status))
+        {
+            FLT_ASSERT((VolumeProperties->SectorSize == 0) || (VolumeProperties->SectorSize >= MIN_SECTOR_SIZE));
+            SectorSize = max(VolumeProperties->SectorSize, MIN_SECTOR_SIZE);
+            ReportedSectorSize = VolumeProperties->SectorSize;
+        }
+        else
+        {
+            DPRINT1("Failed to get the volume properties : 0x%X", Status);
+            return Status;
+        }
 #endif
 #endif
-    /*  Get the storage device object we want a name for */
-    Status = FltGetDiskDeviceObject(FltObjects->Volume, &DeviceObject);
-    if (NT_SUCCESS(Status))
-    {
-        /* Get the dos device name */
-        Status = IoVolumeDeviceToDosName(DeviceObject, &VolumeName);
+        /*  Get the storage device object we want a name for */
+        Status = FltGetDiskDeviceObject(FltObjects->Volume, &DeviceObject);
         if (NT_SUCCESS(Status))
         {
         if (NT_SUCCESS(Status))
         {
-            DPRINT("VolumeDeviceType %lu, VolumeFilesystemType %lu, Real SectSize=0x%04x, Reported SectSize=0x%04x, Name=\"%wZ\"",
-                   VolumeDeviceType,
-                   VolumeFilesystemType,
-                   SectorSize,
-                   ReportedSectorSize,
-                   &VolumeName);
-
-            Status = TestInstanceSetup(FltObjects,
-                                       Flags,
-                                       VolumeDeviceType,
-                                       VolumeFilesystemType,
-                                       &VolumeName,
-                                       SectorSize,
-                                       ReportedSectorSize);
-
-            /* The buffer was allocated by the IoMgr */
-            ExFreePool(VolumeName.Buffer);
+            /* Get the dos device name */
+            Status = IoVolumeDeviceToDosName(DeviceObject, &VolumeName);
+            if (NT_SUCCESS(Status))
+            {
+                DPRINT("VolumeDeviceType %lu, VolumeFilesystemType %lu, Real SectSize=0x%04x, Reported SectSize=0x%04x, Name=\"%wZ\"",
+                       VolumeDeviceType,
+                       VolumeFilesystemType,
+                       SectorSize,
+                       ReportedSectorSize,
+                       &VolumeName);
+
+                Status = TestInstanceSetup(FltObjects,
+                                           Flags,
+                                           VolumeDeviceType,
+                                           VolumeFilesystemType,
+                                           &VolumeName,
+                                           SectorSize,
+                                           ReportedSectorSize);
+
+                /* The buffer was allocated by the IoMgr */
+                ExFreePool(VolumeName.Buffer);
+            }
         }
     }
 
         }
     }
 
@@ -384,7 +420,10 @@ FilterQueryTeardown(
 {
     PAGED_CODE();
 
 {
     PAGED_CODE();
 
-    TestQueryTeardown(FltObjects, Flags);
+    if (!(Flags & TESTENTRY_NO_QUERY_TEARDOWN))
+    {
+        TestQueryTeardown(FltObjects, Flags);
+    }
 
     /* We always allow a volume to detach */
     return STATUS_SUCCESS;
 
     /* We always allow a volume to detach */
     return STATUS_SUCCESS;