[KS]
[reactos.git] / reactos / drivers / ksfilter / ks / filterfactory.c
index 54e06b7..e60bfc4 100644 (file)
@@ -14,13 +14,14 @@ typedef struct
     KSBASIC_HEADER Header;
     KSFILTERFACTORY FilterFactory;
 
-    IKsFilterFactoryVtbl *lpVtbl;
     LONG ref;
     PKSIDEVICE_HEADER DeviceHeader;
     PFNKSFILTERFACTORYPOWER SleepCallback;
     PFNKSFILTERFACTORYPOWER WakeCallback;
 
     LIST_ENTRY SymbolicLinkList;
+    KMUTEX ControlMutex;
+
 }IKsFilterFactoryImpl;
 
 VOID
@@ -41,6 +42,7 @@ IKsFilterFactory_Create(
     IN PIRP Irp)
 {
     PKSOBJECT_CREATE_ITEM CreateItem;
+    IKsFilterFactoryImpl * Factory;
     IKsFilterFactory * iface;
     NTSTATUS Status;
 
@@ -53,7 +55,10 @@ IKsFilterFactory_Create(
     }
 
     /* get filter factory interface */
-    iface = (IKsFilterFactory*)CONTAINING_RECORD(CreateItem->Context, IKsFilterFactoryImpl, FilterFactory);
+    Factory = (IKsFilterFactoryImpl*)CONTAINING_RECORD(CreateItem->Context, IKsFilterFactoryImpl, FilterFactory);
+
+    /* get interface */
+    iface = (IKsFilterFactory*)&Factory->Header.OuterUnknown;
 
     /* create a filter instance */
     Status = KspCreateFilter(DeviceObject, Irp, iface);
@@ -78,15 +83,31 @@ IKsFilterFactory_fnQueryInterface(
     IN  REFIID refiid,
     OUT PVOID* Output)
 {
-    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, lpVtbl);
+    NTSTATUS Status;
+
+    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, Header.OuterUnknown);
 
     if (IsEqualGUIDAligned(refiid, &IID_IUnknown))
     {
-        *Output = &This->lpVtbl;
+        *Output = &This->Header.OuterUnknown;
         _InterlockedIncrement(&This->ref);
         return STATUS_SUCCESS;
     }
-    return STATUS_UNSUCCESSFUL;
+
+    if (This->Header.ClientAggregate)
+    {
+         /* using client aggregate */
+         Status = This->Header.ClientAggregate->lpVtbl->QueryInterface(This->Header.ClientAggregate, refiid, Output);
+
+         if (NT_SUCCESS(Status))
+         {
+             /* client aggregate supports interface */
+             return Status;
+         }
+    }
+
+    DPRINT("IKsFilterFactory_fnQueryInterface no interface\n");
+    return STATUS_NOT_SUPPORTED;
 }
 
 ULONG
@@ -94,7 +115,7 @@ NTAPI
 IKsFilterFactory_fnAddRef(
     IKsFilterFactory * iface)
 {
-    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, lpVtbl);
+    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, Header.OuterUnknown);
 
     return InterlockedIncrement(&This->ref);
 }
@@ -104,7 +125,7 @@ NTAPI
 IKsFilterFactory_fnRelease(
     IKsFilterFactory * iface)
 {
-    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, lpVtbl);
+    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, Header.OuterUnknown);
 
     InterlockedDecrement(&This->ref);
 
@@ -130,7 +151,7 @@ NTAPI
 IKsFilterFactory_fnGetStruct(
     IKsFilterFactory * iface)
 {
-    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, lpVtbl);
+    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, Header.OuterUnknown);
 
     return &This->FilterFactory;
 }
@@ -141,7 +162,7 @@ IKsFilterFactory_fnSetDeviceClassesState(
     IKsFilterFactory * iface,
     IN BOOLEAN Enable)
 {
-    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, lpVtbl);
+    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, Header.OuterUnknown);
 
     return KspSetDeviceInterfacesState(&This->SymbolicLinkList, Enable);
 }
@@ -207,7 +228,7 @@ IKsFilterFactory_fnInitialize(
     BOOL FreeString = FALSE;
     IKsDevice * KsDevice;
 
-    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, lpVtbl);
+    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(iface, IKsFilterFactoryImpl, Header.OuterUnknown);
 
     /* get device extension */
     DeviceExtension = (PDEVICE_EXTENSION)DeviceObject->DeviceExtension;
@@ -221,17 +242,16 @@ IKsFilterFactory_fnInitialize(
     This->Header.Parent.KsDevice = &DeviceExtension->DeviceHeader->KsDevice;
     This->DeviceHeader = DeviceExtension->DeviceHeader;
 
+    /* initialize filter factory control mutex */
+    This->Header.ControlMutex = &This->ControlMutex;
+    KeInitializeMutex(This->Header.ControlMutex, 0);
+
     /* unused fields */
-    KeInitializeMutex(&This->Header.ControlMutex, 0);
     InitializeListHead(&This->Header.EventList);
     KeInitializeSpinLock(&This->Header.EventListLock);
 
-
     InitializeListHead(&This->SymbolicLinkList);
 
-    /* initialize filter factory control mutex */
-    KeInitializeMutex(&This->Header.ControlMutex, 0);
-
     /* does the device use a reference string */
     if (RefString || !Descriptor->ReferenceGuid)
     {
@@ -253,6 +273,8 @@ IKsFilterFactory_fnInitialize(
         FreeString = TRUE;
     }
 
+    DPRINT("IKsFilterFactory_fnInitialize CategoriesCount %u ReferenceString '%S'\n", Descriptor->CategoriesCount,ReferenceString.Buffer);
+
     /* now register the device interface */
     Status = KspRegisterDeviceInterfaces(DeviceExtension->DeviceHeader->KsDevice.PhysicalDeviceObject,
                                          Descriptor->CategoriesCount,
@@ -299,7 +321,7 @@ IKsFilterFactory_fnInitialize(
         if (This->FilterFactory.Bag)
         {
             /* initialize object bag */
-            KsDevice = (IKsDevice*)&DeviceExtension->DeviceHeader->lpVtblIKsDevice;
+            KsDevice = (IKsDevice*)&DeviceExtension->DeviceHeader->BasicHeader.OuterUnknown;
             KsDevice->lpVtbl->InitializeObjectBag(KsDevice, (PKSIOBJECT_BAG)This->FilterFactory.Bag, NULL);
         }
     }
@@ -338,6 +360,8 @@ KspCreateFilterFactory(
     IKsFilterFactory * Filter;
     NTSTATUS Status;
 
+    DPRINT("KsCreateFilterFactory\n");
+
     /* Lets allocate a filterfactory */
     This = AllocateItem(NonPagedPool, sizeof(IKsFilterFactoryImpl));
     if (!This)
@@ -348,10 +372,10 @@ KspCreateFilterFactory(
 
     /* initialize struct */
     This->ref = 1;
-    This->lpVtbl = &vt_IKsFilterFactoryVtbl;
+    This->Header.OuterUnknown = (PUNKNOWN)&vt_IKsFilterFactoryVtbl;
 
     /* map to com object */
-    Filter = (IKsFilterFactory*)&This->lpVtbl;
+    Filter = (IKsFilterFactory*)&This->Header.OuterUnknown;
 
     /* initialize filter */
     Status = Filter->lpVtbl->Initialize(Filter, DeviceObject, Descriptor, RefString, SecurityDescriptor, CreateItemFlags, SleepCallback, WakeCallback, FilterFactory);
@@ -363,6 +387,10 @@ KspCreateFilterFactory(
     }
 
     /* return result */
+    DPRINT("KsCreateFilterFactory %x\n", Status);
+    /* sanity check */
+    ASSERT(Status == STATUS_SUCCESS);
+
     return Status;
 }
 
@@ -396,8 +424,10 @@ KsFilterFactorySetDeviceClassesState(
     IN PKSFILTERFACTORY  FilterFactory,
     IN BOOLEAN  NewState)
 {
-    IKsFilterFactory * Factory = (IKsFilterFactory*)CONTAINING_RECORD(FilterFactory, IKsFilterFactoryImpl, FilterFactory);
+    IKsFilterFactory * Factory;
+    IKsFilterFactoryImpl * This = (IKsFilterFactoryImpl*)CONTAINING_RECORD(FilterFactory, IKsFilterFactoryImpl, FilterFactory);
 
+    Factory = (IKsFilterFactory*)&This->Header.OuterUnknown;
     return Factory->lpVtbl->SetDeviceClassesState(Factory, NewState);
 }
 
@@ -454,18 +484,261 @@ KsFilterFactoryAddCreateItem(
     return KsAllocateObjectCreateItem((KSDEVICE_HEADER)Factory->DeviceHeader, &CreateItem, TRUE, IKsFilterFactory_ItemFreeCb);
 }
 
+ULONG
+KspCacheAddData(
+    PKSPCACHE_DESCRIPTOR Descriptor,
+    LPCVOID Data,
+    ULONG Length)
+{
+    ULONG Index;
+
+    for(Index = 0; Index < Descriptor->DataOffset; Index++)
+    {
+        if (RtlCompareMemory(Descriptor->DataCache, Data, Length) == Length)
+        {
+            if (Index + Length > Descriptor->DataOffset)
+            {
+                /* adjust used space */
+                Descriptor->DataOffset = Index + Length;
+                /* return absolute offset */
+                return Descriptor->DataLength + Index;
+            }
+        }
+    }
+
+    /* sanity check */
+    ASSERT(Descriptor->DataOffset + Length < Descriptor->DataLength);
+
+    /* copy to data blob */
+    RtlMoveMemory((Descriptor->DataCache + Descriptor->DataOffset), Data, Length);
+
+    /* backup offset */
+    Index = Descriptor->DataOffset;
+
+    /* adjust used space */
+    Descriptor->DataOffset += Length;
+
+    /* return absolute offset */
+    return Descriptor->DataLength + Index;
+}
+
 /*
     @implemented
 */
 KSDDKAPI
 NTSTATUS
 NTAPI
-KsFilterFactoryUpdateCacheData (
+KsFilterFactoryUpdateCacheData(
     IN PKSFILTERFACTORY  FilterFactory,
     IN const KSFILTER_DESCRIPTOR*  FilterDescriptor OPTIONAL)
 {
-    UNIMPLEMENTED
+    KSPCACHE_DESCRIPTOR Descriptor;
+    PKSPCACHE_FILTER_HEADER FilterHeader;
+    UNICODE_STRING FilterData = RTL_CONSTANT_STRING(L"FilterData");
+    PKSPCACHE_PIN_HEADER PinHeader;
+    ULONG Index, SubIndex;
+    PLIST_ENTRY Entry;
+    PSYMBOLIC_LINK_ENTRY SymEntry;
+    BOOLEAN Found;
+    HKEY hKey;
+    NTSTATUS Status = STATUS_SUCCESS;
+
+    IKsFilterFactoryImpl * Factory = (IKsFilterFactoryImpl*)CONTAINING_RECORD(FilterFactory, IKsFilterFactoryImpl, FilterFactory);
+
+    DPRINT("KsFilterFactoryUpdateCacheData %p\n", FilterDescriptor);
+
+    if (!FilterDescriptor)
+        FilterDescriptor = Factory->FilterFactory.FilterDescriptor;
+
+    ASSERT(FilterDescriptor);
+
+    /* initialize cache descriptor */
+    RtlZeroMemory(&Descriptor, sizeof(KSPCACHE_DESCRIPTOR));
+
+    /* calculate filter data size */
+    Descriptor.FilterLength = sizeof(KSPCACHE_FILTER_HEADER);
+
+    /* FIXME support variable size pin descriptors */
+    ASSERT(FilterDescriptor->PinDescriptorSize == sizeof(KSPIN_DESCRIPTOR_EX));
+
+    for(Index = 0; Index < FilterDescriptor->PinDescriptorsCount; Index++)
+    {
+        /* add filter descriptor */
+        Descriptor.FilterLength += sizeof(KSPCACHE_PIN_HEADER);
+
+        if (FilterDescriptor->PinDescriptors[Index].PinDescriptor.Category)
+        {
+            /* add extra ULONG for offset to category */
+            Descriptor.FilterLength += sizeof(ULONG);
+
+            /* add size for clsid */
+            Descriptor.DataLength += sizeof(CLSID);
+        }
+
+        /* add space for formats */
+        Descriptor.FilterLength += FilterDescriptor->PinDescriptors[Index].PinDescriptor.DataRangesCount * sizeof(KSPCACHE_DATARANGE);
+
+        /* add space for MajorFormat / MinorFormat */
+        Descriptor.DataLength += FilterDescriptor->PinDescriptors[Index].PinDescriptor.DataRangesCount * sizeof(CLSID) * 2;
+
+        /* add space for mediums */
+        Descriptor.FilterLength += FilterDescriptor->PinDescriptors[Index].PinDescriptor.MediumsCount * sizeof(ULONG);
+
+        /* add space for the data */
+        Descriptor.DataLength += FilterDescriptor->PinDescriptors[Index].PinDescriptor.MediumsCount * sizeof(KSPCACHE_MEDIUM);
+    }
+
+    /* now allocate the space */
+    Descriptor.FilterData = (PUCHAR)AllocateItem(NonPagedPool, Descriptor.DataLength + Descriptor.FilterLength);
+    if (!Descriptor.FilterData)
+    {
+        /* no memory */
+        return STATUS_INSUFFICIENT_RESOURCES;
+    }
 
-    return STATUS_NOT_IMPLEMENTED;
+    /* initialize data cache */
+    Descriptor.DataCache = (PUCHAR)((ULONG_PTR)Descriptor.FilterData + Descriptor.FilterLength);
+
+    /* setup filter header */
+    FilterHeader = (PKSPCACHE_FILTER_HEADER)Descriptor.FilterData;
+
+    FilterHeader->dwVersion = 2;
+    FilterHeader->dwMerit = MERIT_DO_NOT_USE;
+    FilterHeader->dwUnused = 0;
+    FilterHeader->dwPins = FilterDescriptor->PinDescriptorsCount;
+
+    Descriptor.FilterOffset = sizeof(KSPCACHE_FILTER_HEADER);
+
+    /* write pin headers */
+    for(Index = 0; Index < FilterDescriptor->PinDescriptorsCount; Index++)
+    {
+        /* get offset to pin */
+        PinHeader = (PKSPCACHE_PIN_HEADER)((ULONG_PTR)Descriptor.FilterData + Descriptor.FilterOffset);
+
+        /* write pin header */
+        PinHeader->Signature = 0x33697030 + Index;
+        PinHeader->Flags = 0;
+        PinHeader->Instances = FilterDescriptor->PinDescriptors[Index].InstancesPossible;
+        if (PinHeader->Instances > 1)
+            PinHeader->Flags |= REG_PINFLAG_B_MANY;
+
+
+        PinHeader->MediaTypes = FilterDescriptor->PinDescriptors[Index].PinDescriptor.DataRangesCount;
+        PinHeader->Mediums = FilterDescriptor->PinDescriptors[Index].PinDescriptor.MediumsCount;
+        PinHeader->Category = (FilterDescriptor->PinDescriptors[Index].PinDescriptor.Category ? TRUE : FALSE);
+
+        Descriptor.FilterOffset += sizeof(KSPCACHE_PIN_HEADER);
+
+        if (PinHeader->Category)
+        {
+            /* get category offset */
+            PULONG Category = (PULONG)(PinHeader + 1);
+
+            /* write category offset */
+            *Category = KspCacheAddData(&Descriptor, FilterDescriptor->PinDescriptors[Index].PinDescriptor.Category, sizeof(CLSID));
+
+            /* adjust offset */
+            Descriptor.FilterOffset += sizeof(ULONG);
+        }
+
+        /* add dataranges */
+        for(SubIndex = 0; SubIndex < FilterDescriptor->PinDescriptors[Index].PinDescriptor.DataRangesCount; SubIndex++)
+        {
+            /* get datarange offset */
+            PKSPCACHE_DATARANGE DataRange = (PKSPCACHE_DATARANGE)((ULONG_PTR)Descriptor.FilterData + Descriptor.FilterOffset);
+
+            /* initialize data range */
+            DataRange->Signature = 0x33797430 + SubIndex;
+            DataRange->dwUnused = 0;
+            DataRange->OffsetMajor = KspCacheAddData(&Descriptor, &FilterDescriptor->PinDescriptors[Index].PinDescriptor.DataRanges[SubIndex]->MajorFormat, sizeof(CLSID));
+            DataRange->OffsetMinor = KspCacheAddData(&Descriptor, &FilterDescriptor->PinDescriptors[Index].PinDescriptor.DataRanges[SubIndex]->SubFormat, sizeof(CLSID));
+
+            /* adjust offset */
+            Descriptor.FilterOffset += sizeof(KSPCACHE_DATARANGE);
+        }
+
+        /* add mediums */
+        for(SubIndex = 0; SubIndex < FilterDescriptor->PinDescriptors[Index].PinDescriptor.MediumsCount; SubIndex++)
+        {
+            KSPCACHE_MEDIUM Medium;
+            PULONG MediumOffset;
+
+            /* get pin medium offset */
+            MediumOffset = (PULONG)((ULONG_PTR)Descriptor.FilterData + Descriptor.FilterOffset);
+
+            /* copy medium guid */
+            RtlMoveMemory(&Medium.Medium, &FilterDescriptor->PinDescriptors[Index].PinDescriptor.Mediums[SubIndex].Set, sizeof(GUID));
+            Medium.dw1 = FilterDescriptor->PinDescriptors[Index].PinDescriptor.Mediums[SubIndex].Id; /* FIXME verify */
+            Medium.dw2 = 0;
+
+            *MediumOffset = KspCacheAddData(&Descriptor, &Medium, sizeof(KSPCACHE_MEDIUM));
+
+            /* adjust offset */
+            Descriptor.FilterOffset += sizeof(ULONG);
+        }
+    }
+
+    /* sanity checks */
+    ASSERT(Descriptor.FilterOffset == Descriptor.FilterLength);
+    ASSERT(Descriptor.DataOffset <= Descriptor.DataLength);
+
+
+    /* now go through all entries and update 'FilterData' key */
+    for(Index = 0; Index < FilterDescriptor->CategoriesCount; Index++)
+    {
+        /* get first entry */
+        Entry = Factory->SymbolicLinkList.Flink;
+
+        /* set status to not found */
+        Found = FALSE;
+        /* loop list until the the current category is found */
+        while(Entry != &Factory->SymbolicLinkList)
+        {
+            /* fetch symbolic link entry */
+            SymEntry = (PSYMBOLIC_LINK_ENTRY)CONTAINING_RECORD(Entry, SYMBOLIC_LINK_ENTRY, Entry);
+
+            if (IsEqualGUIDAligned(&SymEntry->DeviceInterfaceClass, &FilterDescriptor->Categories[Index]))
+            {
+                /* found category */
+                Found = TRUE;
+                break;
+            }
+
+            /* move to next entry */
+            Entry = Entry->Flink;
+        }
+
+        if (!Found)
+        {
+            /* filter category is not present */
+            Status = STATUS_INVALID_PARAMETER;
+            break;
+        }
+
+        /* now open device interface */
+        Status = IoOpenDeviceInterfaceRegistryKey(&SymEntry->SymbolicLink, KEY_WRITE, &hKey);
+        if (!NT_SUCCESS(Status))
+        {
+            /* failed to open interface key */
+            break;
+        }
+
+        /* update filterdata key */
+        Status = ZwSetValueKey(hKey, &FilterData, 0, REG_BINARY, Descriptor.FilterData, Descriptor.FilterLength + Descriptor.DataOffset);
+
+        /* close filterdata key */
+        ZwClose(hKey);
+
+        if (!NT_SUCCESS(Status))
+        {
+            /* failed to set key value */
+            break;
+        }
+    }
+    /* free filter data */
+    FreeItem(Descriptor.FilterData);
+
+    /* done */
+    return Status;
 }