[USB]
[reactos.git] / reactos / ntoskrnl / io / iomgr / deviface.c
index 391c681..9b40aab 100644 (file)
 
 static PWCHAR BaseKeyString = L"\\Registry\\Machine\\System\\CurrentControlSet\\Control\\DeviceClasses\\";
 
-/*++
- * @name IoOpenDeviceInterfaceRegistryKey
- * @unimplemented
- *
- * Provides a handle to the device's interface instance registry key.
- * Documented in WDK.
- *
- * @param SymbolicLinkName
- *        Pointer to a string which identifies the device interface instance
- *
- * @param DesiredAccess
- *        Desired ACCESS_MASK used to access the key (like KEY_READ,
- *        KEY_WRITE, etc)
- *
- * @param DeviceInterfaceKey
- *        If a call has been succesfull, a handle to the registry key
- *        will be stored there
- *
- * @return Three different NTSTATUS values in case of errors, and STATUS_SUCCESS
- *         otherwise (see WDK for details)
- *
- * @remarks Must be called at IRQL = PASSIVE_LEVEL in the context of a system thread
- *
- *--*/
+static
 NTSTATUS
-NTAPI
-IoOpenDeviceInterfaceRegistryKey(IN PUNICODE_STRING SymbolicLinkName,
-                                 IN ACCESS_MASK DesiredAccess,
-                                 OUT PHANDLE DeviceInterfaceKey)
+OpenRegistryHandlesFromSymbolicLink(IN PUNICODE_STRING SymbolicLinkName,
+                                    IN ACCESS_MASK DesiredAccess,
+                                    IN OPTIONAL PHANDLE GuidKey,
+                                    IN OPTIONAL PHANDLE DeviceKey,
+                                    IN OPTIONAL PHANDLE InstanceKey)
 {
-    WCHAR StrBuff[MAX_PATH], PathBuff[MAX_PATH];
-    PWCHAR Guid, RefString;
-    UNICODE_STRING DevParamU = RTL_CONSTANT_STRING(L"\\Device Parameters");
-    UNICODE_STRING PrefixU = RTL_CONSTANT_STRING(L"\\??\\");
-    UNICODE_STRING KeyPath, KeyName;
-    UNICODE_STRING MatchableGuid;
-    UNICODE_STRING GuidString;
-    HANDLE GuidKey, hInterfaceKey;
-    ULONG Index = 0;
-    PKEY_BASIC_INFORMATION KeyInformation;
-    ULONG KeyInformationLength;
     OBJECT_ATTRIBUTES ObjectAttributes;
+    WCHAR PathBuffer[MAX_PATH];
+    UNICODE_STRING BaseKeyU;
+    UNICODE_STRING GuidString, SubKeyName, ReferenceString;
+    PWCHAR StartPosition, EndPosition;
+    HANDLE ClassesKey;
+    PHANDLE GuidKeyRealP, DeviceKeyRealP, InstanceKeyRealP;
+    HANDLE GuidKeyReal, DeviceKeyReal, InstanceKeyReal;
     NTSTATUS Status;
-    ULONG RequiredLength;
 
-    swprintf(StrBuff, L"##?#%s", &SymbolicLinkName->Buffer[PrefixU.Length / sizeof(WCHAR)]);
+    SubKeyName.Buffer = NULL;
 
-    RefString = wcsstr(StrBuff, L"\\");
-    if (RefString)
-    {
-        RefString[0] = 0;
-    }
+    if (GuidKey != NULL)
+        GuidKeyRealP = GuidKey;
+    else
+        GuidKeyRealP = &GuidKeyReal;
 
-    RtlInitUnicodeString(&MatchableGuid, StrBuff);
+    if (DeviceKey != NULL)
+        DeviceKeyRealP = DeviceKey;
+    else
+        DeviceKeyRealP = &DeviceKeyReal;
 
-    Guid = wcsstr(StrBuff, L"{");
-    if (!Guid)
-        return STATUS_OBJECT_NAME_NOT_FOUND;
+    if (InstanceKey != NULL)
+        InstanceKeyRealP = InstanceKey;
+    else
+        InstanceKeyRealP = &InstanceKeyReal;
 
-    KeyPath.Buffer = PathBuff;
-    KeyPath.Length = 0;
-    KeyPath.MaximumLength = MAX_PATH * sizeof(WCHAR);
+    *GuidKeyRealP = INVALID_HANDLE_VALUE;
+    *DeviceKeyRealP = INVALID_HANDLE_VALUE;
+    *InstanceKeyRealP = INVALID_HANDLE_VALUE;
 
-    GuidString.Buffer = Guid;
-    GuidString.Length = GuidString.MaximumLength = 38 * sizeof(WCHAR);
+    BaseKeyU.Buffer = PathBuffer;
+    BaseKeyU.Length = 0;
+    BaseKeyU.MaximumLength = MAX_PATH * sizeof(WCHAR);
 
-    RtlAppendUnicodeToString(&KeyPath, BaseKeyString);
-    RtlAppendUnicodeStringToString(&KeyPath, &GuidString);
+    RtlAppendUnicodeToString(&BaseKeyU, BaseKeyString);
 
+    /* Open the DeviceClasses key */
     InitializeObjectAttributes(&ObjectAttributes,
-                               &KeyPath,
-                               OBJ_CASE_INSENSITIVE,
-                               0,
+                               &BaseKeyU,
+                               OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE,
+                               NULL,
                                NULL);
+    Status = ZwOpenKey(&ClassesKey,
+                       DesiredAccess | KEY_ENUMERATE_SUB_KEYS,
+                       &ObjectAttributes);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to open %wZ\n", &BaseKeyU);
+        goto cleanup;
+    }
+
+    StartPosition = wcschr(SymbolicLinkName->Buffer, L'{');
+    EndPosition = wcschr(SymbolicLinkName->Buffer, L'}');
+    if (!StartPosition || !EndPosition || StartPosition > EndPosition)
+    {
+        DPRINT1("Bad symbolic link: %wZ\n", SymbolicLinkName);
+        return STATUS_INVALID_PARAMETER_1;
+    }
+    GuidString.Buffer = StartPosition;
+    GuidString.MaximumLength = GuidString.Length = (USHORT)((ULONG_PTR)(EndPosition + 1) - (ULONG_PTR)StartPosition);
 
-    Status = ZwOpenKey(&GuidKey, KEY_CREATE_SUB_KEY, &ObjectAttributes);
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &GuidString,
+                               OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE,
+                               ClassesKey,
+                               NULL);
+    Status = ZwOpenKey(GuidKeyRealP,
+                       DesiredAccess | KEY_ENUMERATE_SUB_KEYS,
+                       &ObjectAttributes);
+    ZwClose(ClassesKey);
     if (!NT_SUCCESS(Status))
-        return Status;
+    {
+        DPRINT1("Failed to open %wZ%wZ (%x)\n", &BaseKeyU, &GuidString, Status);
+        goto cleanup;
+    }
 
-    while (TRUE)
+    SubKeyName.MaximumLength = SymbolicLinkName->Length + sizeof(WCHAR);
+    SubKeyName.Length = 0;
+    SubKeyName.Buffer = ExAllocatePool(PagedPool, SubKeyName.MaximumLength);
+    if (!SubKeyName.Buffer)
     {
-        Status = ZwEnumerateKey(GuidKey,
-                                Index,
-                                KeyBasicInformation,
-                                NULL,
-                                0,
-                                &RequiredLength);
-        if (Status == STATUS_NO_MORE_ENTRIES)
-            break;
-        else if (Status == STATUS_BUFFER_TOO_SMALL)
-        {
-            KeyInformationLength = RequiredLength;
-            KeyInformation = ExAllocatePool(PagedPool, KeyInformationLength);
-            if (!KeyInformation)
-            {
-                ZwClose(GuidKey);
-                return STATUS_INSUFFICIENT_RESOURCES;
-            }
+        Status = STATUS_INSUFFICIENT_RESOURCES;
+        goto cleanup;
+    }
 
-            Status = ZwEnumerateKey(GuidKey,
-                                    Index,
-                                    KeyBasicInformation,
-                                    KeyInformation,
-                                    KeyInformationLength,
-                                    &RequiredLength);
-        }
-        else
-        {
-            ZwClose(GuidKey);
-            return STATUS_OBJECT_PATH_NOT_FOUND;
-        }
-        Index++;
+    RtlAppendUnicodeStringToString(&SubKeyName,
+                                   SymbolicLinkName);
 
-        if (!NT_SUCCESS(Status))
-        {
-            ZwClose(GuidKey);
-            return Status;
-        }
+    SubKeyName.Buffer[SubKeyName.Length / sizeof(WCHAR)] = UNICODE_NULL;
 
-        KeyName.Length = KeyName.MaximumLength = KeyInformation->NameLength;
-        KeyName.Buffer = KeyInformation->Name;
+    SubKeyName.Buffer[0] = L'#';
+    SubKeyName.Buffer[1] = L'#';
+    SubKeyName.Buffer[2] = L'?';
+    SubKeyName.Buffer[3] = L'#';
 
-        if (!RtlEqualUnicodeString(&KeyName, &MatchableGuid, TRUE))
-        {
-            ExFreePool(KeyInformation);
-            continue;
-        }
+    ReferenceString.Buffer = wcsrchr(SubKeyName.Buffer, '\\');
+    if (ReferenceString.Buffer != NULL)
+    {
+        ReferenceString.Buffer[0] = L'#';
 
-        KeyPath.Length = 0;
-        RtlAppendUnicodeStringToString(&KeyPath, &KeyName);
-        RtlAppendUnicodeToString(&KeyPath, L"\\");
+        SubKeyName.Length = (USHORT)((ULONG_PTR)(ReferenceString.Buffer) - (ULONG_PTR)SubKeyName.Buffer);
+        ReferenceString.Length = SymbolicLinkName->Length - SubKeyName.Length;
+    }
+    else
+    {
+        RtlInitUnicodeString(&ReferenceString, L"#");
+    }
 
-        /* check for presence of a reference string */
-        if (RefString)
-        {
-            /* append reference string */
-            RefString[0] = L'#';
-            RtlInitUnicodeString(&KeyName, RefString);
-        }
-        else
-        {
-            /* no reference string */
-            RtlInitUnicodeString(&KeyName, L"#");
-        }
-        RtlAppendUnicodeStringToString(&KeyPath, &KeyName);
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &SubKeyName,
+                               OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE,
+                               *GuidKeyRealP,
+                               NULL);
+    Status = ZwOpenKey(DeviceKeyRealP,
+                       DesiredAccess | KEY_ENUMERATE_SUB_KEYS,
+                       &ObjectAttributes);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to open %wZ%wZ\\%wZ\n", &BaseKeyU, &GuidString, &SubKeyName);
+        goto cleanup;
+    }
+
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &ReferenceString,
+                               OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE,
+                               *DeviceKeyRealP,
+                               NULL);
+    Status = ZwOpenKey(InstanceKeyRealP,
+                       DesiredAccess,
+                       &ObjectAttributes);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to open %wZ%wZ\\%wZ%\\%wZ (%x)\n", &BaseKeyU, &GuidString, &SubKeyName, &ReferenceString, Status);
+        goto cleanup;
+    }
 
-        /* initialize reference string attributes */
-        InitializeObjectAttributes(&ObjectAttributes,
-                                   &KeyPath,
-                                   OBJ_CASE_INSENSITIVE,
-                                   GuidKey,
-                                   NULL);
+    Status = STATUS_SUCCESS;
 
-        /* now open device interface key */
-        Status = ZwOpenKey(&hInterfaceKey, KEY_CREATE_SUB_KEY, &ObjectAttributes);
+cleanup:
+    if (SubKeyName.Buffer != NULL)
+       ExFreePool(SubKeyName.Buffer);
 
-        if (NT_SUCCESS(Status))
-        {
-            /* check if it provides a DeviceParameters key */
-            InitializeObjectAttributes(&ObjectAttributes, &DevParamU, OBJ_CASE_INSENSITIVE, hInterfaceKey, NULL);
-
-             Status = ZwCreateKey(DeviceInterfaceKey, DesiredAccess, &ObjectAttributes, 0, NULL, REG_OPTION_NON_VOLATILE, NULL);
-
-             if (NT_SUCCESS(Status))
-             {
-                 /* DeviceParameters key present */
-                 ZwClose(hInterfaceKey);
-             }
-             else
-             {
-                 /* fall back to device interface */
-                 *DeviceInterfaceKey = hInterfaceKey;
-                 Status = STATUS_SUCCESS;
-             }
-        }
+    if (NT_SUCCESS(Status))
+    {
+       if (!GuidKey)
+          ZwClose(*GuidKeyRealP);
 
-        /* close class key */
-        ZwClose(GuidKey);
-        ExFreePool(KeyInformation);
-        return Status;
+       if (!DeviceKey)
+          ZwClose(*DeviceKeyRealP);
+
+       if (!InstanceKey)
+          ZwClose(*InstanceKeyRealP);
     }
+    else
+    {
+       if (*GuidKeyRealP != INVALID_HANDLE_VALUE)
+          ZwClose(*GuidKeyRealP);
+
+       if (*DeviceKeyRealP != INVALID_HANDLE_VALUE)
+          ZwClose(*DeviceKeyRealP);
 
-    return STATUS_OBJECT_PATH_NOT_FOUND;
+       if (*InstanceKeyRealP != INVALID_HANDLE_VALUE)
+          ZwClose(*InstanceKeyRealP);
+    }
+
+    return Status;
+}
+/*++
+ * @name IoOpenDeviceInterfaceRegistryKey
+ * @unimplemented
+ *
+ * Provides a handle to the device's interface instance registry key.
+ * Documented in WDK.
+ *
+ * @param SymbolicLinkName
+ *        Pointer to a string which identifies the device interface instance
+ *
+ * @param DesiredAccess
+ *        Desired ACCESS_MASK used to access the key (like KEY_READ,
+ *        KEY_WRITE, etc)
+ *
+ * @param DeviceInterfaceKey
+ *        If a call has been succesfull, a handle to the registry key
+ *        will be stored there
+ *
+ * @return Three different NTSTATUS values in case of errors, and STATUS_SUCCESS
+ *         otherwise (see WDK for details)
+ *
+ * @remarks Must be called at IRQL = PASSIVE_LEVEL in the context of a system thread
+ *
+ *--*/
+NTSTATUS
+NTAPI
+IoOpenDeviceInterfaceRegistryKey(IN PUNICODE_STRING SymbolicLinkName,
+                                 IN ACCESS_MASK DesiredAccess,
+                                 OUT PHANDLE DeviceInterfaceKey)
+{
+   HANDLE InstanceKey, DeviceParametersKey;
+   NTSTATUS Status;
+   OBJECT_ATTRIBUTES ObjectAttributes;
+   UNICODE_STRING DeviceParametersU = RTL_CONSTANT_STRING(L"Device Parameters");
+
+   Status = OpenRegistryHandlesFromSymbolicLink(SymbolicLinkName,
+                                                KEY_CREATE_SUB_KEY,
+                                                NULL,
+                                                NULL,
+                                                &InstanceKey);
+   if (!NT_SUCCESS(Status))
+       return Status;
+
+   InitializeObjectAttributes(&ObjectAttributes,
+                              &DeviceParametersU,
+                              OBJ_CASE_INSENSITIVE | OBJ_OPENIF,
+                              InstanceKey,
+                              NULL);
+   Status = ZwCreateKey(&DeviceParametersKey,
+                        DesiredAccess,
+                        &ObjectAttributes,
+                        0,
+                        NULL,
+                        REG_OPTION_NON_VOLATILE,
+                        NULL);
+   ZwClose(InstanceKey);
+
+   if (NT_SUCCESS(Status))
+       *DeviceInterfaceKey = DeviceParametersKey;
+
+   return Status;
 }
 
 /*++
@@ -278,7 +332,7 @@ IopOpenInterfaceKey(IN CONST GUID *InterfaceClassGuid,
     }
 
     KeyName.Length = 0;
-    KeyName.MaximumLength = LocalMachine.Length + (wcslen(REGSTR_PATH_DEVICE_CLASSES) + 1) * sizeof(WCHAR) + GuidString.Length;
+    KeyName.MaximumLength = LocalMachine.Length + ((USHORT)wcslen(REGSTR_PATH_DEVICE_CLASSES) + 1) * sizeof(WCHAR) + GuidString.Length;
     KeyName.Buffer = ExAllocatePool(PagedPool, KeyName.MaximumLength);
     if (!KeyName.Buffer)
     {
@@ -395,10 +449,11 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
     PKEY_BASIC_INFORMATION DeviceBi = NULL;
     PKEY_BASIC_INFORMATION ReferenceBi = NULL;
     PKEY_VALUE_PARTIAL_INFORMATION bip = NULL;
+    PKEY_VALUE_PARTIAL_INFORMATION PartialInfo;
     UNICODE_STRING KeyName;
     OBJECT_ATTRIBUTES ObjectAttributes;
     BOOLEAN FoundRightPDO = FALSE;
-    ULONG i = 0, j, Size;
+    ULONG i = 0, j, Size, NeededLength, ActualLength, LinkedValue;
     UNICODE_STRING ReturnBuffer = { 0, 0, NULL };
     NTSTATUS Status;
 
@@ -552,7 +607,6 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                 /* We have to check if the interface is enabled, by
                 * reading the Linked value in the Control subkey
                 */
-#if 0
                 InitializeObjectAttributes(
                     &ObjectAttributes,
                     &Control,
@@ -572,17 +626,63 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                 }
                 else if (!NT_SUCCESS(Status))
                 {
-                    DPRINT("ZwOpenKey() failed with status 0x%08lx\n", Status);
+                    DPRINT1("ZwOpenKey() failed with status 0x%08lx\n", Status);
+                    goto cleanup;
+                }
+
+                RtlInitUnicodeString(&KeyName, L"Linked");
+                Status = ZwQueryValueKey(ControlKey,
+                                         &KeyName,
+                                         KeyValuePartialInformation,
+                                         NULL,
+                                         0,
+                                         &NeededLength);
+                if (Status == STATUS_BUFFER_TOO_SMALL)
+                {
+                    ActualLength = NeededLength;
+                    PartialInfo = ExAllocatePool(NonPagedPool, ActualLength);
+                    if (!PartialInfo)
+                    {
+                        Status = STATUS_INSUFFICIENT_RESOURCES;
+                        goto cleanup;
+                    }
+
+                    Status = ZwQueryValueKey(ControlKey,
+                                             &KeyName,
+                                             KeyValuePartialInformation,
+                                             PartialInfo,
+                                             ActualLength,
+                                             &NeededLength);
+                    if (!NT_SUCCESS(Status))
+                    {
+                        DPRINT1("ZwQueryValueKey #2 failed (%x)\n", Status);
+                        ExFreePool(PartialInfo);
+                        goto cleanup;
+                    }
+
+                    if (PartialInfo->Type != REG_DWORD || PartialInfo->DataLength != sizeof(ULONG))
+                    {
+                        DPRINT1("Bad registry read\n");
+                        ExFreePool(PartialInfo);
+                        goto cleanup;
+                    }
+
+                    RtlCopyMemory(&LinkedValue,
+                                  PartialInfo->Data,
+                                  PartialInfo->DataLength);
+
+                    ExFreePool(PartialInfo);
+                    if (LinkedValue == 0)
+                    {
+                        /* This interface isn't active */
+                        goto NextReferenceString;
+                    }
+                }
+                else
+                {
+                    DPRINT1("ZwQueryValueKey #1 failed (%x)\n", Status);
                     goto cleanup;
                 }
-#endif
-                /* FIXME: Read the Linked value
-                * If it doesn't exist => ERROR
-                * If it is not a REG_DWORD or Size != sizeof(ULONG) => ERROR
-                * If its value is 0, go to NextReferenceString
-                * At the moment, do as if it is active...
-                */
-                DPRINT1("Checking if device is enabled is not implemented yet!\n");
             }
 
             /* Read the SymbolicLink string and add it into SymbolicLinkList */
@@ -631,17 +731,13 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
             }
             KeyName.Length = KeyName.MaximumLength = (USHORT)bip->DataLength - 4 * sizeof(WCHAR);
             KeyName.Buffer = &((PWSTR)bip->Data)[4];
-            if (KeyName.Length && KeyName.Buffer[KeyName.Length / sizeof(WCHAR)] == UNICODE_NULL)
-            {
-                /* Remove trailing NULL */
-                KeyName.Length -= sizeof(WCHAR);
-            }
 
             /* Add new symbolic link to symbolic link list */
             if (ReturnBuffer.Length + KeyName.Length + sizeof(WCHAR) > ReturnBuffer.MaximumLength)
             {
                 PWSTR NewBuffer;
-                ReturnBuffer.MaximumLength = max(ReturnBuffer.MaximumLength * 2, ReturnBuffer.Length + KeyName.Length + 2 * sizeof(WCHAR));
+                ReturnBuffer.MaximumLength = (USHORT)max(ReturnBuffer.MaximumLength * 2,
+                                                 ReturnBuffer.Length + KeyName.Length + 2 * sizeof(WCHAR));
                 NewBuffer = ExAllocatePool(PagedPool, ReturnBuffer.MaximumLength);
                 if (!NewBuffer)
                 {
@@ -662,15 +758,16 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                 goto cleanup;
             }
             /* RtlAppendUnicodeStringToString added a NULL at the end of the
-            * destination string, but didn't increase the Length field.
-            * Do it for it.
-            */
+             * destination string, but didn't increase the Length field.
+             * Do it for it.
+             */
             ReturnBuffer.Length += sizeof(WCHAR);
 
 NextReferenceString:
             ExFreePool(ReferenceBi);
             ReferenceBi = NULL;
-            ExFreePool(bip);
+            if (bip)
+                ExFreePool(bip);
             bip = NULL;
             if (ReferenceKey != INVALID_HANDLE_VALUE)
             {
@@ -707,8 +804,11 @@ NextReferenceString:
             Status = STATUS_INSUFFICIENT_RESOURCES;
             goto cleanup;
         }
-        RtlCopyMemory(NewBuffer, ReturnBuffer.Buffer, ReturnBuffer.Length);
-        ExFreePool(ReturnBuffer.Buffer);
+        if (ReturnBuffer.Buffer)
+        {
+            RtlCopyMemory(NewBuffer, ReturnBuffer.Buffer, ReturnBuffer.Length);
+            ExFreePool(ReturnBuffer.Buffer);
+        }
         ReturnBuffer.Buffer = NewBuffer;
     }
     ReturnBuffer.Buffer[ReturnBuffer.Length / sizeof(WCHAR)] = UNICODE_NULL;
@@ -787,7 +887,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
     ULONG StartIndex;
     OBJECT_ATTRIBUTES ObjectAttributes;
     ULONG i;
-    NTSTATUS Status;
+    NTSTATUS Status, SymLinkStatus;
     PEXTENDED_DEVOBJ_EXTENSION DeviceObjectExtension;
 
     ASSERT_IRQL_EQUAL(PASSIVE_LEVEL);
@@ -847,7 +947,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
     /* Create base key name for this interface: HKLM\SYSTEM\CurrentControlSet\Control\DeviceClasses\{GUID} */
     ASSERT(((PEXTENDED_DEVOBJ_EXTENSION)PhysicalDeviceObject->DeviceObjectExtension)->DeviceNode);
     InstancePath = &((PEXTENDED_DEVOBJ_EXTENSION)PhysicalDeviceObject->DeviceObjectExtension)->DeviceNode->InstancePath;
-    BaseKeyName.Length = wcslen(BaseKeyString) * sizeof(WCHAR);
+    BaseKeyName.Length = (USHORT)wcslen(BaseKeyString) * sizeof(WCHAR);
     BaseKeyName.MaximumLength = BaseKeyName.Length
         + GuidString.Length;
     BaseKeyName.Buffer = ExAllocatePool(
@@ -1041,10 +1141,20 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
 
     /* Create symbolic link */
     DPRINT("IoRegisterDeviceInterface(): creating symbolic link %wZ -> %wZ\n", SymbolicLinkName, &PdoNameInfo->Name);
-    Status = IoCreateSymbolicLink(SymbolicLinkName, &PdoNameInfo->Name);
-    if (!NT_SUCCESS(Status) && ReferenceString == NULL)
+    SymLinkStatus = IoCreateSymbolicLink(SymbolicLinkName, &PdoNameInfo->Name);
+
+    /* If the symbolic link already exists, return an informational success status */
+    if (SymLinkStatus == STATUS_OBJECT_NAME_COLLISION)
+    {
+        /* HACK: Delete the existing symbolic link and update it to the new PDO name */
+        IoDeleteSymbolicLink(SymbolicLinkName);
+        IoCreateSymbolicLink(SymbolicLinkName, &PdoNameInfo->Name);
+        SymLinkStatus = STATUS_OBJECT_NAME_EXISTS;
+    }
+
+    if (!NT_SUCCESS(SymLinkStatus))
     {
-        DPRINT1("IoCreateSymbolicLink() failed with status 0x%08lx\n", Status);
+        DPRINT1("IoCreateSymbolicLink() failed with status 0x%08lx\n", SymLinkStatus);
         ZwClose(SubKey);
         ZwClose(InterfaceKey);
         ZwClose(ClassKey);
@@ -1052,7 +1162,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
         ExFreePool(InterfaceKeyName.Buffer);
         ExFreePool(BaseKeyName.Buffer);
         ExFreePool(SymbolicLinkName->Buffer);
-        return Status;
+        return SymLinkStatus;
     }
 
     if (ReferenceString && ReferenceString->Length)
@@ -1088,7 +1198,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
     ExFreePool(InterfaceKeyName.Buffer);
     ExFreePool(BaseKeyName.Buffer);
 
-    return Status;
+    return NT_SUCCESS(Status) ? SymLinkStatus : Status;
 }
 
 /*++
@@ -1124,7 +1234,11 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
     PWCHAR EndPosition;
     NTSTATUS Status;
     LPCGUID EventGuid;
-
+    HANDLE InstanceHandle, ControlHandle;
+    UNICODE_STRING KeyName;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    ULONG LinkedValue;
+    GUID DeviceGuid;
 
     if (SymbolicLinkName == NULL)
         return STATUS_INVALID_PARAMETER_1;
@@ -1146,6 +1260,51 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
     SymLink.Buffer = SymbolicLinkName->Buffer;
     SymLink.MaximumLength = SymLink.Length = (USHORT)((ULONG_PTR)(EndPosition + 1) - (ULONG_PTR)SymLink.Buffer);
     DPRINT("IoSetDeviceInterfaceState('%wZ', %d)\n", SymbolicLinkName, Enable);
+
+    Status = OpenRegistryHandlesFromSymbolicLink(SymbolicLinkName,
+                                                 KEY_CREATE_SUB_KEY,
+                                                 NULL,
+                                                 NULL,
+                                                 &InstanceHandle);
+    if (!NT_SUCCESS(Status))
+        return Status;
+
+    RtlInitUnicodeString(&KeyName, L"Control");
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &KeyName,
+                               OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE,
+                               InstanceHandle,
+                               NULL);
+    Status = ZwCreateKey(&ControlHandle,
+                         KEY_SET_VALUE,
+                         &ObjectAttributes,
+                         0,
+                         NULL,
+                         REG_OPTION_VOLATILE,
+                         NULL);
+    ZwClose(InstanceHandle);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to create the Control subkey\n");
+        return Status;
+    }
+
+    LinkedValue = (Enable ? 1 : 0);
+
+    RtlInitUnicodeString(&KeyName, L"Linked");
+    Status = ZwSetValueKey(ControlHandle,
+                           &KeyName,
+                           0,
+                           REG_DWORD,
+                           &LinkedValue,
+                           sizeof(ULONG));
+    ZwClose(ControlHandle);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("Failed to write the Linked value\n");
+        return Status;
+    }
+
     /* Get pointer to the PDO */
     Status = IoGetDeviceObjectPointer(
         &SymLink,
@@ -1158,12 +1317,19 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
         return Status;
     }
 
+    Status = RtlGUIDFromString(&GuidString, &DeviceGuid);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("RtlGUIDFromString() failed with status 0x%08lx\n", Status);
+        return Status;
+    }
+
     EventGuid = Enable ? &GUID_DEVICE_INTERFACE_ARRIVAL : &GUID_DEVICE_INTERFACE_REMOVAL;
     IopNotifyPlugPlayNotification(
         PhysicalDeviceObject,
         EventCategoryDeviceInterfaceChange,
         EventGuid,
-        &GuidString,
+        &DeviceGuid,
         (PVOID)SymbolicLinkName);
 
     ObDereferenceObject(FileObject);