[FLTMGR] Latest from my branch (#135)
[reactos.git] / drivers / filters / fltmgr / Messaging.c
index 97e6a94..6a40e0d 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "fltmgr.h"
 #include "fltmgrint.h"
+#include <fltmgr_shared.h>
 
 #define NDEBUG
 #include <debug.h>
@@ -29,6 +30,32 @@ FltpDisconnectPort(
     _In_ PFLT_PORT_OBJECT PortObject
 );
 
+static
+NTSTATUS
+CreateClientPort(
+    _In_ PFILE_OBJECT FileObject,
+    _Inout_ PIRP Irp
+);
+
+static
+NTSTATUS
+CloseClientPort(
+    _In_ PFILE_OBJECT FileObject,
+    _Inout_ PIRP Irp
+);
+
+static
+NTSTATUS
+InitializeMessageWaiterQueue(
+    _Inout_ PFLT_MESSAGE_WAITER_QUEUE MsgWaiterQueue
+);
+
+static
+PPORT_CCB
+CreatePortCCB(
+    _In_ PFLT_PORT_OBJECT PortObject
+);
+
 
 
 /* EXPORTED FUNCTIONS ******************************************************/
@@ -72,8 +99,8 @@ FltCreateCommunicationPort(_In_ PFLT_FILTER Filter,
         return Status;
     }
 
-    /* Create our new server port object */
-    Status = ObCreateObject(0,
+    /* Create the server port object for this filter */
+    Status = ObCreateObject(KernelMode,
                             ServerPortObjectType,
                             ObjectAttributes,
                             KernelMode,
@@ -191,6 +218,72 @@ FltSendMessage(_In_ PFLT_FILTER Filter,
 
 /* INTERNAL FUNCTIONS ******************************************************/
 
+
+NTSTATUS
+FltpMsgCreate(_In_ PDEVICE_OBJECT DeviceObject,
+              _Inout_ PIRP Irp)
+{
+    PIO_STACK_LOCATION StackPtr;
+    NTSTATUS Status;
+
+    /* Get the stack location */
+    StackPtr = IoGetCurrentIrpStackLocation(Irp);
+
+    FLT_ASSERT(StackPtr->MajorFunction == IRP_MJ_CREATE);
+
+    /* Check if this is a caller wanting to connect */
+    if (StackPtr->MajorFunction == IRP_MJ_CREATE)
+    {
+        /* Create the client port for this connection and exit */
+        Status = CreateClientPort(StackPtr->FileObject, Irp);
+    }
+    else
+    {
+        Status = STATUS_INVALID_PARAMETER;
+    }
+
+    if (Status != STATUS_PENDING)
+    {
+        Irp->IoStatus.Status = Status;
+        Irp->IoStatus.Information = 0;
+        IoCompleteRequest(Irp, 0);
+    }
+
+    return Status;
+}
+
+NTSTATUS
+FltpMsgDispatch(_In_ PDEVICE_OBJECT DeviceObject,
+                _Inout_ PIRP Irp)
+{
+    PIO_STACK_LOCATION StackPtr;
+    NTSTATUS Status;
+
+    /* Get the stack location */
+    StackPtr = IoGetCurrentIrpStackLocation(Irp);
+
+    /* Check if this is a caller wanting to connect */
+    if (StackPtr->MajorFunction == IRP_MJ_CLOSE)
+    {
+        /* Create the client port for this connection and exit */
+        Status = CloseClientPort(StackPtr->FileObject, Irp);
+    }
+    else
+    {
+        // We don't support anything else yet
+        Status = STATUS_NOT_IMPLEMENTED;
+    }
+
+    if (Status != STATUS_PENDING)
+    {
+        Irp->IoStatus.Status = Status;
+        Irp->IoStatus.Information = 0;
+        IoCompleteRequest(Irp, 0);
+    }
+
+    return Status;
+}
+
 VOID
 NTAPI
 FltpServerPortClose(_In_opt_ PEPROCESS Process,
@@ -369,5 +462,370 @@ Quit:
     return Status;
 }
 
+/* CSQ IRP CALLBACKS *******************************************************/
+
+
+NTSTATUS
+NTAPI
+FltpAddMessageWaiter(_In_ PIO_CSQ Csq,
+                     _In_ PIRP Irp,
+                     _In_ PVOID InsertContext)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Insert the IRP at the end of the queue */
+    InsertTailList(&MessageWaiterQueue->WaiterQ.mList,
+                   &Irp->Tail.Overlay.ListEntry);
+
+    /* return success */
+    return STATUS_SUCCESS;
+}
+
+VOID
+NTAPI
+FltpRemoveMessageWaiter(_In_ PIO_CSQ Csq,
+                        _In_ PIRP Irp)
+{
+    /* Remove the IRP from the queue */
+    RemoveEntryList(&Irp->Tail.Overlay.ListEntry);
+}
+
+PIRP
+NTAPI
+FltpGetNextMessageWaiter(_In_ PIO_CSQ Csq,
+                         _In_ PIRP Irp,
+                         _In_ PVOID PeekContext)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+    PIRP NextIrp = NULL;
+    PLIST_ENTRY NextEntry;
+    PIO_STACK_LOCATION IrpStack;
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Is the IRP valid? */
+    if (Irp == NULL)
+    {
+        /* Start peeking from the listhead */
+        NextEntry = MessageWaiterQueue->WaiterQ.mList.Flink;
+    }
+    else
+    {
+        /* Start peeking from that IRP onwards */
+        NextEntry = Irp->Tail.Overlay.ListEntry.Flink;
+    }
+
+    /* Loop through the queue */
+    while (NextEntry != &MessageWaiterQueue->WaiterQ.mList)
+    {
+        /* Store the next IRP in the list */
+        NextIrp = CONTAINING_RECORD(NextEntry, IRP, Tail.Overlay.ListEntry);
+
+        /* Did we supply a PeekContext on insert? */
+        if (!PeekContext)
+        {
+            /* We already have the next IRP */
+            break;
+        }
+        else
+        {
+            /* Get the stack of the next IRP */
+            IrpStack = IoGetCurrentIrpStackLocation(NextIrp);
+
+            /* Does the PeekContext match the object? */
+            if (IrpStack->FileObject == (PFILE_OBJECT)PeekContext)
+            {
+                /* We have a match */
+                break;
+            }
+
+            /* Move to the next IRP */
+            NextIrp = NULL;
+            NextEntry = NextEntry->Flink;
+        }
+    }
+
+    return NextIrp;
+}
+
+_Acquires_lock_(((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, FLT_MESSAGE_WAITER_QUEUE, Csq))->WaiterQ.mLock)
+_IRQL_saves_global_(Irql, ((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock)
+_IRQL_raises_(DISPATCH_LEVEL)
+VOID
+NTAPI
+FltpAcquireMessageWaiterLock(_In_ PIO_CSQ Csq,
+                             _Out_ PKIRQL Irql)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+
+    UNREFERENCED_PARAMETER(Irql);
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Acquire the IRP queue lock */
+    ExAcquireFastMutex(&MessageWaiterQueue->WaiterQ.mLock);
+}
+
+_Releases_lock_(((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock)
+_IRQL_restores_global_(Irql, ((PFLT_MESSAGE_WAITER_QUEUE)CONTAINING_RECORD(Csq, DEVICE_EXTENSION, IrpQueue))->WaiterQ.mLock)
+_IRQL_requires_(DISPATCH_LEVEL)
+VOID
+NTAPI
+FltpReleaseMessageWaiterLock(_In_ PIO_CSQ Csq,
+                             _In_ KIRQL Irql)
+{
+    PFLT_MESSAGE_WAITER_QUEUE MessageWaiterQueue;
+
+    UNREFERENCED_PARAMETER(Irql);
+
+    /* Get the start of the waiter queue struct */
+    MessageWaiterQueue = CONTAINING_RECORD(Csq,
+                                           FLT_MESSAGE_WAITER_QUEUE,
+                                           Csq);
+
+    /* Release the IRP queue lock */
+    ExReleaseFastMutex(&MessageWaiterQueue->WaiterQ.mLock);
+}
+
+VOID
+NTAPI
+FltpCancelMessageWaiter(_In_ PIO_CSQ Csq,
+                        _In_ PIRP Irp)
+{
+    /* Cancel the IRP */
+    Irp->IoStatus.Status = STATUS_CANCELLED;
+    Irp->IoStatus.Information = 0;
+    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+}
+
+
 /* PRIVATE FUNCTIONS ******************************************************/
 
+static
+NTSTATUS
+CreateClientPort(_In_ PFILE_OBJECT FileObject,
+                   _Inout_ PIRP Irp)
+{
+    PFLT_SERVER_PORT_OBJECT ServerPortObject = NULL;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    PFILTER_PORT_DATA FilterPortData;
+    PFLT_PORT_OBJECT ClientPortObject = NULL;
+    PFLT_PORT PortHandle = NULL;
+    PPORT_CCB PortCCB = NULL;
+    //ULONG BufferLength;
+    LONG NumConns;
+    NTSTATUS Status;
+
+    /* We received the buffer via FilterConnectCommunicationPort, cast it back to its original form */
+    FilterPortData = Irp->AssociatedIrp.SystemBuffer;
+
+    /* Get a reference to the server port the filter created */
+    Status = ObReferenceObjectByName(&FilterPortData->PortName,
+                                     0,
+                                     0,
+                                     FLT_PORT_ALL_ACCESS,
+                                     ServerPortObjectType,
+                                     ExGetPreviousMode(),
+                                     0,
+                                     (PVOID *)&ServerPortObject);
+    if (!NT_SUCCESS(Status))
+    {
+        return Status;
+    }
+
+    /* Increment the number of connections on the server port */
+    NumConns = InterlockedIncrement(&ServerPortObject->NumberOfConnections);
+    if (NumConns > ServerPortObject->MaxConnections)
+    {
+        Status = STATUS_CONNECTION_COUNT_LIMIT;
+        goto Quit;
+    }
+
+    /* Initialize a basic kernel handle request */
+    InitializeObjectAttributes(&ObjectAttributes,
+                               NULL,
+                               OBJ_KERNEL_HANDLE,
+                               NULL,
+                               NULL);
+
+    /* Now create the new client port object */
+    Status = ObCreateObject(KernelMode,
+                            ClientPortObjectType,
+                            &ObjectAttributes,
+                            KernelMode,
+                            NULL,
+                            sizeof(FLT_PORT_OBJECT),
+                            0,
+                            0,
+                            (PVOID *)&ClientPortObject);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Clear out the buffer */
+    RtlZeroMemory(ClientPortObject, sizeof(FLT_PORT_OBJECT));
+
+    /* Initialize the locks */
+    ExInitializeRundownProtection(&ClientPortObject->MsgNotifRundownRef);
+    ExInitializeFastMutex(&ClientPortObject->Lock);
+
+    /* Set the server port object this belongs to */
+    ClientPortObject->ServerPort = ServerPortObject;
+
+    /* Setup the message queue */
+    Status = InitializeMessageWaiterQueue(&ClientPortObject->MsgQ);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Create the CCB which we'll attach to the file object */
+    PortCCB = CreatePortCCB(ClientPortObject);
+    if (PortCCB == NULL)
+    {
+        Status = STATUS_INSUFFICIENT_RESOURCES;
+        goto Quit;
+    }
+
+    /* Now insert the new client port into the object manager*/
+    Status = ObInsertObject(ClientPortObject, 0, FLT_PORT_ALL_ACCESS, 1, 0, (PHANDLE)&PortHandle);
+    if (!NT_SUCCESS(Status))
+    {
+        goto Quit;
+    }
+
+    /* Add a reference to the filter to keep it alive while we do some work with it */
+    Status = FltObjectReference(ServerPortObject->Filter);
+    if (NT_SUCCESS(Status))
+    {
+        /* Invoke the callback to let the filter know we have a connection */
+        Status = ServerPortObject->ConnectNotify(PortHandle,
+                                                 ServerPortObject->Cookie,
+                                                 NULL, //ConnectionContext
+                                                 0, //SizeOfContext
+                                                 &ClientPortObject->Cookie);
+        if (NT_SUCCESS(Status))
+        {
+            /* Add the client port CCB to the file object */
+            FileObject->FsContext2 = PortCCB;
+
+            /* Lock the port list on the filter and add this new port object to the list */
+            ExAcquireFastMutex(&ServerPortObject->Filter->PortList.mLock);
+            InsertTailList(&ServerPortObject->Filter->PortList.mList, &ClientPortObject->FilterLink);
+            ExReleaseFastMutex(&ServerPortObject->Filter->PortList.mLock);
+        }
+
+        /* We're done with the filter object, decremement the count */
+        FltObjectDereference(ServerPortObject->Filter);
+    }
+
+
+Quit:
+    if (!NT_SUCCESS(Status))
+    {
+        if (ClientPortObject)
+        {
+            ObfDereferenceObject(ClientPortObject);
+        }
+
+        if (PortHandle)
+        {
+            ZwClose(PortHandle);
+        }
+        else if (ServerPortObject)
+        {
+            InterlockedDecrement(&ServerPortObject->NumberOfConnections);
+            ObfDereferenceObject(ServerPortObject);
+        }
+
+        if (PortCCB)
+        {
+            ExFreePoolWithTag(PortCCB, FM_TAG_CCB);
+        }
+    }
+
+    return Status;
+}
+
+static
+NTSTATUS
+CloseClientPort(_In_ PFILE_OBJECT FileObject,
+                _Inout_ PIRP Irp)
+{
+    PFLT_CCB Ccb;
+
+    Ccb = (PFLT_CCB)FileObject->FsContext2;
+
+    /* Remove the reference on the filter we added when we opened the port */
+    ObDereferenceObject(Ccb->Data.Port.Port);
+
+    // FIXME: Free the CCB
+
+    return STATUS_SUCCESS;
+}
+
+static
+NTSTATUS
+InitializeMessageWaiterQueue(_Inout_ PFLT_MESSAGE_WAITER_QUEUE MsgWaiterQueue)
+{
+    NTSTATUS Status;
+
+    /* Setup the IRP queue */
+    Status = IoCsqInitializeEx(&MsgWaiterQueue->Csq,
+                               FltpAddMessageWaiter,
+                               FltpRemoveMessageWaiter,
+                               FltpGetNextMessageWaiter,
+                               FltpAcquireMessageWaiterLock,
+                               FltpReleaseMessageWaiterLock,
+                               FltpCancelMessageWaiter);
+    if (!NT_SUCCESS(Status))
+    {
+        return Status;
+    }
+
+    /* Initialize the waiter queue */
+    ExInitializeFastMutex(&MsgWaiterQueue->WaiterQ.mLock);
+    InitializeListHead(&MsgWaiterQueue->WaiterQ.mList);
+    MsgWaiterQueue->WaiterQ.mCount = 0;
+
+    /* We don't have a minimum waiter length */
+    MsgWaiterQueue->MinimumWaiterLength = (ULONG)-1;
+
+    /* Init the semaphore and event used for counting and signaling available IRPs */
+    KeInitializeSemaphore(&MsgWaiterQueue->Semaphore, 0, MAXLONG);
+    KeInitializeEvent(&MsgWaiterQueue->Event, NotificationEvent, FALSE);
+
+    return STATUS_SUCCESS;
+}
+
+static
+PPORT_CCB
+CreatePortCCB(_In_ PFLT_PORT_OBJECT PortObject)
+{
+    PPORT_CCB PortCCB;
+
+    /* Allocate a CCB struct to hold the client port object info */
+    PortCCB = ExAllocatePoolWithTag(NonPagedPool, sizeof(PPORT_CCB), FM_TAG_CCB);
+    if (PortCCB)
+    {
+        /* Initialize the structure */
+        PortCCB->Port = PortObject;
+        PortCCB->ReplyWaiterList.mCount = 0;
+        ExInitializeFastMutex(&PortCCB->ReplyWaiterList.mLock);
+        KeInitializeEvent(&PortCCB->ReplyWaiterList.mLock.Event, SynchronizationEvent, 0);
+    }
+
+    return PortCCB;
+}
\ No newline at end of file