[NTOSKRNL]
[reactos.git] / reactos / ntoskrnl / io / iomgr / deviface.c
index c2dac1d..59eddeb 100644 (file)
@@ -50,7 +50,159 @@ IoOpenDeviceInterfaceRegistryKey(IN PUNICODE_STRING SymbolicLinkName,
                                  IN ACCESS_MASK DesiredAccess,
                                  OUT PHANDLE DeviceInterfaceKey)
 {
-    return STATUS_NOT_IMPLEMENTED;
+    WCHAR StrBuff[MAX_PATH], PathBuff[MAX_PATH];
+    PWCHAR Guid, RefString;
+    UNICODE_STRING DevParamU = RTL_CONSTANT_STRING(L"\\Device Parameters");
+    UNICODE_STRING PrefixU = RTL_CONSTANT_STRING(L"\\??\\");
+    UNICODE_STRING KeyPath, KeyName;
+    UNICODE_STRING MatchableGuid;
+    UNICODE_STRING GuidString;
+    HANDLE GuidKey, hInterfaceKey;
+    ULONG Index = 0;
+    PKEY_BASIC_INFORMATION KeyInformation;
+    ULONG KeyInformationLength;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    NTSTATUS Status;
+    ULONG RequiredLength;
+
+    swprintf(StrBuff, L"##?#%s", &SymbolicLinkName->Buffer[PrefixU.Length / sizeof(WCHAR)]);
+
+    RefString = wcsstr(StrBuff, L"\\");
+    if (RefString)
+    {
+        RefString[0] = 0;
+    }
+
+    RtlInitUnicodeString(&MatchableGuid, StrBuff);
+
+    Guid = wcsstr(StrBuff, L"{");
+    if (!Guid)
+        return STATUS_OBJECT_NAME_NOT_FOUND;
+
+    KeyPath.Buffer = PathBuff;
+    KeyPath.Length = 0;
+    KeyPath.MaximumLength = MAX_PATH * sizeof(WCHAR);
+
+    GuidString.Buffer = Guid;
+    GuidString.Length = GuidString.MaximumLength = 38 * sizeof(WCHAR);
+
+    RtlAppendUnicodeToString(&KeyPath, BaseKeyString);
+    RtlAppendUnicodeStringToString(&KeyPath, &GuidString);
+
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &KeyPath,
+                               OBJ_CASE_INSENSITIVE,
+                               0,
+                               NULL);
+
+    Status = ZwOpenKey(&GuidKey, KEY_CREATE_SUB_KEY, &ObjectAttributes);
+    if (!NT_SUCCESS(Status))
+        return Status;
+
+    while (TRUE)
+    {
+        Status = ZwEnumerateKey(GuidKey,
+                                Index,
+                                KeyBasicInformation,
+                                NULL,
+                                0,
+                                &RequiredLength);
+        if (Status == STATUS_NO_MORE_ENTRIES)
+            break;
+        else if (Status == STATUS_BUFFER_TOO_SMALL)
+        {
+            KeyInformationLength = RequiredLength;
+            KeyInformation = ExAllocatePool(PagedPool, KeyInformationLength);
+            if (!KeyInformation)
+            {
+                ZwClose(GuidKey);
+                return STATUS_INSUFFICIENT_RESOURCES;
+            }
+
+            Status = ZwEnumerateKey(GuidKey,
+                                    Index,
+                                    KeyBasicInformation,
+                                    KeyInformation,
+                                    KeyInformationLength,
+                                    &RequiredLength);
+        }
+        else
+        {
+            ZwClose(GuidKey);
+            return STATUS_OBJECT_PATH_NOT_FOUND;
+        }
+        Index++;
+
+        if (!NT_SUCCESS(Status))
+        {
+            ZwClose(GuidKey);
+            return Status;
+        }
+
+        KeyName.Length = KeyName.MaximumLength = KeyInformation->NameLength;
+        KeyName.Buffer = KeyInformation->Name;
+
+        if (!RtlEqualUnicodeString(&KeyName, &MatchableGuid, TRUE))
+        {
+            ExFreePool(KeyInformation);
+            continue;
+        }
+
+        KeyPath.Length = 0;
+        RtlAppendUnicodeStringToString(&KeyPath, &KeyName);
+        RtlAppendUnicodeToString(&KeyPath, L"\\");
+
+        /* check for presence of a reference string */
+        if (RefString)
+        {
+            /* append reference string */
+            RefString[0] = L'#';
+            RtlInitUnicodeString(&KeyName, RefString);
+        }
+        else
+        {
+            /* no reference string */
+            RtlInitUnicodeString(&KeyName, L"#");
+        }
+        RtlAppendUnicodeStringToString(&KeyPath, &KeyName);
+
+        /* initialize reference string attributes */
+        InitializeObjectAttributes(&ObjectAttributes,
+                                   &KeyPath,
+                                   OBJ_CASE_INSENSITIVE,
+                                   GuidKey,
+                                   NULL);
+
+        /* now open device interface key */
+        Status = ZwOpenKey(&hInterfaceKey, KEY_CREATE_SUB_KEY, &ObjectAttributes);
+
+        if (NT_SUCCESS(Status))
+        {
+            /* check if it provides a DeviceParameters key */
+            InitializeObjectAttributes(&ObjectAttributes, &DevParamU, OBJ_CASE_INSENSITIVE, hInterfaceKey, NULL);
+
+             Status = ZwCreateKey(DeviceInterfaceKey, DesiredAccess, &ObjectAttributes, 0, NULL, REG_OPTION_NON_VOLATILE, NULL);
+
+             if (NT_SUCCESS(Status))
+             {
+                 /* DeviceParameters key present */
+                 ZwClose(hInterfaceKey);
+             }
+             else
+             {
+                 /* fall back to device interface */
+                 *DeviceInterfaceKey = hInterfaceKey;
+                 Status = STATUS_SUCCESS;
+             }
+        }
+
+        /* close class key */
+        ZwClose(GuidKey);
+        ExFreePool(KeyInformation);
+        return Status;
+    }
+
+    return STATUS_OBJECT_PATH_NOT_FOUND;
 }
 
 /*++
@@ -247,7 +399,7 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
     OBJECT_ATTRIBUTES ObjectAttributes;
     BOOLEAN FoundRightPDO = FALSE;
     ULONG i = 0, j, Size;
-    UNICODE_STRING ReturnBuffer = {0,};
+    UNICODE_STRING ReturnBuffer = { 0, 0, NULL };
     NTSTATUS Status;
 
     PAGED_CODE();
@@ -400,6 +552,7 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                 /* We have to check if the interface is enabled, by
                 * reading the Linked value in the Control subkey
                 */
+#if 0
                 InitializeObjectAttributes(
                     &ObjectAttributes,
                     &Control,
@@ -422,6 +575,7 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                     DPRINT("ZwOpenKey() failed with status 0x%08lx\n", Status);
                     goto cleanup;
                 }
+#endif
                 /* FIXME: Read the Linked value
                 * If it doesn't exist => ERROR
                 * If it is not a REG_DWORD or Size != sizeof(ULONG) => ERROR
@@ -496,7 +650,8 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                     goto cleanup;
                 }
                 RtlCopyMemory(NewBuffer, ReturnBuffer.Buffer, ReturnBuffer.Length);
-                ExFreePool(ReturnBuffer.Buffer);
+                if (ReturnBuffer.Buffer)
+                    ExFreePool(ReturnBuffer.Buffer);
                 ReturnBuffer.Buffer = NewBuffer;
             }
             DPRINT("Adding symbolic link %wZ\n", &KeyName);
@@ -506,12 +661,6 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                 DPRINT("RtlAppendUnicodeStringToString() failed with status 0x%08lx\n", Status);
                 goto cleanup;
             }
-            /* RtlAppendUnicodeStringToString added a NULL at the end of the
-            * destination string, but didn't increase the Length field.
-            * Do it for it.
-            */
-            ReturnBuffer.Length += sizeof(WCHAR);
-
 NextReferenceString:
             ExFreePool(ReferenceBi);
             ReferenceBi = NULL;
@@ -561,7 +710,7 @@ NextReferenceString:
     Status = STATUS_SUCCESS;
 
 cleanup:
-    if (!NT_SUCCESS(Status))
+    if (!NT_SUCCESS(Status) && ReturnBuffer.Buffer)
         ExFreePool(ReturnBuffer.Buffer);
     if (InterfaceKey != INVALID_HANDLE_VALUE)
         ZwClose(InterfaceKey);
@@ -571,9 +720,12 @@ cleanup:
         ZwClose(ReferenceKey);
     if (ControlKey != INVALID_HANDLE_VALUE)
         ZwClose(ControlKey);
-    ExFreePool(DeviceBi);
-    ExFreePool(ReferenceBi);
-    ExFreePool(bip);
+    if (DeviceBi)
+        ExFreePool(DeviceBi);
+    if (ReferenceBi)
+        ExFreePool(ReferenceBi);
+    if (bip)
+        ExFreePool(bip);
     return Status;
 }
 
@@ -879,19 +1031,14 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
     }
     RtlAppendUnicodeToString(SymbolicLinkName, L"#");
     RtlAppendUnicodeStringToString(SymbolicLinkName, &GuidString);
-    if (ReferenceString && ReferenceString->Length)
-    {
-        RtlAppendUnicodeToString(SymbolicLinkName, L"\\");
-        RtlAppendUnicodeStringToString(SymbolicLinkName, ReferenceString);
-    }
     SymbolicLinkName->Buffer[SymbolicLinkName->Length/sizeof(WCHAR)] = L'\0';
 
     /* Create symbolic link */
     DPRINT("IoRegisterDeviceInterface(): creating symbolic link %wZ -> %wZ\n", SymbolicLinkName, &PdoNameInfo->Name);
     Status = IoCreateSymbolicLink(SymbolicLinkName, &PdoNameInfo->Name);
-    if (!NT_SUCCESS(Status))
+    if (!NT_SUCCESS(Status) && ReferenceString == NULL)
     {
-        DPRINT("IoCreateSymbolicLink() failed with status 0x%08lx\n", Status);
+        DPRINT1("IoCreateSymbolicLink() failed with status 0x%08lx\n", Status);
         ZwClose(SubKey);
         ZwClose(InterfaceKey);
         ZwClose(ClassKey);
@@ -902,6 +1049,13 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
         return Status;
     }
 
+    if (ReferenceString && ReferenceString->Length)
+    {
+        RtlAppendUnicodeToString(SymbolicLinkName, L"\\");
+        RtlAppendUnicodeStringToString(SymbolicLinkName, ReferenceString);
+    }
+    SymbolicLinkName->Buffer[SymbolicLinkName->Length/sizeof(WCHAR)] = L'\0';
+
     /* Write symbolic link name in registry */
     SymbolicLinkName->Buffer[1] = '\\';
     Status = ZwSetValueKey(
@@ -913,10 +1067,13 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
         SymbolicLinkName->Length);
     if (!NT_SUCCESS(Status))
     {
-        DPRINT("ZwSetValueKey() failed with status 0x%08lx\n", Status);
+        DPRINT1("ZwSetValueKey() failed with status 0x%08lx\n", Status);
         ExFreePool(SymbolicLinkName->Buffer);
     }
-    SymbolicLinkName->Buffer[1] = '?';
+    else
+    {
+        SymbolicLinkName->Buffer[1] = '?';
+    }
 
     ZwClose(SubKey);
     ZwClose(InterfaceKey);
@@ -956,11 +1113,13 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
     PDEVICE_OBJECT PhysicalDeviceObject;
     PFILE_OBJECT FileObject;
     UNICODE_STRING GuidString;
+    UNICODE_STRING SymLink;
     PWCHAR StartPosition;
     PWCHAR EndPosition;
     NTSTATUS Status;
     LPCGUID EventGuid;
 
+
     if (SymbolicLinkName == NULL)
         return STATUS_INVALID_PARAMETER_1;
 
@@ -972,21 +1131,24 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
     EndPosition = wcschr(SymbolicLinkName->Buffer, L'}');
     if (!StartPosition ||!EndPosition || StartPosition > EndPosition)
     {
-        DPRINT("IoSetDeviceInterfaceState() returning STATUS_INVALID_PARAMETER_1\n");
+        DPRINT1("IoSetDeviceInterfaceState() returning STATUS_INVALID_PARAMETER_1\n");
         return STATUS_INVALID_PARAMETER_1;
     }
     GuidString.Buffer = StartPosition;
     GuidString.MaximumLength = GuidString.Length = (USHORT)((ULONG_PTR)(EndPosition + 1) - (ULONG_PTR)StartPosition);
 
+    SymLink.Buffer = SymbolicLinkName->Buffer;
+    SymLink.MaximumLength = SymLink.Length = (USHORT)((ULONG_PTR)(EndPosition + 1) - (ULONG_PTR)SymLink.Buffer);
+    DPRINT("IoSetDeviceInterfaceState('%wZ', %d)\n", SymbolicLinkName, Enable);
     /* Get pointer to the PDO */
     Status = IoGetDeviceObjectPointer(
-        SymbolicLinkName,
+        &SymLink,
         0, /* DesiredAccess */
         &FileObject,
         &PhysicalDeviceObject);
     if (!NT_SUCCESS(Status))
     {
-        DPRINT("IoGetDeviceObjectPointer() failed with status 0x%08lx\n", Status);
+        DPRINT1("IoGetDeviceObjectPointer() failed with status 0x%08lx\n", Status);
         return Status;
     }
 
@@ -999,7 +1161,7 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
         (PVOID)SymbolicLinkName);
 
     ObDereferenceObject(FileObject);
-
+    DPRINT("Status %x\n", Status);
     return STATUS_SUCCESS;
 }