[MOUNTMGR] Don't kill Mm when a device has several symlinks
[reactos.git] / drivers / filters / mountmgr / point.c
index 3eab7c3..38260a5 100644 (file)
@@ -23,8 +23,6 @@
  * PROGRAMMER:       Pierre Schweitzer (pierre.schweitzer@reactos.org)
  */
 
-/* INCLUDES *****************************************************************/
-
 #include "mntmgr.h"
 
 #define NDEBUG
@@ -46,7 +44,7 @@ MountMgrCreatePointWorker(IN PDEVICE_EXTENSION DeviceExtension,
     PDEVICE_INFORMATION DeviceInformation = NULL, DeviceInfo;
 
     /* Get device name */
-    Status = QueryDeviceInformation(SymbolicLinkName,
+    Status = QueryDeviceInformation(DeviceName,
                                     &TargetDeviceName,
                                     NULL, NULL, NULL,
                                     NULL, NULL, NULL);
@@ -62,7 +60,7 @@ MountMgrCreatePointWorker(IN PDEVICE_EXTENSION DeviceExtension,
     {
         DeviceInformation = CONTAINING_RECORD(DeviceEntry, DEVICE_INFORMATION, DeviceListEntry);
 
-        if (RtlCompareUnicodeString(&TargetDeviceName, &(DeviceInformation->DeviceName), TRUE) == 0)
+        if (RtlEqualUnicodeString(&TargetDeviceName, &(DeviceInformation->DeviceName), TRUE))
         {
             break;
         }
@@ -223,7 +221,7 @@ MountMgrCreatePointWorker(IN PDEVICE_EXTENSION DeviceExtension,
     FreePool(SymLink.Buffer);
     MountMgrNotify(DeviceExtension);
 
-    if (!DeviceInformation->Volume)
+    if (!DeviceInformation->ManuallyRegistered)
     {
         MountMgrNotifyNameChange(DeviceExtension, DeviceName, FALSE);
     }
@@ -243,11 +241,12 @@ QueryPointsFromMemory(IN PDEVICE_EXTENSION DeviceExtension,
     NTSTATUS Status;
     PIO_STACK_LOCATION Stack;
     UNICODE_STRING DeviceName;
-    ULONG TotalSize, TotalSymLinks;
     PMOUNTMGR_MOUNT_POINTS MountPoints;
     PDEVICE_INFORMATION DeviceInformation;
     PLIST_ENTRY DeviceEntry, SymlinksEntry;
     PSYMLINK_INFORMATION SymlinkInformation;
+    USHORT UniqueIdLength, DeviceNameLength;
+    ULONG TotalSize, TotalSymLinks, UniqueIdOffset, DeviceNameOffset;
 
     /* If we got a symbolic link, query device */
     if (SymbolicName)
@@ -324,7 +323,7 @@ QueryPointsFromMemory(IN PDEVICE_EXTENSION DeviceExtension,
     {
         if (DeviceEntry == &(DeviceExtension->DeviceListHead))
         {
-            if (DeviceName.Buffer)
+            if (SymbolicName)
             {
                 FreePool(DeviceName.Buffer);
             }
@@ -334,21 +333,29 @@ QueryPointsFromMemory(IN PDEVICE_EXTENSION DeviceExtension,
     }
 
     /* Now, ensure output buffer can hold everything */
-    Stack = IoGetNextIrpStackLocation(Irp);
+    Stack = IoGetCurrentIrpStackLocation(Irp);
     MountPoints = (PMOUNTMGR_MOUNT_POINTS)Irp->AssociatedIrp.SystemBuffer;
 
     /* Ensure we set output to let user reallocate! */
-    MountPoints->Size = sizeof(MOUNTMGR_MOUNT_POINTS) + TotalSize;
+    MountPoints->Size = sizeof(MOUNTMGR_MOUNT_POINTS) + TotalSymLinks * sizeof(MOUNTMGR_MOUNT_POINT) + TotalSize;
     MountPoints->NumberOfMountPoints = TotalSymLinks;
+    Irp->IoStatus.Information = MountPoints->Size;
 
     if (MountPoints->Size > Stack->Parameters.DeviceIoControl.OutputBufferLength)
     {
+        Irp->IoStatus.Information = sizeof(MOUNTMGR_MOUNT_POINTS);
+
+        if (SymbolicName)
+        {
+            FreePool(DeviceName.Buffer);
+        }
+
         return STATUS_BUFFER_OVERFLOW;
     }
 
     /* Now, start putting mount points */
+    TotalSize = sizeof(MOUNTMGR_MOUNT_POINTS) + TotalSymLinks * sizeof(MOUNTMGR_MOUNT_POINT);
     TotalSymLinks = 0;
-    TotalSize = 0;
     for (DeviceEntry = DeviceExtension->DeviceListHead.Flink;
          DeviceEntry != &(DeviceExtension->DeviceListHead);
          DeviceEntry = DeviceEntry->Flink)
@@ -358,7 +365,7 @@ QueryPointsFromMemory(IN PDEVICE_EXTENSION DeviceExtension,
         /* Find back correct mount point */
         if (UniqueId)
         {
-            if (!UniqueId->UniqueIdLength != DeviceInformation->UniqueId->UniqueIdLength)
+            if (UniqueId->UniqueIdLength != DeviceInformation->UniqueId->UniqueIdLength)
             {
                 continue;
             }
@@ -378,6 +385,26 @@ QueryPointsFromMemory(IN PDEVICE_EXTENSION DeviceExtension,
             }
         }
 
+        /* Save our information about shared data */
+        UniqueIdOffset = TotalSize;
+        UniqueIdLength = DeviceInformation->UniqueId->UniqueIdLength;
+        DeviceNameOffset = TotalSize + UniqueIdLength;
+        DeviceNameLength = DeviceInformation->DeviceName.Length;
+
+        /* Initialize first symlink */
+        MountPoints->MountPoints[TotalSymLinks].UniqueIdOffset = UniqueIdOffset;
+        MountPoints->MountPoints[TotalSymLinks].UniqueIdLength = UniqueIdLength;
+        MountPoints->MountPoints[TotalSymLinks].DeviceNameOffset = DeviceNameOffset;
+        MountPoints->MountPoints[TotalSymLinks].DeviceNameLength = DeviceNameLength;
+
+        /* And copy data */
+        RtlCopyMemory((PWSTR)((ULONG_PTR)MountPoints + UniqueIdOffset),
+                      DeviceInformation->UniqueId->UniqueId, UniqueIdLength);
+        RtlCopyMemory((PWSTR)((ULONG_PTR)MountPoints + DeviceNameOffset),
+                      DeviceInformation->DeviceName.Buffer, DeviceNameLength);
+
+        TotalSize += DeviceInformation->UniqueId->UniqueIdLength + DeviceInformation->DeviceName.Length;
+
         /* Now we've got it, but all the data */
         for (SymlinksEntry = DeviceInformation->SymbolicLinksListHead.Flink;
              SymlinksEntry != &(DeviceInformation->SymbolicLinksListHead);
@@ -385,31 +412,33 @@ QueryPointsFromMemory(IN PDEVICE_EXTENSION DeviceExtension,
         {
             SymlinkInformation = CONTAINING_RECORD(SymlinksEntry, SYMLINK_INFORMATION, SymbolicLinksListEntry);
 
+            /* First, set shared data */
+
+            /* Only put UniqueID if online */
+            if (SymlinkInformation->Online)
+            {
+                MountPoints->MountPoints[TotalSymLinks].UniqueIdOffset = UniqueIdOffset;
+                MountPoints->MountPoints[TotalSymLinks].UniqueIdLength = UniqueIdLength;
+            }
+            else
+            {
+                MountPoints->MountPoints[TotalSymLinks].UniqueIdOffset = 0;
+                MountPoints->MountPoints[TotalSymLinks].UniqueIdLength = 0;
+            }
+
+            MountPoints->MountPoints[TotalSymLinks].DeviceNameOffset = DeviceNameOffset;
+            MountPoints->MountPoints[TotalSymLinks].DeviceNameLength = DeviceNameLength;
 
-            MountPoints->MountPoints[TotalSymLinks].SymbolicLinkNameOffset = sizeof(MOUNTMGR_MOUNT_POINTS) +
-                                                                             TotalSize;
+            /* And now, copy specific symlink info */
+            MountPoints->MountPoints[TotalSymLinks].SymbolicLinkNameOffset = TotalSize;
             MountPoints->MountPoints[TotalSymLinks].SymbolicLinkNameLength = SymlinkInformation->Name.Length;
-            MountPoints->MountPoints[TotalSymLinks].UniqueIdOffset = sizeof(MOUNTMGR_MOUNT_POINTS) +
-                                                                     SymlinkInformation->Name.Length +
-                                                                     TotalSize;
-            MountPoints->MountPoints[TotalSymLinks].UniqueIdLength = DeviceInformation->UniqueId->UniqueIdLength;
-            MountPoints->MountPoints[TotalSymLinks].DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINTS) +
-                                                                       SymlinkInformation->Name.Length +
-                                                                       DeviceInformation->UniqueId->UniqueIdLength +
-                                                                       TotalSize;
-            MountPoints->MountPoints[TotalSymLinks].DeviceNameLength = DeviceInformation->DeviceName.Length;
 
             RtlCopyMemory((PWSTR)((ULONG_PTR)MountPoints + MountPoints->MountPoints[TotalSymLinks].SymbolicLinkNameOffset),
                           SymlinkInformation->Name.Buffer, SymlinkInformation->Name.Length);
-            RtlCopyMemory((PWSTR)((ULONG_PTR)MountPoints + MountPoints->MountPoints[TotalSymLinks].UniqueIdOffset),
-                          DeviceInformation->UniqueId->UniqueId, DeviceInformation->UniqueId->UniqueIdLength);
-            RtlCopyMemory((PWSTR)((ULONG_PTR)MountPoints + MountPoints->MountPoints[TotalSymLinks].DeviceNameOffset),
-                  DeviceInformation->DeviceName.Buffer, DeviceInformation->DeviceName.Length);
 
             /* Update counters */
             TotalSymLinks++;
-            TotalSize += SymlinkInformation->Name.Length + DeviceInformation->UniqueId->UniqueIdLength +
-                         DeviceInformation->DeviceName.Length;
+            TotalSize += SymlinkInformation->Name.Length;
         }
 
         if (UniqueId || SymbolicName)
@@ -418,6 +447,11 @@ QueryPointsFromMemory(IN PDEVICE_EXTENSION DeviceExtension,
         }
     }
 
+    if (SymbolicName)
+    {
+        FreePool(DeviceName.Buffer);
+    }
+
     return STATUS_SUCCESS;
 }
 
@@ -451,7 +485,7 @@ QueryPointsFromSymbolicLinkName(IN PDEVICE_EXTENSION DeviceExtension,
         {
             DeviceInformation = CONTAINING_RECORD(DeviceEntry, DEVICE_INFORMATION, DeviceListEntry);
 
-            if (RtlEqualUnicodeString(&DeviceName, &(DeviceInformation->DeviceName), TRUE) == 0)
+            if (RtlEqualUnicodeString(&DeviceName, &(DeviceInformation->DeviceName), TRUE))
             {
                 break;
             }
@@ -467,11 +501,11 @@ QueryPointsFromSymbolicLinkName(IN PDEVICE_EXTENSION DeviceExtension,
         /* Check for the link */
         for (SymlinksEntry = DeviceInformation->SymbolicLinksListHead.Flink;
              SymlinksEntry != &(DeviceInformation->SymbolicLinksListHead);
-             SymlinksEntry = DeviceEntry->Flink)
+             SymlinksEntry = SymlinksEntry->Flink)
         {
             SymlinkInformation = CONTAINING_RECORD(SymlinksEntry, SYMLINK_INFORMATION, SymbolicLinksListEntry);
 
-            if (RtlEqualUnicodeString(SymbolicName, &SymlinkInformation->Name, TRUE) == 0)
+            if (RtlEqualUnicodeString(SymbolicName, &SymlinkInformation->Name, TRUE))
             {
                 break;
             }
@@ -499,7 +533,7 @@ QueryPointsFromSymbolicLinkName(IN PDEVICE_EXTENSION DeviceExtension,
             {
                 SymlinkInformation = CONTAINING_RECORD(SymlinksEntry, SYMLINK_INFORMATION, SymbolicLinksListEntry);
 
-                if (RtlEqualUnicodeString(SymbolicName, &SymlinkInformation->Name, TRUE) == 0)
+                if (RtlEqualUnicodeString(SymbolicName, &SymlinkInformation->Name, TRUE))
                 {
                     break;
                 }
@@ -519,7 +553,7 @@ QueryPointsFromSymbolicLinkName(IN PDEVICE_EXTENSION DeviceExtension,
     }
 
     /* Get output buffer */
-    Stack = IoGetNextIrpStackLocation(Irp);
+    Stack = IoGetCurrentIrpStackLocation(Irp);
     MountPoints = (PMOUNTMGR_MOUNT_POINTS)Irp->AssociatedIrp.SystemBuffer;
 
     /* Compute output length */
@@ -529,9 +563,12 @@ QueryPointsFromSymbolicLinkName(IN PDEVICE_EXTENSION DeviceExtension,
     /* Give length to allow reallocation */
     MountPoints->Size = sizeof(MOUNTMGR_MOUNT_POINTS) + TotalLength;
     MountPoints->NumberOfMountPoints = 1;
+    Irp->IoStatus.Information = sizeof(MOUNTMGR_MOUNT_POINTS) + TotalLength;
 
     if (MountPoints->Size > Stack->Parameters.DeviceIoControl.OutputBufferLength)
     {
+        Irp->IoStatus.Information = sizeof(MOUNTMGR_MOUNT_POINTS);
+
         return STATUS_BUFFER_OVERFLOW;
     }