[NTOSKRNL]
[reactos.git] / ntoskrnl / io / iomgr / deviface.c
index dcc086e..9b40aab 100644 (file)
@@ -332,7 +332,7 @@ IopOpenInterfaceKey(IN CONST GUID *InterfaceClassGuid,
     }
 
     KeyName.Length = 0;
-    KeyName.MaximumLength = LocalMachine.Length + (wcslen(REGSTR_PATH_DEVICE_CLASSES) + 1) * sizeof(WCHAR) + GuidString.Length;
+    KeyName.MaximumLength = LocalMachine.Length + ((USHORT)wcslen(REGSTR_PATH_DEVICE_CLASSES) + 1) * sizeof(WCHAR) + GuidString.Length;
     KeyName.Buffer = ExAllocatePool(PagedPool, KeyName.MaximumLength);
     if (!KeyName.Buffer)
     {
@@ -731,17 +731,13 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
             }
             KeyName.Length = KeyName.MaximumLength = (USHORT)bip->DataLength - 4 * sizeof(WCHAR);
             KeyName.Buffer = &((PWSTR)bip->Data)[4];
-            if (KeyName.Length && KeyName.Buffer[KeyName.Length / sizeof(WCHAR)] == UNICODE_NULL)
-            {
-                /* Remove trailing NULL */
-                KeyName.Length -= sizeof(WCHAR);
-            }
 
             /* Add new symbolic link to symbolic link list */
             if (ReturnBuffer.Length + KeyName.Length + sizeof(WCHAR) > ReturnBuffer.MaximumLength)
             {
                 PWSTR NewBuffer;
-                ReturnBuffer.MaximumLength = max(ReturnBuffer.MaximumLength * 2, ReturnBuffer.Length + KeyName.Length + 2 * sizeof(WCHAR));
+                ReturnBuffer.MaximumLength = (USHORT)max(ReturnBuffer.MaximumLength * 2,
+                                                 ReturnBuffer.Length + KeyName.Length + 2 * sizeof(WCHAR));
                 NewBuffer = ExAllocatePool(PagedPool, ReturnBuffer.MaximumLength);
                 if (!NewBuffer)
                 {
@@ -761,11 +757,11 @@ IoGetDeviceInterfaces(IN CONST GUID *InterfaceClassGuid,
                 DPRINT("RtlAppendUnicodeStringToString() failed with status 0x%08lx\n", Status);
                 goto cleanup;
             }
-            /* RtlAppendUnicodeStringToString added a NULL at the end of the    
-            * destination string, but didn't increase the Length field.         
-            * Do it for it.     
-            */          
-           ReturnBuffer.Length += sizeof(WCHAR);
+            /* RtlAppendUnicodeStringToString added a NULL at the end of the
+             * destination string, but didn't increase the Length field.
+             * Do it for it.
+             */
+            ReturnBuffer.Length += sizeof(WCHAR);
 
 NextReferenceString:
             ExFreePool(ReferenceBi);
@@ -891,7 +887,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
     ULONG StartIndex;
     OBJECT_ATTRIBUTES ObjectAttributes;
     ULONG i;
-    NTSTATUS Status;
+    NTSTATUS Status, SymLinkStatus;
     PEXTENDED_DEVOBJ_EXTENSION DeviceObjectExtension;
 
     ASSERT_IRQL_EQUAL(PASSIVE_LEVEL);
@@ -951,7 +947,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
     /* Create base key name for this interface: HKLM\SYSTEM\CurrentControlSet\Control\DeviceClasses\{GUID} */
     ASSERT(((PEXTENDED_DEVOBJ_EXTENSION)PhysicalDeviceObject->DeviceObjectExtension)->DeviceNode);
     InstancePath = &((PEXTENDED_DEVOBJ_EXTENSION)PhysicalDeviceObject->DeviceObjectExtension)->DeviceNode->InstancePath;
-    BaseKeyName.Length = wcslen(BaseKeyString) * sizeof(WCHAR);
+    BaseKeyName.Length = (USHORT)wcslen(BaseKeyString) * sizeof(WCHAR);
     BaseKeyName.MaximumLength = BaseKeyName.Length
         + GuidString.Length;
     BaseKeyName.Buffer = ExAllocatePool(
@@ -1145,10 +1141,20 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
 
     /* Create symbolic link */
     DPRINT("IoRegisterDeviceInterface(): creating symbolic link %wZ -> %wZ\n", SymbolicLinkName, &PdoNameInfo->Name);
-    Status = IoCreateSymbolicLink(SymbolicLinkName, &PdoNameInfo->Name);
-    if (!NT_SUCCESS(Status) && ReferenceString == NULL)
+    SymLinkStatus = IoCreateSymbolicLink(SymbolicLinkName, &PdoNameInfo->Name);
+
+    /* If the symbolic link already exists, return an informational success status */
+    if (SymLinkStatus == STATUS_OBJECT_NAME_COLLISION)
+    {
+        /* HACK: Delete the existing symbolic link and update it to the new PDO name */
+        IoDeleteSymbolicLink(SymbolicLinkName);
+        IoCreateSymbolicLink(SymbolicLinkName, &PdoNameInfo->Name);
+        SymLinkStatus = STATUS_OBJECT_NAME_EXISTS;
+    }
+
+    if (!NT_SUCCESS(SymLinkStatus))
     {
-        DPRINT1("IoCreateSymbolicLink() failed with status 0x%08lx\n", Status);
+        DPRINT1("IoCreateSymbolicLink() failed with status 0x%08lx\n", SymLinkStatus);
         ZwClose(SubKey);
         ZwClose(InterfaceKey);
         ZwClose(ClassKey);
@@ -1156,7 +1162,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
         ExFreePool(InterfaceKeyName.Buffer);
         ExFreePool(BaseKeyName.Buffer);
         ExFreePool(SymbolicLinkName->Buffer);
-        return Status;
+        return SymLinkStatus;
     }
 
     if (ReferenceString && ReferenceString->Length)
@@ -1192,7 +1198,7 @@ IoRegisterDeviceInterface(IN PDEVICE_OBJECT PhysicalDeviceObject,
     ExFreePool(InterfaceKeyName.Buffer);
     ExFreePool(BaseKeyName.Buffer);
 
-    return Status;
+    return NT_SUCCESS(Status) ? SymLinkStatus : Status;
 }
 
 /*++
@@ -1232,6 +1238,7 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
     UNICODE_STRING KeyName;
     OBJECT_ATTRIBUTES ObjectAttributes;
     ULONG LinkedValue;
+    GUID DeviceGuid;
 
     if (SymbolicLinkName == NULL)
         return STATUS_INVALID_PARAMETER_1;
@@ -1310,12 +1317,19 @@ IoSetDeviceInterfaceState(IN PUNICODE_STRING SymbolicLinkName,
         return Status;
     }
 
+    Status = RtlGUIDFromString(&GuidString, &DeviceGuid);
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("RtlGUIDFromString() failed with status 0x%08lx\n", Status);
+        return Status;
+    }
+
     EventGuid = Enable ? &GUID_DEVICE_INTERFACE_ARRIVAL : &GUID_DEVICE_INTERFACE_REMOVAL;
     IopNotifyPlugPlayNotification(
         PhysicalDeviceObject,
         EventCategoryDeviceInterfaceChange,
         EventGuid,
-        &GuidString,
+        &DeviceGuid,
         (PVOID)SymbolicLinkName);
 
     ObDereferenceObject(FileObject);