[CMAKE]
[reactos.git] / drivers / ksfilter / ks / irp.c
index e50490a..fb36d6d 100644 (file)
@@ -101,7 +101,7 @@ KsDispatchSetSecurity(
     if (NT_SUCCESS(Status))
     {
         /* free old descriptor */
-        ExFreePool(Descriptor);
+        FreeItem(Descriptor);
 
        /* mark create item as changed */
        CreateItem->Flags |= KSCREATE_ITEM_SECURITYCHANGED;
@@ -661,6 +661,11 @@ KsProbeStreamIrp(
 
     if (Irp->RequestorMode == KernelMode || Irp->AssociatedIrp.SystemBuffer)
     {
+        if (Irp->RequestorMode == KernelMode)
+        {
+            /* no need to allocate stream header */
+            Irp->AssociatedIrp.SystemBuffer = Irp->UserBuffer;
+        }
 AllocMdl:
         /* check if alloc mdl flag is passed */
         if (!(ProbeFlags & KSPROBE_ALLOCATEMDL))
@@ -724,7 +729,16 @@ ProbeMdl:
                         Mdl = Irp->MdlAddress;
 
                         /* determine operation */
-                        Operation = (ProbeFlags & KSPROBE_STREAMWRITE) ? IoWriteAccess : IoReadAccess;
+                        if (!(ProbeFlags & KSPROBE_STREAMWRITE) || (ProbeFlags & KSPROBE_MODIFY))
+                        {
+                            /* operation is read / modify stream, need write access */
+                            Operation = IoWriteAccess;
+                        }
+                        else
+                        {
+                            /* operation is write to device, so we need read access */
+                            Operation = IoReadAccess;
+                        }
 
                         do
                         {
@@ -859,7 +873,8 @@ ProbeMdl:
                 if (StreamHeader->FrameExtent)
                 {
                     /* allocate an mdl */
-                    Mdl = IoAllocateMdl(StreamHeader->Data, StreamHeader->FrameExtent, Irp->MdlAddress != NULL, TRUE, Irp);
+                    ASSERT(Irp->MdlAddress == NULL);
+                    Mdl = IoAllocateMdl(StreamHeader->Data, StreamHeader->FrameExtent, FALSE, TRUE, Irp);
                     if (!Mdl)
                     {
                         /* not enough memory */
@@ -880,23 +895,19 @@ ProbeMdl:
 
         /* now probe the allocated mdl's */
         if (!NT_SUCCESS(Status))
-               {
+        {
             DPRINT("Status %x\n", Status);
             return Status;
-               }
+        }
         else
             goto ProbeMdl;
     }
 
-#if 0
-    // HACK for MS PORTCLS
-       HeaderSize = Length;
-#endif
     /* probe user mode buffers */
     if (Length && ( (!HeaderSize) || (Length % HeaderSize == 0) || ((ProbeFlags & KSPROBE_ALLOWFORMATCHANGE) && (Length == sizeof(KSSTREAM_HEADER))) ) )
     {
         /* allocate stream header buffer */
-        Irp->AssociatedIrp.SystemBuffer = ExAllocatePool(NonPagedPool, Length);
+        Irp->AssociatedIrp.SystemBuffer = AllocateItem(NonPagedPool, Length);
 
         if (!Irp->AssociatedIrp.SystemBuffer)
         {
@@ -904,6 +915,9 @@ ProbeMdl:
             return STATUS_INSUFFICIENT_RESOURCES;
         }
 
+        /* mark irp as buffered so that changes the stream headers are propagated back */
+        Irp->Flags = IRP_DEALLOCATE_BUFFER | IRP_BUFFERED_IO;
+
         _SEH2_TRY
         {
             if (ProbeFlags & KSPROBE_STREAMWRITE)
@@ -917,6 +931,9 @@ ProbeMdl:
             {
                 /* stream reads means writing */
                 ProbeForWrite(Irp->UserBuffer, Length, sizeof(UCHAR));
+
+                /* set input operation flags */
+                Irp->Flags |= IRP_INPUT_OPERATION;
             }
 
             /* copy stream buffer */
@@ -1619,19 +1636,32 @@ KsAddIrpToCancelableQueue(
     PIO_STACK_LOCATION IoStack;
     KIRQL OldLevel;
 
-    DPRINT("KsAddIrpToCancelableQueue QueueHead %p SpinLock %p Irp %p ListLocation %x DriverCancel %p\n", QueueHead, SpinLock, Irp, ListLocation, DriverCancel);
     /* check for required parameters */
     if (!QueueHead || !SpinLock || !Irp)
         return;
 
+    /* get current irp stack */
+    IoStack = IoGetCurrentIrpStackLocation(Irp);
+
+    DPRINT("KsAddIrpToCancelableQueue QueueHead %p SpinLock %p Irp %p ListLocation %x DriverCancel %p\n", QueueHead, SpinLock, Irp, ListLocation, DriverCancel);
+
+    // HACK for ms portcls
+    if (IoStack->MajorFunction == IRP_MJ_CREATE)
+    {
+        // complete the request
+        Irp->IoStatus.Status = STATUS_SUCCESS;
+        IoCompleteRequest(Irp, IO_NO_INCREMENT);
+
+        return;
+    }
+
+
     if (!DriverCancel)
     {
         /* default to KsCancelRoutine */
         DriverCancel = KsCancelRoutine;
     }
 
-    /* get current irp stack */
-    IoStack = IoGetCurrentIrpStackLocation(Irp);
 
     /* acquire spinlock */
     KeAcquireSpinLock(SpinLock, &OldLevel);
@@ -1719,10 +1749,26 @@ FindMatchingCreateItem(
 {
     PLIST_ENTRY Entry;
     PCREATE_ITEM_ENTRY CreateItemEntry;
+    UNICODE_STRING RefString;
+    LPWSTR pStr;
 
-    /* remove '\' slash */
-    Buffer++;
-    BufferSize -= sizeof(WCHAR);
+    /* get terminator */
+    pStr = wcschr(Buffer, L'\\');
+
+    /* sanity check */
+    ASSERT(pStr != NULL);
+
+    if (pStr == Buffer)
+    {
+        // skip slash
+        RtlInitUnicodeString(&RefString, ++pStr);
+    }
+    else
+    {
+        // request is for pin / node / allocator
+        RefString.Buffer = Buffer;
+        RefString.Length = BufferSize = RefString.MaximumLength = ((ULONG_PTR)pStr - (ULONG_PTR)Buffer);
+    }
 
     /* point to first entry */
     Entry = ListHead->Flink;
@@ -1751,9 +1797,9 @@ FindMatchingCreateItem(
 
         ASSERT(CreateItemEntry->CreateItem->ObjectClass.Buffer);
 
-        DPRINT("CreateItem %S Length %u Request %S %u\n", CreateItemEntry->CreateItem->ObjectClass.Buffer,
+        DPRINT("CreateItem %S Length %u Request %wZ %u\n", CreateItemEntry->CreateItem->ObjectClass.Buffer,
                                                            CreateItemEntry->CreateItem->ObjectClass.Length,
-                                                           Buffer,
+                                                           &RefString,
                                                            BufferSize);
 
         if (CreateItemEntry->CreateItem->ObjectClass.Length > BufferSize)
@@ -1764,7 +1810,7 @@ FindMatchingCreateItem(
         }
 
          /* now check if the object class is the same */
-        if (RtlCompareMemory(CreateItemEntry->CreateItem->ObjectClass.Buffer, Buffer, CreateItemEntry->CreateItem->ObjectClass.Length) == CreateItemEntry->CreateItem->ObjectClass.Length)
+        if (!RtlCompareUnicodeString(&CreateItemEntry->CreateItem->ObjectClass, &RefString, TRUE))
         {
             /* found matching create item */
             *OutCreateItem = CreateItemEntry;
@@ -1949,12 +1995,6 @@ KsSetMajorFunctionHandler(
     IN  ULONG MajorFunction)
 {
     DPRINT("KsSetMajorFunctionHandler Function %x\n", MajorFunction);
-#if 1
-    // HACK
-    // for MS PORTCLS
-    //
-    DriverObject->MajorFunction[IRP_MJ_CREATE] = KspCreate;
-#endif
 
     switch ( MajorFunction )
     {
@@ -2008,7 +2048,7 @@ KsDispatchIrp(
         if (IoStack->MajorFunction == IRP_MJ_CREATE)
         {
             /* check internal type */
-            if (DeviceHeader->lpVtblIKsDevice) /* FIXME improve check */
+            if (DeviceHeader->BasicHeader.OuterUnknown) /* FIXME improve check */
             {
                 /* AVStream client */
                 return IKsDevice_Create(DeviceObject, Irp);
@@ -2040,7 +2080,7 @@ KsDispatchIrp(
     if (IoStack->MajorFunction == IRP_MJ_POWER)
     {
         /* check internal type */
-        if (DeviceHeader->lpVtblIKsDevice) /* FIXME improve check */
+        if (DeviceHeader->BasicHeader.OuterUnknown) /* FIXME improve check */
         {
             /* AVStream client */
             return IKsDevice_Power(DeviceObject, Irp);
@@ -2054,7 +2094,7 @@ KsDispatchIrp(
     else if (IoStack->MajorFunction == IRP_MJ_PNP) /* dispatch pnp */
     {
         /* check internal type */
-        if (DeviceHeader->lpVtblIKsDevice) /* FIXME improve check */
+        if (DeviceHeader->BasicHeader.OuterUnknown) /* FIXME improve check */
         {
             /* AVStream client */
             return IKsDevice_Pnp(DeviceObject, Irp);