[CMAKE]
[reactos.git] / drivers / ksfilter / ks / irp.c
index 32ba84d..fb36d6d 100644 (file)
@@ -661,6 +661,11 @@ KsProbeStreamIrp(
 
     if (Irp->RequestorMode == KernelMode || Irp->AssociatedIrp.SystemBuffer)
     {
+        if (Irp->RequestorMode == KernelMode)
+        {
+            /* no need to allocate stream header */
+            Irp->AssociatedIrp.SystemBuffer = Irp->UserBuffer;
+        }
 AllocMdl:
         /* check if alloc mdl flag is passed */
         if (!(ProbeFlags & KSPROBE_ALLOCATEMDL))
@@ -724,7 +729,16 @@ ProbeMdl:
                         Mdl = Irp->MdlAddress;
 
                         /* determine operation */
-                        Operation = (ProbeFlags & KSPROBE_STREAMWRITE) ? IoWriteAccess : IoReadAccess;
+                        if (!(ProbeFlags & KSPROBE_STREAMWRITE) || (ProbeFlags & KSPROBE_MODIFY))
+                        {
+                            /* operation is read / modify stream, need write access */
+                            Operation = IoWriteAccess;
+                        }
+                        else
+                        {
+                            /* operation is write to device, so we need read access */
+                            Operation = IoReadAccess;
+                        }
 
                         do
                         {
@@ -859,7 +873,8 @@ ProbeMdl:
                 if (StreamHeader->FrameExtent)
                 {
                     /* allocate an mdl */
-                    Mdl = IoAllocateMdl(StreamHeader->Data, StreamHeader->FrameExtent, Irp->MdlAddress != NULL, TRUE, Irp);
+                    ASSERT(Irp->MdlAddress == NULL);
+                    Mdl = IoAllocateMdl(StreamHeader->Data, StreamHeader->FrameExtent, FALSE, TRUE, Irp);
                     if (!Mdl)
                     {
                         /* not enough memory */
@@ -880,18 +895,14 @@ ProbeMdl:
 
         /* now probe the allocated mdl's */
         if (!NT_SUCCESS(Status))
-               {
+        {
             DPRINT("Status %x\n", Status);
             return Status;
-               }
+        }
         else
             goto ProbeMdl;
     }
 
-#if 0
-    // HACK for MS PORTCLS
-       HeaderSize = Length;
-#endif
     /* probe user mode buffers */
     if (Length && ( (!HeaderSize) || (Length % HeaderSize == 0) || ((ProbeFlags & KSPROBE_ALLOWFORMATCHANGE) && (Length == sizeof(KSSTREAM_HEADER))) ) )
     {
@@ -904,6 +915,9 @@ ProbeMdl:
             return STATUS_INSUFFICIENT_RESOURCES;
         }
 
+        /* mark irp as buffered so that changes the stream headers are propagated back */
+        Irp->Flags = IRP_DEALLOCATE_BUFFER | IRP_BUFFERED_IO;
+
         _SEH2_TRY
         {
             if (ProbeFlags & KSPROBE_STREAMWRITE)
@@ -917,6 +931,9 @@ ProbeMdl:
             {
                 /* stream reads means writing */
                 ProbeForWrite(Irp->UserBuffer, Length, sizeof(UCHAR));
+
+                /* set input operation flags */
+                Irp->Flags |= IRP_INPUT_OPERATION;
             }
 
             /* copy stream buffer */
@@ -1360,7 +1377,7 @@ KsRemoveIrpFromCancelableQueue(
     PLIST_ENTRY CurEntry;
     KIRQL OldIrql;
 
-    //DPRINT("KsRemoveIrpFromCancelableQueue ListHead %p SpinLock %p ListLocation %x RemovalOperation %x\n", QueueHead, SpinLock, ListLocation, RemovalOperation);
+    DPRINT("KsRemoveIrpFromCancelableQueue ListHead %p SpinLock %p ListLocation %x RemovalOperation %x\n", QueueHead, SpinLock, ListLocation, RemovalOperation);
 
     /* check parameters */
     if (!QueueHead || !SpinLock)
@@ -1619,19 +1636,32 @@ KsAddIrpToCancelableQueue(
     PIO_STACK_LOCATION IoStack;
     KIRQL OldLevel;
 
-    DPRINT("KsAddIrpToCancelableQueue QueueHead %p SpinLock %p Irp %p ListLocation %x DriverCancel %p\n", QueueHead, SpinLock, Irp, ListLocation, DriverCancel);
     /* check for required parameters */
     if (!QueueHead || !SpinLock || !Irp)
         return;
 
+    /* get current irp stack */
+    IoStack = IoGetCurrentIrpStackLocation(Irp);
+
+    DPRINT("KsAddIrpToCancelableQueue QueueHead %p SpinLock %p Irp %p ListLocation %x DriverCancel %p\n", QueueHead, SpinLock, Irp, ListLocation, DriverCancel);
+
+    // HACK for ms portcls
+    if (IoStack->MajorFunction == IRP_MJ_CREATE)
+    {
+        // complete the request
+        Irp->IoStatus.Status = STATUS_SUCCESS;
+        IoCompleteRequest(Irp, IO_NO_INCREMENT);
+
+        return;
+    }
+
+
     if (!DriverCancel)
     {
         /* default to KsCancelRoutine */
         DriverCancel = KsCancelRoutine;
     }
 
-    /* get current irp stack */
-    IoStack = IoGetCurrentIrpStackLocation(Irp);
 
     /* acquire spinlock */
     KeAcquireSpinLock(SpinLock, &OldLevel);
@@ -1720,22 +1750,24 @@ FindMatchingCreateItem(
     PLIST_ENTRY Entry;
     PCREATE_ITEM_ENTRY CreateItemEntry;
     UNICODE_STRING RefString;
+    LPWSTR pStr;
 
+    /* get terminator */
+    pStr = wcschr(Buffer, L'\\');
 
-#ifndef MS_KSUSER
-    /* remove '\' slash */
-    Buffer++;
-    BufferSize -= sizeof(WCHAR);
-#endif
+    /* sanity check */
+    ASSERT(pStr != NULL);
 
-    if (!wcschr(Buffer, L'\\'))
+    if (pStr == Buffer)
     {
-        RtlInitUnicodeString(&RefString, Buffer);
+        // skip slash
+        RtlInitUnicodeString(&RefString, ++pStr);
     }
     else
     {
+        // request is for pin / node / allocator
         RefString.Buffer = Buffer;
-        RefString.Length = RefString.MaximumLength = ((ULONG_PTR)wcschr(Buffer, L'\\') - (ULONG_PTR)Buffer);
+        RefString.Length = BufferSize = RefString.MaximumLength = ((ULONG_PTR)pStr - (ULONG_PTR)Buffer);
     }
 
     /* point to first entry */
@@ -1963,12 +1995,6 @@ KsSetMajorFunctionHandler(
     IN  ULONG MajorFunction)
 {
     DPRINT("KsSetMajorFunctionHandler Function %x\n", MajorFunction);
-#if 1
-    // HACK
-    // for MS PORTCLS
-    //
-    DriverObject->MajorFunction[IRP_MJ_CREATE] = KspCreate;
-#endif
 
     switch ( MajorFunction )
     {
@@ -2006,7 +2032,7 @@ KsDispatchIrp(
     PKSIDEVICE_HEADER DeviceHeader;
     PDEVICE_EXTENSION DeviceExtension;
 
-    //DPRINT("KsDispatchIrp DeviceObject %p Irp %p\n", DeviceObject, Irp);
+    DPRINT("KsDispatchIrp DeviceObject %p Irp %p\n", DeviceObject, Irp);
 
     /* get device extension */
     DeviceExtension = (PDEVICE_EXTENSION)DeviceObject->DeviceExtension;