Synchronize with trunk revision 59636 (just before Alex's CreateProcess revamp).
[reactos.git] / drivers / ksfilter / ks / irp.c
index 32ba84d..0c493ee 100644 (file)
@@ -34,7 +34,7 @@ KsDispatchQuerySecurity(
     {
         /* no create item */
         Irp->IoStatus.Status = STATUS_NO_SECURITY_ON_OBJECT;
-        IoCompleteRequest(Irp, IO_NO_INCREMENT);
+        CompleteRequest(Irp, IO_NO_INCREMENT);
         return STATUS_NO_SECURITY_ON_OBJECT;
     }
 
@@ -50,7 +50,7 @@ KsDispatchQuerySecurity(
     Irp->IoStatus.Status = Status;
     Irp->IoStatus.Information = Length;
 
-    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+    CompleteRequest(Irp, IO_NO_INCREMENT);
     return Status;
 }
 
@@ -80,7 +80,7 @@ KsDispatchSetSecurity(
     {
         /* no create item */
         Irp->IoStatus.Status = STATUS_NO_SECURITY_ON_OBJECT;
-        IoCompleteRequest(Irp, IO_NO_INCREMENT);
+        CompleteRequest(Irp, IO_NO_INCREMENT);
         return STATUS_NO_SECURITY_ON_OBJECT;
     }
 
@@ -109,7 +109,7 @@ KsDispatchSetSecurity(
 
     /* store result */
     Irp->IoStatus.Status = Status;
-    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+    CompleteRequest(Irp, IO_NO_INCREMENT);
 
     return Status;
 }
@@ -652,7 +652,7 @@ KsProbeStreamIrp(
     PKSSTREAM_HEADER StreamHeader;
     PIO_STACK_LOCATION IoStack;
     ULONG Length;
-    BOOLEAN AllocateMdl = FALSE;
+    //BOOLEAN AllocateMdl = FALSE;
 
     /* get current irp stack */
     IoStack = IoGetCurrentIrpStackLocation(Irp);
@@ -661,6 +661,11 @@ KsProbeStreamIrp(
 
     if (Irp->RequestorMode == KernelMode || Irp->AssociatedIrp.SystemBuffer)
     {
+        if (Irp->RequestorMode == KernelMode)
+        {
+            /* no need to allocate stream header */
+            Irp->AssociatedIrp.SystemBuffer = Irp->UserBuffer;
+        }
 AllocMdl:
         /* check if alloc mdl flag is passed */
         if (!(ProbeFlags & KSPROBE_ALLOCATEMDL))
@@ -724,7 +729,16 @@ ProbeMdl:
                         Mdl = Irp->MdlAddress;
 
                         /* determine operation */
-                        Operation = (ProbeFlags & KSPROBE_STREAMWRITE) ? IoWriteAccess : IoReadAccess;
+                        if (!(ProbeFlags & KSPROBE_STREAMWRITE) || (ProbeFlags & KSPROBE_MODIFY))
+                        {
+                            /* operation is read / modify stream, need write access */
+                            Operation = IoWriteAccess;
+                        }
+                        else
+                        {
+                            /* operation is write to device, so we need read access */
+                            Operation = IoReadAccess;
+                        }
 
                         do
                         {
@@ -859,7 +873,8 @@ ProbeMdl:
                 if (StreamHeader->FrameExtent)
                 {
                     /* allocate an mdl */
-                    Mdl = IoAllocateMdl(StreamHeader->Data, StreamHeader->FrameExtent, Irp->MdlAddress != NULL, TRUE, Irp);
+                    ASSERT(Irp->MdlAddress == NULL);
+                    Mdl = IoAllocateMdl(StreamHeader->Data, StreamHeader->FrameExtent, FALSE, TRUE, Irp);
                     if (!Mdl)
                     {
                         /* not enough memory */
@@ -880,18 +895,14 @@ ProbeMdl:
 
         /* now probe the allocated mdl's */
         if (!NT_SUCCESS(Status))
-               {
+        {
             DPRINT("Status %x\n", Status);
             return Status;
-               }
+        }
         else
             goto ProbeMdl;
     }
 
-#if 0
-    // HACK for MS PORTCLS
-       HeaderSize = Length;
-#endif
     /* probe user mode buffers */
     if (Length && ( (!HeaderSize) || (Length % HeaderSize == 0) || ((ProbeFlags & KSPROBE_ALLOWFORMATCHANGE) && (Length == sizeof(KSSTREAM_HEADER))) ) )
     {
@@ -904,6 +915,9 @@ ProbeMdl:
             return STATUS_INSUFFICIENT_RESOURCES;
         }
 
+        /* mark irp as buffered so that changes the stream headers are propagated back */
+        Irp->Flags = IRP_DEALLOCATE_BUFFER | IRP_BUFFERED_IO;
+
         _SEH2_TRY
         {
             if (ProbeFlags & KSPROBE_STREAMWRITE)
@@ -917,6 +931,9 @@ ProbeMdl:
             {
                 /* stream reads means writing */
                 ProbeForWrite(Irp->UserBuffer, Length, sizeof(UCHAR));
+
+                /* set input operation flags */
+                Irp->Flags |= IRP_INPUT_OPERATION;
             }
 
             /* copy stream buffer */
@@ -1007,7 +1024,7 @@ ProbeMdl:
                             }
 
                             /* break out to probe for the irp */
-                            AllocateMdl = TRUE;
+                            //AllocateMdl = TRUE;
                             break;
                         }
                     }
@@ -1137,7 +1154,7 @@ KsDispatchInvalidDeviceRequest(
     IN  PIRP Irp)
 {
     Irp->IoStatus.Status = STATUS_INVALID_DEVICE_REQUEST;
-    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+    CompleteRequest(Irp, IO_NO_INCREMENT);
 
     return STATUS_INVALID_DEVICE_REQUEST;
 }
@@ -1158,9 +1175,10 @@ KsDefaultDeviceIoCompletion(
     /* get current irp stack */
     IoStack = IoGetCurrentIrpStackLocation(Irp);
 
-    if (IoStack->Parameters.DeviceIoControl.IoControlCode != IOCTL_KS_PROPERTY && 
+    if (IoStack->Parameters.DeviceIoControl.IoControlCode != IOCTL_KS_PROPERTY &&
         IoStack->Parameters.DeviceIoControl.IoControlCode != IOCTL_KS_METHOD &&
-        IoStack->Parameters.DeviceIoControl.IoControlCode != IOCTL_KS_PROPERTY)
+        IoStack->Parameters.DeviceIoControl.IoControlCode != IOCTL_KS_ENABLE_EVENT &&
+        IoStack->Parameters.DeviceIoControl.IoControlCode != IOCTL_KS_DISABLE_EVENT)
     {
         if (IoStack->Parameters.DeviceIoControl.IoControlCode == IOCTL_KS_RESET_STATE)
         {
@@ -1181,7 +1199,7 @@ KsDefaultDeviceIoCompletion(
 
     /* complete request */
     Irp->IoStatus.Status = Status;
-    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+    CompleteRequest(Irp, IO_NO_INCREMENT);
 
 
     return Status;
@@ -1360,7 +1378,7 @@ KsRemoveIrpFromCancelableQueue(
     PLIST_ENTRY CurEntry;
     KIRQL OldIrql;
 
-    //DPRINT("KsRemoveIrpFromCancelableQueue ListHead %p SpinLock %p ListLocation %x RemovalOperation %x\n", QueueHead, SpinLock, ListLocation, RemovalOperation);
+    DPRINT("KsRemoveIrpFromCancelableQueue ListHead %p SpinLock %p ListLocation %x RemovalOperation %x\n", QueueHead, SpinLock, ListLocation, RemovalOperation);
 
     /* check parameters */
     if (!QueueHead || !SpinLock)
@@ -1619,19 +1637,33 @@ KsAddIrpToCancelableQueue(
     PIO_STACK_LOCATION IoStack;
     KIRQL OldLevel;
 
-    DPRINT("KsAddIrpToCancelableQueue QueueHead %p SpinLock %p Irp %p ListLocation %x DriverCancel %p\n", QueueHead, SpinLock, Irp, ListLocation, DriverCancel);
     /* check for required parameters */
     if (!QueueHead || !SpinLock || !Irp)
         return;
 
+    /* get current irp stack */
+    IoStack = IoGetCurrentIrpStackLocation(Irp);
+
+    DPRINT("KsAddIrpToCancelableQueue QueueHead %p SpinLock %p Irp %p ListLocation %x DriverCancel %p\n", QueueHead, SpinLock, Irp, ListLocation, DriverCancel);
+
+    // HACK for ms portcls
+    if (IoStack->MajorFunction == IRP_MJ_CREATE)
+    {
+        // complete the request
+        DPRINT1("MS HACK\n");
+        Irp->IoStatus.Status = STATUS_SUCCESS;
+        CompleteRequest(Irp, IO_NO_INCREMENT);
+
+        return;
+    }
+
+
     if (!DriverCancel)
     {
         /* default to KsCancelRoutine */
         DriverCancel = KsCancelRoutine;
     }
 
-    /* get current irp stack */
-    IoStack = IoGetCurrentIrpStackLocation(Irp);
 
     /* acquire spinlock */
     KeAcquireSpinLock(SpinLock, &OldLevel);
@@ -1706,46 +1738,60 @@ KsCancelRoutine(
     {
         /* let's complete it */
         Irp->IoStatus.Status = STATUS_CANCELLED;
-        IoCompleteRequest(Irp, IO_NO_INCREMENT);
+        CompleteRequest(Irp, IO_NO_INCREMENT);
     }
 }
 
 NTSTATUS
 FindMatchingCreateItem(
     PLIST_ENTRY ListHead,
-    ULONG BufferSize,
-    LPWSTR Buffer,
+    PUNICODE_STRING String,
     OUT PCREATE_ITEM_ENTRY *OutCreateItem)
 {
     PLIST_ENTRY Entry;
     PCREATE_ITEM_ENTRY CreateItemEntry;
     UNICODE_STRING RefString;
+    LPWSTR pStr;
+    ULONG Count;
 
+    /* Copy the input string */
+    RefString = *String;
 
-#ifndef MS_KSUSER
-    /* remove '\' slash */
-    Buffer++;
-    BufferSize -= sizeof(WCHAR);
-#endif
-
-    if (!wcschr(Buffer, L'\\'))
+    /* Check if the string starts with a backslash */
+    if (String->Buffer[0] == L'\\')
     {
-        RtlInitUnicodeString(&RefString, Buffer);
+        /* Skip backslash */
+        RefString.Buffer++;
+        RefString.Length -= sizeof(WCHAR);
     }
     else
     {
-        RefString.Buffer = Buffer;
-        RefString.Length = RefString.MaximumLength = ((ULONG_PTR)wcschr(Buffer, L'\\') - (ULONG_PTR)Buffer);
+        /* get terminator */
+        pStr = String->Buffer;
+        Count = String->Length / sizeof(WCHAR);
+        while ((Count > 0) && (*pStr != L'\\'))
+        {
+            pStr++;
+            Count--;
+        }
+
+        /* sanity check */
+        ASSERT(Count != 0);
+
+        // request is for pin / node / allocator
+        RefString.Length = (USHORT)((PCHAR)pStr - (PCHAR)String->Buffer);
     }
 
     /* point to first entry */
     Entry = ListHead->Flink;
 
     /* loop all device items */
-    while(Entry != ListHead)
+    while (Entry != ListHead)
     {
         /* get create item entry */
-        CreateItemEntry = (PCREATE_ITEM_ENTRY)CONTAINING_RECORD(Entry, CREATE_ITEM_ENTRY, Entry);
+        CreateItemEntry = (PCREATE_ITEM_ENTRY)CONTAINING_RECORD(Entry,
+                                                                CREATE_ITEM_ENTRY,
+                                                                Entry);
 
         ASSERT(CreateItemEntry->CreateItem);
 
@@ -1763,14 +1809,13 @@ FindMatchingCreateItem(
             continue;
         }
 
-        ASSERT(CreateItemEntry->CreateItem->ObjectClass.Buffer);
-
-        DPRINT("CreateItem %S Length %u Request %wZ %u\n", CreateItemEntry->CreateItem->ObjectClass.Buffer,
-                                                           CreateItemEntry->CreateItem->ObjectClass.Length,
-                                                           &RefString,
-                                                           BufferSize);
+        DPRINT("CreateItem %S Length %u Request %wZ %u\n",
+               CreateItemEntry->CreateItem->ObjectClass.Buffer,
+               CreateItemEntry->CreateItem->ObjectClass.Length,
+               &RefString,
+               RefString.Length);
 
-        if (CreateItemEntry->CreateItem->ObjectClass.Length > BufferSize)
+        if (CreateItemEntry->CreateItem->ObjectClass.Length > RefString.Length)
         {
             /* create item doesnt match in length */
             Entry = Entry->Flink;
@@ -1778,7 +1823,9 @@ FindMatchingCreateItem(
         }
 
          /* now check if the object class is the same */
-        if (!RtlCompareUnicodeString(&CreateItemEntry->CreateItem->ObjectClass, &RefString, TRUE))
+        if (!RtlCompareUnicodeString(&CreateItemEntry->CreateItem->ObjectClass,
+                                     &RefString,
+                                     TRUE))
         {
             /* found matching create item */
             *OutCreateItem = CreateItemEntry;
@@ -1821,7 +1868,7 @@ KspCreate(
         Irp->IoStatus.Information = 0;
         /* set return status */
         Irp->IoStatus.Status = STATUS_SUCCESS;
-        IoCompleteRequest(Irp, IO_NO_INCREMENT);
+        CompleteRequest(Irp, IO_NO_INCREMENT);
         return STATUS_SUCCESS;
     }
 
@@ -1834,12 +1881,16 @@ KspCreate(
         ASSERT(ObjectHeader);
 
         /* find a matching a create item */
-        Status = FindMatchingCreateItem(&ObjectHeader->ItemList, IoStack->FileObject->FileName.Length, IoStack->FileObject->FileName.Buffer, &CreateItemEntry);
+        Status = FindMatchingCreateItem(&ObjectHeader->ItemList,
+                                        &IoStack->FileObject->FileName,
+                                        &CreateItemEntry);
     }
     else
     {
         /* request to create a filter */
-        Status = FindMatchingCreateItem(&DeviceHeader->ItemList, IoStack->FileObject->FileName.Length, IoStack->FileObject->FileName.Buffer, &CreateItemEntry);
+        Status = FindMatchingCreateItem(&DeviceHeader->ItemList,
+                                        &IoStack->FileObject->FileName,
+                                        &CreateItemEntry);
     }
 
     if (NT_SUCCESS(Status))
@@ -1861,7 +1912,7 @@ KspCreate(
     Irp->IoStatus.Information = 0;
     /* set return status */
     Irp->IoStatus.Status = STATUS_UNSUCCESSFUL;
-    IoCompleteRequest(Irp, IO_NO_INCREMENT);
+    CompleteRequest(Irp, IO_NO_INCREMENT);
     return STATUS_UNSUCCESSFUL;
 }
 
@@ -1872,9 +1923,9 @@ KspDispatchIrp(
     IN  PIRP Irp)
 {
     PIO_STACK_LOCATION IoStack;
-    PDEVICE_EXTENSION DeviceExtension;
+    //PDEVICE_EXTENSION DeviceExtension;
     PKSIOBJECT_HEADER ObjectHeader;
-    PKSIDEVICE_HEADER DeviceHeader;
+    //PKSIDEVICE_HEADER DeviceHeader;
     PDRIVER_DISPATCH Dispatch;
     NTSTATUS Status;
 
@@ -1882,9 +1933,9 @@ KspDispatchIrp(
     IoStack = IoGetCurrentIrpStackLocation(Irp);
 
     /* get device extension */
-    DeviceExtension = (PDEVICE_EXTENSION)DeviceObject->DeviceExtension;
+    //DeviceExtension = (PDEVICE_EXTENSION)DeviceObject->DeviceExtension;
     /* get device header */
-    DeviceHeader = DeviceExtension->DeviceHeader;
+    //DeviceHeader = DeviceExtension->DeviceHeader;
 
     ASSERT(IoStack->FileObject);
 
@@ -1897,7 +1948,7 @@ KspDispatchIrp(
         Irp->IoStatus.Status = STATUS_SUCCESS;
         Irp->IoStatus.Information = 0;
         /* complete and forget */
-        IoCompleteRequest(Irp, IO_NO_INCREMENT);
+        CompleteRequest(Irp, IO_NO_INCREMENT);
         return STATUS_SUCCESS;
     }
 
@@ -1932,6 +1983,7 @@ KspDispatchIrp(
             break;
         case IRP_MJ_PNP:
             Dispatch = KsDefaultDispatchPnp;
+            break;
         default:
             Dispatch = NULL;
     }
@@ -1963,12 +2015,6 @@ KsSetMajorFunctionHandler(
     IN  ULONG MajorFunction)
 {
     DPRINT("KsSetMajorFunctionHandler Function %x\n", MajorFunction);
-#if 1
-    // HACK
-    // for MS PORTCLS
-    //
-    DriverObject->MajorFunction[IRP_MJ_CREATE] = KspCreate;
-#endif
 
     switch ( MajorFunction )
     {
@@ -2006,10 +2052,11 @@ KsDispatchIrp(
     PKSIDEVICE_HEADER DeviceHeader;
     PDEVICE_EXTENSION DeviceExtension;
 
-    //DPRINT("KsDispatchIrp DeviceObject %p Irp %p\n", DeviceObject, Irp);
+    DPRINT("KsDispatchIrp DeviceObject %p Irp %p\n", DeviceObject, Irp);
 
     /* get device extension */
     DeviceExtension = (PDEVICE_EXTENSION)DeviceObject->DeviceExtension;
+
     /* get device header */
     DeviceHeader = DeviceExtension->DeviceHeader;