[MOUNTMGR] Fix out of bounds write
[reactos.git] / drivers / filters / mountmgr / mountmgr.c
index 229a00a..b79bb0f 100644 (file)
  *                   Alex Ionescu (alex.ionescu@reactos.org)
  */
 
-/* INCLUDES *****************************************************************/
-
 #include "mntmgr.h"
 
 #define NDEBUG
 #include <debug.h>
 
+#if defined(ALLOC_PRAGMA)
+#pragma alloc_text(INIT, MountmgrReadNoAutoMount)
+#pragma alloc_text(INIT, DriverEntry)
+#endif
+
 /* FIXME */
 GUID MountedDevicesGuid = {0x53F5630D, 0xB6BF, 0x11D0, {0x94, 0xF2, 0x00, 0xA0, 0xC9, 0x1E, 0xFB, 0x8B}};
 
@@ -39,17 +42,7 @@ KEVENT UnloadEvent;
 LONG Unloading;
 
 static const WCHAR Cunc[] = L"\\??\\C:";
-
-/*
- * TODO:
- * - MountMgrQueryDosVolumePath
- * - MountMgrQueryDosVolumePaths
- * - MountMgrQueryVolumePaths
- * - MountMgrValidateBackPointer
- * - MountMgrVolumeMountPointCreated
- * - MountMgrVolumeMountPointDeleted
- * - ReconcileThisDatabaseWithMasterWorker
- */
+#define Cunc_LETTER_POSITION 4
 
 /*
  * @implemented
@@ -97,28 +90,18 @@ HasDriveLetter(IN PDEVICE_INFORMATION DeviceInformation)
     PLIST_ENTRY NextEntry;
     PSYMLINK_INFORMATION SymlinkInfo;
 
-    /* To have a drive letter, a device must have symbolic links */
-    if (IsListEmpty(&(DeviceInformation->SymbolicLinksListHead)))
-    {
-        return FALSE;
-    }
-
-    /* Browse all the links untill a drive letter is found */
-    NextEntry = &(DeviceInformation->SymbolicLinksListHead);
-    do
+    /* Browse all the symlinks to check if there is at least a drive letter */
+    for (NextEntry = DeviceInformation->SymbolicLinksListHead.Flink;
+         NextEntry != &DeviceInformation->SymbolicLinksListHead;
+         NextEntry = NextEntry->Flink)
     {
         SymlinkInfo = CONTAINING_RECORD(NextEntry, SYMLINK_INFORMATION, SymbolicLinksListEntry);
 
-        if (SymlinkInfo->Online)
+        if (IsDriveLetter(&SymlinkInfo->Name) && SymlinkInfo->Online)
         {
-            if (IsDriveLetter(&(SymlinkInfo->Name)))
-            {
-                return TRUE;
-            }
+            return TRUE;
         }
-
-        NextEntry = NextEntry->Flink;
-    } while (NextEntry != &(DeviceInformation->SymbolicLinksListHead));
+    }
 
     return FALSE;
 }
@@ -135,8 +118,8 @@ CreateNewDriveLetterName(OUT PUNICODE_STRING DriveLetter,
     NTSTATUS Status = STATUS_UNSUCCESSFUL;
 
     /* Allocate a big enough buffer to contain the symbolic link */
-    DriveLetter->MaximumLength = sizeof(DosDevices.Buffer) + 3 * sizeof(WCHAR);
-    DriveLetter->Buffer = AllocatePool(sizeof(DosDevices.Buffer) + 3 * sizeof(WCHAR));
+    DriveLetter->MaximumLength = DosDevices.Length + 3 * sizeof(WCHAR);
+    DriveLetter->Buffer = AllocatePool(DriveLetter->MaximumLength);
     if (!DriveLetter->Buffer)
     {
         return STATUS_INSUFFICIENT_RESOURCES;
@@ -146,9 +129,9 @@ CreateNewDriveLetterName(OUT PUNICODE_STRING DriveLetter,
     RtlCopyUnicodeString(DriveLetter, &DosDevices);
 
     /* Update string to reflect real contents */
-    DriveLetter->Length = sizeof(DosDevices.Buffer) + 2 * sizeof(WCHAR);
-    DriveLetter->Buffer[(sizeof(DosDevices.Buffer) + 2 * sizeof(WCHAR)) / sizeof (WCHAR)] = UNICODE_NULL;
-    DriveLetter->Buffer[(sizeof(DosDevices.Buffer) + sizeof(WCHAR)) / sizeof (WCHAR)] = L':';
+    DriveLetter->Length = DosDevices.Length + 2 * sizeof(WCHAR);
+    DriveLetter->Buffer[DosDevices.Length / sizeof(WCHAR) + 2] = UNICODE_NULL;
+    DriveLetter->Buffer[DosDevices.Length / sizeof(WCHAR) + 1] = L':';
 
     /* If caller wants a no drive entry */
     if (Letter == (UCHAR)-1)
@@ -161,7 +144,7 @@ CreateNewDriveLetterName(OUT PUNICODE_STRING DriveLetter,
     else if (Letter)
     {
         /* Use the letter given by the caller */
-        DriveLetter->Buffer[sizeof(DosDevices.Buffer) / sizeof(WCHAR)] = (WCHAR)Letter;
+        DriveLetter->Buffer[DosDevices.Length / sizeof(WCHAR)] = (WCHAR)Letter;
         Status = GlobalCreateSymbolicLink(DriveLetter, DeviceName);
         if (NT_SUCCESS(Status))
         {
@@ -169,34 +152,39 @@ CreateNewDriveLetterName(OUT PUNICODE_STRING DriveLetter,
         }
     }
 
-    /* If caller didn't provide a letter, let's find one for him.
-     * If device is a floppy, start with letter A
-     */
+    /* If caller didn't provide a letter, let's find one for him */
+
     if (RtlPrefixUnicodeString(&DeviceFloppy, DeviceName, TRUE))
     {
+        /* If the device is a floppy, start with letter A */
         Letter = 'A';
     }
+    else if (RtlPrefixUnicodeString(&DeviceCdRom, DeviceName, TRUE))
+    {
+        /* If the device is a CD-ROM, start with letter D */
+        Letter = 'D';
+    }
     else
     {
-        /* Otherwise, if device is a cd rom, then, start with D.
-         * Finally, if a disk, use C
-         */
-        Letter = RtlPrefixUnicodeString(&DeviceCdRom, DeviceName, TRUE) + 'C';
+        /* Finally, if it's a disk, use C */
+        Letter = 'C';
     }
 
     /* Try to affect a letter (up to Z, ofc) until it's possible */
     for (; Letter <= 'Z'; Letter++)
     {
-        DriveLetter->Buffer[sizeof(DosDevices.Buffer) / sizeof(WCHAR)] = (WCHAR)Letter;
+        DriveLetter->Buffer[DosDevices.Length / sizeof(WCHAR)] = (WCHAR)Letter;
         Status = GlobalCreateSymbolicLink(DriveLetter, DeviceName);
         if (NT_SUCCESS(Status))
         {
+            DPRINT("Assigned drive %c: to %wZ\n", Letter, DeviceName);
             return Status;
         }
     }
 
     /* We failed to allocate a letter */
     FreePool(DriveLetter->Buffer);
+    DPRINT("Failed to create a drive letter for %wZ\n", DeviceName);
     return Status;
 }
 
@@ -216,12 +204,12 @@ QueryDeviceInformation(IN PUNICODE_STRING SymbolicName,
     PIRP Irp;
     USHORT Size;
     KEVENT Event;
-    NTSTATUS Status;
     BOOLEAN IsRemovable;
     PMOUNTDEV_NAME Name;
     PMOUNTDEV_UNIQUE_ID Id;
     PFILE_OBJECT FileObject;
     PIO_STACK_LOCATION Stack;
+    NTSTATUS Status, IntStatus;
     PDEVICE_OBJECT DeviceObject;
     IO_STATUS_BLOCK IoStatusBlock;
     PARTITION_INFORMATION_EX PartitionInfo;
@@ -285,7 +273,7 @@ QueryDeviceInformation(IN PUNICODE_STRING SymbolicName,
             if (Status == STATUS_PENDING)
             {
                 KeWaitForSingleObject(&Event, Executive, KernelMode, FALSE, NULL);
-                Status =  IoStatusBlock.Status;
+                Status = IoStatusBlock.Status;
             }
 
             /* In case of failure, don't fail, that's no vital */
@@ -332,7 +320,7 @@ QueryDeviceInformation(IN PUNICODE_STRING SymbolicName,
             if (Status == STATUS_PENDING)
             {
                 KeWaitForSingleObject(&Event, Executive, KernelMode, FALSE, NULL);
-                Status =  IoStatusBlock.Status;
+                Status = IoStatusBlock.Status;
             }
 
             /* Once again here, failure isn't major */
@@ -371,7 +359,7 @@ QueryDeviceInformation(IN PUNICODE_STRING SymbolicName,
                 if (Status == STATUS_PENDING)
                 {
                     KeWaitForSingleObject(&Event, Executive, KernelMode, FALSE, NULL);
-                    Status =  IoStatusBlock.Status;
+                    Status = IoStatusBlock.Status;
                 }
 
                 if (!NT_SUCCESS(Status))
@@ -473,32 +461,33 @@ QueryDeviceInformation(IN PUNICODE_STRING SymbolicName,
             }
         }
 
-        /* Here we can't fail and assume default value */
-        if (!NT_SUCCESS(Status))
-        {
-            FreePool(Name);
-            ObDereferenceObject(DeviceObject);
-            ObDereferenceObject(FileObject);
-            return Status;
-        }
-
-        /* Copy back found name to the caller */
-        DeviceName->Length = Name->NameLength;
-        DeviceName->MaximumLength = Name->NameLength + sizeof(WCHAR);
-        DeviceName->Buffer = AllocatePool(DeviceName->MaximumLength);
-        if (!DeviceName->Buffer)
+        if (NT_SUCCESS(Status))
         {
-            FreePool(Name);
-            ObDereferenceObject(DeviceObject);
-            ObDereferenceObject(FileObject);
-            return STATUS_INSUFFICIENT_RESOURCES;
+            /* Copy back found name to the caller */
+            DeviceName->Length = Name->NameLength;
+            DeviceName->MaximumLength = Name->NameLength + sizeof(WCHAR);
+            DeviceName->Buffer = AllocatePool(DeviceName->MaximumLength);
+            if (!DeviceName->Buffer)
+            {
+                Status = STATUS_INSUFFICIENT_RESOURCES;
+            }
+            else
+            {
+                RtlCopyMemory(DeviceName->Buffer, Name->Name, Name->NameLength);
+                DeviceName->Buffer[Name->NameLength / sizeof(WCHAR)] = UNICODE_NULL;
+            }
         }
 
-        RtlCopyMemory(DeviceName->Buffer, Name->Name, Name->NameLength);
-        DeviceName->Buffer[Name->NameLength / sizeof(WCHAR)] = UNICODE_NULL;
         FreePool(Name);
     }
 
+    if (!NT_SUCCESS(Status))
+    {
+        ObDereferenceObject(DeviceObject);
+        ObDereferenceObject(FileObject);
+        return Status;
+    }
+
     /* If caller wants device unique ID */
     if (UniqueId)
     {
@@ -558,7 +547,7 @@ QueryDeviceInformation(IN PUNICODE_STRING SymbolicName,
 
             /* Query unique ID */
             KeInitializeEvent(&Event, NotificationEvent, FALSE);
-            Irp = IoBuildDeviceIoControlRequest(IOCTL_MOUNTDEV_QUERY_DEVICE_NAME,
+            Irp = IoBuildDeviceIoControlRequest(IOCTL_MOUNTDEV_QUERY_UNIQUE_ID,
                                                 DeviceObject,
                                                 NULL,
                                                 0,
@@ -631,14 +620,14 @@ QueryDeviceInformation(IN PUNICODE_STRING SymbolicName,
         Stack = IoGetNextIrpStackLocation(Irp);
         Stack->FileObject = FileObject;
 
-        Status = IoCallDriver(DeviceObject, Irp);
-        if (Status == STATUS_PENDING)
+        IntStatus = IoCallDriver(DeviceObject, Irp);
+        if (IntStatus == STATUS_PENDING)
         {
             KeWaitForSingleObject(&Event, Executive, KernelMode, FALSE, NULL);
-            Status = IoStatusBlock.Status;
+            IntStatus = IoStatusBlock.Status;
         }
 
-        *HasGuid = NT_SUCCESS(Status);
+        *HasGuid = NT_SUCCESS(IntStatus);
     }
 
     ObDereferenceObject(DeviceObject);
@@ -701,7 +690,7 @@ FindDeviceInfo(IN PDEVICE_EXTENSION DeviceExtension,
         FreePool(DeviceName.Buffer);
     }
 
-    /* Return found intormation */
+    /* Return found information */
     if (NextEntry == &(DeviceExtension->DeviceListHead))
     {
         return STATUS_OBJECT_NAME_NOT_FOUND;
@@ -816,6 +805,8 @@ MountMgrUnload(IN struct _DRIVER_OBJECT *DriverObject)
     PDEVICE_INFORMATION DeviceInformation;
     PSAVED_LINK_INFORMATION SavedLinkInformation;
 
+    UNREFERENCED_PARAMETER(DriverObject);
+
     /* Don't get notification any longer */
     IoUnregisterShutdownNotification(gdeviceObject);
 
@@ -832,7 +823,7 @@ MountMgrUnload(IN struct _DRIVER_OBJECT *DriverObject)
     KeInitializeEvent(&UnloadEvent, NotificationEvent, FALSE);
 
     /* Wait for workers to finish */
-    if (InterlockedIncrement(&DeviceExtension->WorkerReferences))
+    if (InterlockedIncrement(&DeviceExtension->WorkerReferences) > 0)
     {
         KeReleaseSemaphore(&(DeviceExtension->WorkerSemaphore),
                            IO_NO_INCREMENT, 1, FALSE);
@@ -873,7 +864,7 @@ MountMgrUnload(IN struct _DRIVER_OBJECT *DriverObject)
         NextEntry = RemoveHeadList(&(DeviceExtension->UniqueIdWorkerItemListHead));
         WorkItem = CONTAINING_RECORD(NextEntry, UNIQUE_ID_WORK_ITEM, UniqueIdWorkerItemListEntry);
 
-        KeResetEvent(&UnloadEvent);
+        KeClearEvent(&UnloadEvent);
         WorkItem->Event = &UnloadEvent;
 
         KeReleaseSemaphore(&(DeviceExtension->DeviceLock), IO_NO_INCREMENT,
@@ -908,6 +899,7 @@ MountMgrUnload(IN struct _DRIVER_OBJECT *DriverObject)
 /*
  * @implemented
  */
+INIT_FUNCTION
 BOOLEAN
 MountmgrReadNoAutoMount(IN PUNICODE_STRING RegistryPath)
 {
@@ -944,7 +936,7 @@ MountmgrReadNoAutoMount(IN PUNICODE_STRING RegistryPath)
 NTSTATUS
 MountMgrMountedDeviceArrival(IN PDEVICE_EXTENSION DeviceExtension,
                              IN PUNICODE_STRING SymbolicName,
-                             IN BOOLEAN FromVolume)
+                             IN BOOLEAN ManuallyRegistered)
 {
     WCHAR Letter;
     GUID StableGuid;
@@ -958,7 +950,7 @@ MountMgrMountedDeviceArrival(IN PDEVICE_EXTENSION DeviceExtension,
     PMOUNTDEV_UNIQUE_ID UniqueId, NewUniqueId;
     PSAVED_LINK_INFORMATION SavedLinkInformation;
     PDEVICE_INFORMATION DeviceInformation, CurrentDevice;
-    WCHAR CSymLinkBuffer[MAX_PATH], LinkTargetBuffer[MAX_PATH];
+    WCHAR CSymLinkBuffer[RTL_NUMBER_OF(Cunc)], LinkTargetBuffer[MAX_PATH];
     UNICODE_STRING TargetDeviceName, SuggestedLinkName, DeviceName, VolumeName, DriveLetter, LinkTarget, CSymLink;
     BOOLEAN HasGuid, HasGptDriveLetter, Valid, UseOnlyIfThereAreNoOtherLinks, IsDrvLetter, IsOff, IsVolumeName, LinkError;
 
@@ -986,7 +978,7 @@ MountMgrMountedDeviceArrival(IN PDEVICE_EXTENSION DeviceExtension,
     /* Copy symbolic name */
     RtlCopyMemory(DeviceInformation->SymbolicName.Buffer, SymbolicName->Buffer, SymbolicName->Length);
     DeviceInformation->SymbolicName.Buffer[DeviceInformation->SymbolicName.Length / sizeof(WCHAR)] = UNICODE_NULL;
-    DeviceInformation->Volume = FromVolume;
+    DeviceInformation->ManuallyRegistered = ManuallyRegistered;
     DeviceInformation->DeviceExtension = DeviceExtension;
 
     /* Query as much data as possible about device */
@@ -1068,7 +1060,7 @@ MountMgrMountedDeviceArrival(IN PDEVICE_EXTENSION DeviceExtension,
     {
         CurrentDevice = CONTAINING_RECORD(NextEntry, DEVICE_INFORMATION, DeviceListEntry);
 
-        if (RtlEqualUnicodeString(&(DeviceInformation->DeviceName), &TargetDeviceName, TRUE))
+        if (RtlEqualUnicodeString(&(CurrentDevice->DeviceName), &TargetDeviceName, TRUE))
         {
             break;
         }
@@ -1115,11 +1107,11 @@ MountMgrMountedDeviceArrival(IN PDEVICE_EXTENSION DeviceExtension,
         /* Start checking all letters that could have been associated */
         for (Letter = L'D'; Letter <= L'Z'; Letter++)
         {
-            CSymLink.Buffer[LETTER_POSITION] = Letter;
+            CSymLink.Buffer[Cunc_LETTER_POSITION] = Letter;
 
             InitializeObjectAttributes(&ObjectAttributes,
                                        &CSymLink,
-                                       OBJ_CASE_INSENSITIVE,
+                                       OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE,
                                        NULL,
                                        NULL);
 
@@ -1327,9 +1319,9 @@ MountMgrMountedDeviceArrival(IN PDEVICE_EXTENSION DeviceExtension,
         DeviceInformation->SuggestedDriveLetter = 0;
     }
     /* Else, it's time to set up one */
-    else if (!DeviceExtension->NoAutoMount && !DeviceInformation->Removable &&
-             DeviceExtension->AutomaticDriveLetter && HasGptDriveLetter &&
-             DeviceInformation->SuggestedDriveLetter &&
+    else if ((DeviceExtension->NoAutoMount || DeviceInformation->Removable) &&
+             DeviceExtension->AutomaticDriveLetter &&
+             (HasGptDriveLetter || DeviceInformation->SuggestedDriveLetter) &&
              !HasNoDriveLetterEntry(UniqueId))
     {
         /* Create a new drive letter */
@@ -1368,8 +1360,8 @@ MountMgrMountedDeviceArrival(IN PDEVICE_EXTENSION DeviceExtension,
         }
     }
 
-    /* If required, register for notifications about the device */
-    if (!FromVolume)
+    /* If that's a PnP device, register for notifications */
+    if (!ManuallyRegistered)
     {
         RegisterForTargetDeviceNotification(DeviceExtension, DeviceInformation);
     }
@@ -1612,7 +1604,7 @@ MountMgrMountedDeviceRemoval(IN PDEVICE_EXTENSION DeviceExtension,
         }
     }
 
-    /* Releave driver */
+    /* Release driver */
     KeReleaseSemaphore(&(DeviceExtension->DeviceLock), IO_NO_INCREMENT, 1, FALSE);
 }
 
@@ -1663,6 +1655,8 @@ MountMgrCreateClose(IN PDEVICE_OBJECT DeviceObject,
     PIO_STACK_LOCATION Stack;
     NTSTATUS Status = STATUS_SUCCESS;
 
+    UNREFERENCED_PARAMETER(DeviceObject);
+
     Stack = IoGetCurrentIrpStackLocation(Irp);
 
     /* Allow driver opening for communication
@@ -1688,6 +1682,8 @@ NTAPI
 MountMgrCancel(IN PDEVICE_OBJECT DeviceObject,
                IN PIRP Irp)
 {
+    UNREFERENCED_PARAMETER(DeviceObject);
+
     RemoveEntryList(&(Irp->Tail.Overlay.ListEntry));
 
     IoReleaseCancelSpinLock(Irp->CancelIrql);
@@ -1731,7 +1727,7 @@ MountMgrCleanup(IN PDEVICE_OBJECT DeviceObject,
     }
 
     /* Otherwise, cancel all the IRPs */
-    NextEntry = &(DeviceExtension->IrpListHead);
+    NextEntry = DeviceExtension->IrpListHead.Flink;
     do
     {
         ListIrp = CONTAINING_RECORD(NextEntry, IRP, Tail.Overlay.ListEntry);
@@ -1775,7 +1771,7 @@ MountMgrShutdown(IN PDEVICE_OBJECT DeviceObject,
     KeInitializeEvent(&UnloadEvent, NotificationEvent, FALSE);
 
     /* Wait for workers */
-    if (InterlockedIncrement(&(DeviceExtension->WorkerReferences)))
+    if (InterlockedIncrement(&(DeviceExtension->WorkerReferences)) > 0)
     {
         KeReleaseSemaphore(&(DeviceExtension->WorkerSemaphore),
                            IO_NO_INCREMENT,
@@ -1797,6 +1793,7 @@ MountMgrShutdown(IN PDEVICE_OBJECT DeviceObject,
 
 /* FUNCTIONS ****************************************************************/
 
+INIT_FUNCTION
 NTSTATUS
 NTAPI
 DriverEntry(IN PDRIVER_OBJECT DriverObject,
@@ -1870,7 +1867,7 @@ DriverEntry(IN PDRIVER_OBJECT DriverObject,
                                             &MountedDevicesGuid,
                                             DriverObject,
                                             MountMgrMountedDeviceNotification,
-                                            DeviceObject,
+                                            DeviceExtension,
                                             &(DeviceExtension->NotificationEntry));
 
     if (!NT_SUCCESS(Status))