[MOUNTMGR]
authorPierre Schweitzer <pierre@reactos.org>
Sat, 19 Sep 2015 20:37:09 +0000 (20:37 +0000)
committerPierre Schweitzer <pierre@reactos.org>
Sat, 19 Sep 2015 20:37:09 +0000 (20:37 +0000)
Implement ReconcileThisDatabaseWithMasterWorker() which was the last missing bit of our MountMgr :-)

svn path=/trunk/; revision=69292

reactos/drivers/filters/mountmgr/database.c
reactos/drivers/filters/mountmgr/mntmgr.h
reactos/drivers/filters/mountmgr/mountmgr.c

index 5342861..7dacb97 100644 (file)
@@ -394,11 +394,750 @@ ReleaseRemoteDatabaseSemaphore(IN PDEVICE_EXTENSION DeviceExtension)
     KeReleaseSemaphore(&(DeviceExtension->RemoteDatabaseLock), IO_NO_INCREMENT, 1, FALSE);
 }
 
+/*
+ * @implemented
+ */
+NTSTATUS
+NTAPI
+QueryUniqueIdQueryRoutine(IN PWSTR ValueName,
+                          IN ULONG ValueType,
+                          IN PVOID ValueData,
+                          IN ULONG ValueLength,
+                          IN PVOID Context,
+                          IN PVOID EntryContext)
+{
+    PMOUNTDEV_UNIQUE_ID IntUniqueId;
+    PMOUNTDEV_UNIQUE_ID * UniqueId;
+
+    UNREFERENCED_PARAMETER(ValueName);
+    UNREFERENCED_PARAMETER(ValueType);
+    UNREFERENCED_PARAMETER(EntryContext);
+
+    /* Sanity check */
+    if (ValueLength >= 0x10000)
+    {
+        return STATUS_SUCCESS;
+    }
+
+    /* Allocate the Unique ID */
+    IntUniqueId = AllocatePool(sizeof(UniqueId) + ValueLength);
+    if (IntUniqueId)
+    {
+        /* Copy data & return */
+        IntUniqueId->UniqueIdLength = (USHORT)ValueLength;
+        RtlCopyMemory(&(IntUniqueId->UniqueId), ValueData, ValueLength);
+
+        UniqueId = Context;
+        *UniqueId = IntUniqueId;
+    }
+
+    return STATUS_SUCCESS;
+}
+
+/*
+ * @implemented
+ */
+NTSTATUS
+QueryUniqueIdFromMaster(IN PDEVICE_EXTENSION DeviceExtension,
+                        IN PUNICODE_STRING SymbolicName,
+                        OUT PMOUNTDEV_UNIQUE_ID * UniqueId)
+{
+    NTSTATUS Status;
+    PDEVICE_INFORMATION DeviceInformation;
+    RTL_QUERY_REGISTRY_TABLE QueryTable[2];
+
+    /* Query the unique ID */
+    RtlZeroMemory(QueryTable, sizeof(QueryTable));
+    QueryTable[0].QueryRoutine = QueryUniqueIdQueryRoutine;
+    QueryTable[0].Name = SymbolicName->Buffer;
+
+    *UniqueId = NULL;
+    RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE,
+                           DatabasePath,
+                           QueryTable,
+                           UniqueId,
+                           NULL);
+    /* Unique ID found, no need to go farther */
+    if (*UniqueId)
+    {
+        return STATUS_SUCCESS;
+    }
+
+    /* Otherwise, find associate device information */
+    Status = FindDeviceInfo(DeviceExtension, SymbolicName, FALSE, &DeviceInformation);
+    if (!NT_SUCCESS(Status))
+    {
+        return Status;
+    }
+
+    *UniqueId = AllocatePool(DeviceInformation->UniqueId->UniqueIdLength + sizeof(MOUNTDEV_UNIQUE_ID));
+    if (!*UniqueId)
+    {
+        return STATUS_INSUFFICIENT_RESOURCES;
+    }
+
+    /* Return this unique ID (better than nothing) */
+    (*UniqueId)->UniqueIdLength = DeviceInformation->UniqueId->UniqueIdLength;
+    RtlCopyMemory(&((*UniqueId)->UniqueId), &(DeviceInformation->UniqueId->UniqueId), (*UniqueId)->UniqueIdLength);
+
+    return STATUS_SUCCESS;
+}
+
+/*
+ * @implemented
+ */
+NTSTATUS
+WriteUniqueIdToMaster(IN PDEVICE_EXTENSION DeviceExtension,
+                      IN PDATABASE_ENTRY DatabaseEntry)
+{
+    NTSTATUS Status;
+    PWCHAR SymbolicName;
+    PLIST_ENTRY NextEntry;
+    UNICODE_STRING SymbolicString;
+    PDEVICE_INFORMATION DeviceInformation;
+
+    /* Create symbolic name from database entry */
+    SymbolicName = AllocatePool(DatabaseEntry->SymbolicNameLength + sizeof(WCHAR));
+    if (!SymbolicName)
+    {
+        return STATUS_INSUFFICIENT_RESOURCES;
+    }
+
+    RtlCopyMemory(SymbolicName,
+                  (PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->SymbolicNameOffset),
+                  DatabaseEntry->SymbolicNameLength);
+    SymbolicName[DatabaseEntry->SymbolicNameLength / sizeof(WCHAR)] = UNICODE_NULL;
+
+    /* Associate the unique ID with the name from remote database */
+    Status = RtlWriteRegistryValue(RTL_REGISTRY_ABSOLUTE,
+                                   DatabasePath,
+                                   SymbolicName,
+                                   REG_BINARY,
+                                   (PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->UniqueIdOffset),
+                                   DatabaseEntry->UniqueIdLength);
+    FreePool(SymbolicName);
+
+    /* Reget symbolic name */
+    SymbolicString.Length = DatabaseEntry->SymbolicNameLength;
+    SymbolicString.MaximumLength = DatabaseEntry->SymbolicNameLength;
+    SymbolicString.Buffer = (PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->SymbolicNameOffset);
+
+    /* Find the device using this unique ID */
+    for (NextEntry = DeviceExtension->DeviceListHead.Flink;
+         NextEntry != &(DeviceExtension->DeviceListHead);
+         NextEntry = NextEntry->Flink)
+    {
+        DeviceInformation = CONTAINING_RECORD(NextEntry,
+                                              DEVICE_INFORMATION,
+                                              DeviceListEntry);
+
+        if (DeviceInformation->UniqueId->UniqueIdLength != DatabaseEntry->UniqueIdLength)
+        {
+            continue;
+        }
+
+        if (RtlCompareMemory((PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->UniqueIdOffset),
+                             DeviceInformation->UniqueId->UniqueId,
+                             DatabaseEntry->UniqueIdLength) == DatabaseEntry->UniqueIdLength)
+        {
+            break;
+        }
+    }
+
+    /* If found, create a mount point */
+    if (NextEntry != &(DeviceExtension->DeviceListHead))
+    {
+        MountMgrCreatePointWorker(DeviceExtension, &SymbolicString, &(DeviceInformation->DeviceName));
+    }
+
+    return Status;
+}
+
+/*
+ * @implemented
+ */
 VOID
 NTAPI
 ReconcileThisDatabaseWithMasterWorker(IN PVOID Parameter)
 {
-    UNREFERENCED_PARAMETER(Parameter);
+    ULONG Offset;
+    NTSTATUS Status;
+    PFILE_OBJECT FileObject;
+    PDEVICE_OBJECT DeviceObject;
+    PMOUNTDEV_UNIQUE_ID UniqueId;
+    PDATABASE_ENTRY DatabaseEntry;
+    HANDLE DatabaseHandle, Handle;
+    IO_STATUS_BLOCK IoStatusBlock;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    PDEVICE_INFORMATION ListDeviceInfo;
+    PLIST_ENTRY Entry, EntryInfo, NextEntry;
+    PASSOCIATED_DEVICE_ENTRY AssociatedDevice;
+    BOOLEAN HardwareErrors, Restart, FailedFinding;
+    WCHAR FileNameBuffer[0x8], SymbolicNameBuffer[100];
+    UNICODE_STRING ReparseFile, FileName, SymbolicName, VolumeName;
+    FILE_REPARSE_POINT_INFORMATION ReparsePointInformation, SavedReparsePointInformation;
+    PDEVICE_EXTENSION DeviceExtension = ((PRECONCILE_WORK_ITEM_CONTEXT)Parameter)->DeviceExtension;
+    PDEVICE_INFORMATION DeviceInformation = ((PRECONCILE_WORK_ITEM_CONTEXT)Parameter)->DeviceInformation;
+
+    /* We're unloading, do nothing */
+    if (Unloading)
+    {
+        return;
+    }
+
+    /* Lock remote DB */
+    if (!NT_SUCCESS(WaitForRemoteDatabaseSemaphore(DeviceExtension)))
+    {
+        return;
+    }
+
+    /* Recheck for unloading */
+    if (Unloading)
+    {
+        goto ReleaseRDS;
+    }
+
+    /* Find the DB to reconcile */
+    KeWaitForSingleObject(&DeviceExtension->DeviceLock, Executive, KernelMode, FALSE, NULL);
+    for (Entry = DeviceExtension->DeviceListHead.Flink;
+         Entry != &DeviceExtension->DeviceListHead;
+         Entry = Entry->Flink)
+    {
+        ListDeviceInfo = CONTAINING_RECORD(Entry, DEVICE_INFORMATION, DeviceListEntry);
+        if (ListDeviceInfo == DeviceInformation)
+        {
+            break;
+        }
+    }
+
+    /* If not found, or if removable, bail out */
+    if (Entry == &DeviceExtension->DeviceListHead || DeviceInformation->Removable)
+    {
+        KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+        goto ReleaseRDS;
+    }
+
+    /* Get our device object */
+    Status = IoGetDeviceObjectPointer(&ListDeviceInfo->DeviceName, FILE_READ_ATTRIBUTES, &FileObject, &DeviceObject);
+    if (!NT_SUCCESS(Status))
+    {
+        KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+        goto ReleaseRDS;
+    }
+
+    if (DeviceObject->Flags & 1)
+    {
+        _InterlockedExchangeAdd(&ListDeviceInfo->MountState, 1u);
+    }
+
+    ObDereferenceObject(FileObject);
+
+    /* Force default: no DB, and need for reconcile */
+    DeviceInformation->NeedsReconcile = TRUE;
+    DeviceInformation->NoDatabase = TRUE;
+    FailedFinding = FALSE;
+
+    /* Remove any associated device that refers to the DB to reconcile */
+    for (Entry = DeviceExtension->DeviceListHead.Flink;
+         Entry != &DeviceExtension->DeviceListHead;
+         Entry = Entry->Flink)
+    {
+        ListDeviceInfo = CONTAINING_RECORD(Entry, DEVICE_INFORMATION, DeviceListEntry);
+
+        EntryInfo = ListDeviceInfo->AssociatedDevicesHead.Flink;
+        while (EntryInfo != &ListDeviceInfo->AssociatedDevicesHead)
+        {
+            AssociatedDevice = CONTAINING_RECORD(EntryInfo, ASSOCIATED_DEVICE_ENTRY, AssociatedDevicesEntry);
+            NextEntry = EntryInfo->Flink;
+
+            if (AssociatedDevice->DeviceInformation == DeviceInformation)
+            {
+                RemoveEntryList(&AssociatedDevice->AssociatedDevicesEntry);
+                FreePool(AssociatedDevice->String.Buffer);
+                FreePool(AssociatedDevice);
+            }
+
+            EntryInfo = NextEntry;
+        }
+    }
+
+    /* Open the remote database */
+    DatabaseHandle = OpenRemoteDatabase(DeviceInformation, FALSE);
+
+    /* Prepare a string with reparse point index */
+    ReparseFile.Length = DeviceInformation->DeviceName.Length + ReparseIndex.Length;
+    ReparseFile.MaximumLength = ReparseFile.Length + sizeof(UNICODE_NULL);
+    ReparseFile.Buffer = AllocatePool(ReparseFile.MaximumLength);
+    if (ReparseFile.Buffer == NULL)
+    {
+        if (DatabaseHandle != 0)
+        {
+            CloseRemoteDatabase(DatabaseHandle);
+        }
+        KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+        goto ReleaseRDS;
+    }
+
+    
+    RtlCopyMemory(ReparseFile.Buffer, DeviceInformation->DeviceName.Buffer,
+                  DeviceInformation->DeviceName.Length);
+    RtlCopyMemory((PVOID)((ULONG_PTR)ReparseFile.Buffer + DeviceInformation->DeviceName.Length),
+                  ReparseFile.Buffer, ReparseFile.Length);
+    ReparseFile.Buffer[ReparseFile.Length / sizeof(WCHAR)] = UNICODE_NULL;
+
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &ReparseFile,
+                               OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE,
+                               NULL,
+                               NULL);
+
+    /* Open reparse point directory */
+    HardwareErrors = IoSetThreadHardErrorMode(FALSE);
+    Status = ZwOpenFile(&Handle,
+                        FILE_GENERIC_READ,
+                        &ObjectAttributes,
+                        &IoStatusBlock,
+                        FILE_SHARE_READ | FILE_SHARE_WRITE,
+                        FILE_SYNCHRONOUS_IO_ALERT);
+    IoSetThreadHardErrorMode(HardwareErrors);
+
+    FreePool(ReparseFile.Buffer);
+
+    if (!NT_SUCCESS(Status))
+    {
+        if (DatabaseHandle != 0)
+        {
+            TruncateRemoteDatabase(DatabaseHandle, 0);
+            CloseRemoteDatabase(DatabaseHandle);
+        }
+        KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+        goto ReleaseRDS;
+    }
+
+    /* Query reparse point information
+     * We only pay attention to mout point
+     */
+    RtlZeroMemory(FileNameBuffer, sizeof(FileNameBuffer));
+    FileName.Buffer = FileNameBuffer;
+    FileName.Length = sizeof(FileNameBuffer);
+    FileName.MaximumLength = sizeof(FileNameBuffer);
+    ((PULONG)FileNameBuffer)[0] = IO_REPARSE_TAG_MOUNT_POINT;
+    Status = ZwQueryDirectoryFile(Handle,
+                                  NULL,
+                                  NULL,
+                                  NULL,
+                                  &IoStatusBlock,
+                                  &ReparsePointInformation,
+                                  sizeof(FILE_REPARSE_POINT_INFORMATION),
+                                  FileReparsePointInformation,
+                                  TRUE,
+                                  &FileName,
+                                  FALSE);
+    if (!NT_SUCCESS(Status))
+    {
+        ZwClose(Handle);
+        if (DatabaseHandle != 0)
+        {
+            TruncateRemoteDatabase(DatabaseHandle, 0);
+            CloseRemoteDatabase(DatabaseHandle);
+        }
+        KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+        goto ReleaseRDS;
+    }
+
+    /* If we failed to open the remote DB previously,
+     * retry this time allowing migration (and thus, creation if required)
+     */
+    if (DatabaseHandle == 0)
+    {
+        DatabaseHandle = OpenRemoteDatabase(DeviceInformation, TRUE);
+        if (DatabaseHandle == 0)
+        {
+            KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+            goto ReleaseRDS;
+        }
+    }
+
+    KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+
+    /* Reset all the references to our DB entries */
+    Offset = 0;
+    for (;;)
+    {
+        DatabaseEntry = GetRemoteDatabaseEntry(DatabaseHandle, Offset);
+        if (DatabaseEntry == NULL)
+        {
+            break;
+        }
+
+        DatabaseEntry->EntryReferences = 0;
+        Status = WriteRemoteDatabaseEntry(DatabaseHandle, Offset, DatabaseEntry);
+        if (!NT_SUCCESS(Status))
+        {
+            FreePool(DatabaseEntry);
+            goto CloseReparse;
+        }
+
+        Offset += DatabaseEntry->EntrySize;
+        FreePool(DatabaseEntry);
+    }
+
+    /* Init string for QueryVolumeName call */
+    SymbolicName.MaximumLength = sizeof(SymbolicNameBuffer);
+    SymbolicName.Length = 0;
+    SymbolicName.Buffer = SymbolicNameBuffer;
+    Restart = TRUE;
+
+    /* Start looping on reparse points */
+    for (;;)
+    {
+        RtlCopyMemory(&SavedReparsePointInformation, &ReparsePointInformation, sizeof(FILE_REPARSE_POINT_INFORMATION));
+        Status = ZwQueryDirectoryFile(Handle,
+                                      NULL,
+                                      NULL,
+                                      NULL,
+                                      &IoStatusBlock,
+                                      &ReparsePointInformation,
+                                      sizeof(FILE_REPARSE_POINT_INFORMATION),
+                                      FileReparsePointInformation,
+                                      TRUE,
+                                      Restart ? &FileName : NULL,
+                                      Restart);
+        /* Restart only once */
+        if (Restart)
+        {
+            Restart = FALSE;
+        }
+        else
+        {
+            /* If we get the same one, we're done, bail out */
+            if (ReparsePointInformation.FileReference == SavedReparsePointInformation.FileReference &&
+                ReparsePointInformation.Tag == SavedReparsePointInformation.Tag)
+            {
+                break;
+            }
+        }
+
+        /* If querying failed, or if onloading, or if not returning mount points, bail out */
+        if (!NT_SUCCESS(Status) || Unloading || ReparsePointInformation.Tag != IO_REPARSE_TAG_MOUNT_POINT)
+        {
+            break;
+        }
+
+        /* Get the volume name associated to the mount point */
+        Status = QueryVolumeName(Handle, &ReparsePointInformation, 0, &SymbolicName, &VolumeName);
+        if (!NT_SUCCESS(Status))
+        {
+            continue;
+        }
+
+        /* Browse the DB to find the name */
+        Offset = 0;
+        for (;;)
+        {
+            UNICODE_STRING DbName;
+
+            DatabaseEntry = GetRemoteDatabaseEntry(DatabaseHandle, Offset);
+            if (DatabaseEntry == NULL)
+            {
+                break;
+            }
+
+            DbName.MaximumLength = DatabaseEntry->SymbolicNameLength;
+            DbName.Length = DbName.MaximumLength;
+            DbName.Buffer = (PWSTR)((ULONG_PTR)DatabaseEntry + DatabaseEntry->SymbolicNameOffset);
+            /* Found, we're done! */
+            if (RtlEqualUnicodeString(&DbName, &SymbolicName, TRUE))
+            {
+                break;
+            }
+
+            Offset += DatabaseEntry->EntrySize;
+            FreePool(DatabaseEntry);
+        }
+
+        /* If we found the mount point.... */
+        if (DatabaseEntry != NULL)
+        {
+            /* If it was referenced, reference it once more and update to remote */
+            if (DatabaseEntry->EntryReferences)
+            {
+                ++DatabaseEntry->EntryReferences;
+                Status = WriteRemoteDatabaseEntry(DatabaseHandle, Offset, DatabaseEntry);
+                if (!NT_SUCCESS(Status))
+                {
+                    goto FreeDBEntry;
+                }
+
+                FreePool(DatabaseEntry);
+            }
+            else
+            {
+                /* Query the Unique ID associated to that mount point in case it changed */
+                KeWaitForSingleObject(&DeviceExtension->DeviceLock, Executive, KernelMode, FALSE, NULL);
+                Status = QueryUniqueIdFromMaster(DeviceExtension, &SymbolicName, &UniqueId);
+                if (!NT_SUCCESS(Status))
+                {
+                    /* If we failed doing so, reuse the old Unique ID and push it to master */
+                    Status = WriteUniqueIdToMaster(DeviceExtension, DatabaseEntry);
+                    if (!NT_SUCCESS(Status))
+                    {
+                        goto ReleaseDeviceLock;
+                    }
+
+                    /* And then, reference & write the entry */
+                    ++DatabaseEntry->EntryReferences;
+                    Status = WriteRemoteDatabaseEntry(DatabaseHandle, Offset, DatabaseEntry);
+                    if (!NT_SUCCESS(Status))
+                    {
+                        goto ReleaseDeviceLock;
+                    }
+
+                    KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+                    FreePool(DatabaseEntry);
+                }
+                /* If the Unique ID didn't change */
+                else if (UniqueId->UniqueIdLength == DatabaseEntry->UniqueIdLength &&
+                         RtlCompareMemory(UniqueId->UniqueId,
+                                          (PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->UniqueIdOffset),
+                                          UniqueId->UniqueIdLength) == UniqueId->UniqueIdLength)
+                {
+                    /* Reference the entry, and update to remote */
+                    ++DatabaseEntry->EntryReferences;
+                    Status = WriteRemoteDatabaseEntry(DatabaseHandle, Offset, DatabaseEntry);
+                    if (!NT_SUCCESS(Status))
+                    {
+                        goto FreeUniqueId;
+                    }
+
+                    FreePool(UniqueId);
+                    KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+                    FreePool(DatabaseEntry);
+                }
+                /* Would, by chance, the Unique ID be present elsewhere? */
+                else if (IsUniqueIdPresent(DeviceExtension, DatabaseEntry))
+                {
+                    /* Push the ID to master */
+                    Status = WriteUniqueIdToMaster(DeviceExtension, DatabaseEntry);
+                    if (!NT_SUCCESS(Status))
+                    {
+                        goto FreeUniqueId;
+                    }
+
+                    /* And then, reference & write the entry */
+                    ++DatabaseEntry->EntryReferences;
+                    Status = WriteRemoteDatabaseEntry(DatabaseHandle, Offset, DatabaseEntry);
+                    if (!NT_SUCCESS(Status))
+                    {
+                        goto FreeUniqueId;
+                    }
+
+                    FreePool(UniqueId);
+                    KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+                    FreePool(DatabaseEntry);
+                }
+                else
+                {
+                    /* OK, at that point, we're facing a totally unknown unique ID
+                     * So, get rid of the old entry, and recreate a new one with
+                     * the know unique ID
+                     */
+                    Status = DeleteRemoteDatabaseEntry(DatabaseHandle, Offset);
+                    if (!NT_SUCCESS(Status))
+                    {
+                        goto FreeUniqueId;
+                    }
+
+                    FreePool(DatabaseEntry);
+                    /* Allocate a new entry big enough */
+                    DatabaseEntry = AllocatePool(UniqueId->UniqueIdLength + SymbolicName.Length + sizeof(DATABASE_ENTRY));
+                    if (DatabaseEntry == NULL)
+                    {
+                       goto FreeUniqueId;
+                    }
+
+                    /* Configure it */
+                    DatabaseEntry->EntrySize = UniqueId->UniqueIdLength + SymbolicName.Length + sizeof(DATABASE_ENTRY);
+                    DatabaseEntry->EntryReferences = 1;
+                    DatabaseEntry->SymbolicNameOffset = sizeof(DATABASE_ENTRY);
+                    DatabaseEntry->SymbolicNameLength = SymbolicName.Length;
+                    DatabaseEntry->UniqueIdOffset = SymbolicName.Length + sizeof(DATABASE_ENTRY);
+                    DatabaseEntry->UniqueIdLength = UniqueId->UniqueIdLength;
+                    RtlCopyMemory((PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->SymbolicNameOffset), SymbolicName.Buffer, DatabaseEntry->SymbolicNameLength);
+                    RtlCopyMemory((PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->UniqueIdOffset), UniqueId->UniqueId, UniqueId->UniqueIdLength);
+
+                    /* And write it remotely */
+                    Status = AddRemoteDatabaseEntry(DatabaseHandle, DatabaseEntry);
+                    if (!NT_SUCCESS(Status))
+                    {
+                       FreePool(DatabaseEntry);
+                       goto FreeUniqueId;
+                    }
+
+                    FreePool(UniqueId);
+                    FreePool(DatabaseEntry);
+                    KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+                }
+            }
+        }
+        else
+        {
+            /* We failed finding it remotely
+             * So, let's allocate a new remote DB entry
+             */
+            KeWaitForSingleObject(&DeviceExtension->DeviceLock, Executive, KernelMode, FALSE, NULL);
+            /* To be able to do so, we need the device Unique ID, ask master */
+            Status = QueryUniqueIdFromMaster(DeviceExtension, &SymbolicName, &UniqueId);
+            KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+            if (NT_SUCCESS(Status))
+            {
+                /* Allocate a new entry big enough */
+                DatabaseEntry = AllocatePool(UniqueId->UniqueIdLength + SymbolicName.Length + sizeof(DATABASE_ENTRY));
+                if (DatabaseEntry != NULL)
+                {
+                    /* Configure it */
+                    DatabaseEntry->EntrySize = UniqueId->UniqueIdLength + SymbolicName.Length + sizeof(DATABASE_ENTRY);
+                    DatabaseEntry->EntryReferences = 1;
+                    DatabaseEntry->SymbolicNameOffset = sizeof(DATABASE_ENTRY);
+                    DatabaseEntry->SymbolicNameLength = SymbolicName.Length;
+                    DatabaseEntry->UniqueIdOffset = SymbolicName.Length + sizeof(DATABASE_ENTRY);
+                    DatabaseEntry->UniqueIdLength = UniqueId->UniqueIdLength;
+                    RtlCopyMemory((PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->SymbolicNameOffset), SymbolicName.Buffer, DatabaseEntry->SymbolicNameLength);
+                    RtlCopyMemory((PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->UniqueIdOffset), UniqueId->UniqueId, UniqueId->UniqueIdLength);
+
+                    /* And write it remotely */
+                    Status = AddRemoteDatabaseEntry(DatabaseHandle, DatabaseEntry);
+                    FreePool(DatabaseEntry);
+                    FreePool(UniqueId);
+
+                    if (!NT_SUCCESS(Status))
+                    {
+                        goto FreeVolume;
+                    }
+                }
+                else
+                {
+                    FreePool(UniqueId);
+                }
+            }
+        }
+
+        /* Find info about the device associated associated with the mount point */
+        KeWaitForSingleObject(&DeviceExtension->DeviceLock, Executive, KernelMode, FALSE, NULL);
+        Status = FindDeviceInfo(DeviceExtension, &SymbolicName, FALSE, &ListDeviceInfo);
+        if (!NT_SUCCESS(Status))
+        {
+            FailedFinding = TRUE;
+            FreePool(VolumeName.Buffer);
+        }
+        else
+        {
+            /* Associate the device with the currrent DB */
+            AssociatedDevice = AllocatePool(sizeof(ASSOCIATED_DEVICE_ENTRY));
+            if (AssociatedDevice == NULL)
+            {
+                FreePool(VolumeName.Buffer);
+            }
+            else
+            {
+                AssociatedDevice->DeviceInformation = DeviceInformation;
+                AssociatedDevice->String.Length = VolumeName.Length;
+                AssociatedDevice->String.MaximumLength = VolumeName.MaximumLength;
+                AssociatedDevice->String.Buffer = VolumeName.Buffer;
+                InsertTailList(&ListDeviceInfo->AssociatedDevicesHead, &AssociatedDevice->AssociatedDevicesEntry);
+            }
+
+            /* If we don't have to skip notifications, notify */
+            if (!ListDeviceInfo->SkipNotifications)
+            {
+                PostOnlineNotification(DeviceExtension, &ListDeviceInfo->SymbolicName);
+            }
+        }
+
+        KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+    }
+
+    /* We don't need mount points any longer */
+    ZwClose(Handle);
+
+    /* Look for the DB again */
+    KeWaitForSingleObject(&DeviceExtension->DeviceLock, Executive, KernelMode, FALSE, NULL);
+    for (Entry = DeviceExtension->DeviceListHead.Flink;
+         Entry != &DeviceExtension->DeviceListHead;
+         Entry = Entry->Flink)
+    {
+        ListDeviceInfo = CONTAINING_RECORD(Entry, DEVICE_INFORMATION, DeviceListEntry);
+        if (ListDeviceInfo == DeviceInformation)
+        {
+            break;
+        }
+    }
+
+    if (Entry == &DeviceExtension->DeviceListHead)
+    {
+        ListDeviceInfo = NULL;
+    }
+
+    /* Start the pruning loop */
+    Offset = 0;
+    for (;;)
+    {
+        /* Get the entry */
+        DatabaseEntry = GetRemoteDatabaseEntry(DatabaseHandle, Offset);
+        if (DatabaseEntry == NULL)
+        {
+            break;
+        }
+
+        /* It's not referenced anylonger? Prune it */
+        if (DatabaseEntry->EntryReferences == 0)
+        {
+            Status = DeleteRemoteDatabaseEntry(DatabaseHandle, Offset);
+            if (!NT_SUCCESS(Status))
+            {
+                FreePool(DatabaseEntry);
+                KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+                goto CloseRDB;
+            }
+        }
+        /* Update the Unique IDs to reflect the changes we might have done previously */
+        else
+        {
+            if (ListDeviceInfo != NULL)
+            {
+                UpdateReplicatedUniqueIds(ListDeviceInfo, DatabaseEntry);
+            }
+
+            Offset += DatabaseEntry->EntrySize;
+        }
+
+        FreePool(DatabaseEntry);
+    }
+
+    /* We do have a DB now :-) */
+    if (ListDeviceInfo != NULL && !FailedFinding)
+    {
+        DeviceInformation->NoDatabase = FALSE;
+    }
+
+    KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+
+    goto CloseRDB;
+
+FreeUniqueId:
+    FreePool(UniqueId);
+ReleaseDeviceLock:
+    KeReleaseSemaphore(&DeviceExtension->DeviceLock, IO_NO_INCREMENT, 1, FALSE);
+FreeDBEntry:
+    FreePool(DatabaseEntry);
+FreeVolume:
+    FreePool(VolumeName.Buffer);
+CloseReparse:
+    ZwClose(Handle);
+CloseRDB:
+    CloseRemoteDatabase(DatabaseHandle);
+ReleaseRDS:
+    ReleaseRemoteDatabaseSemaphore(DeviceExtension);
     return;
 }
 
@@ -1247,165 +1986,6 @@ OpenRemoteDatabase(IN PDEVICE_INFORMATION DeviceInformation,
     return Database;
 }
 
-/*
- * @implemented
- */
-NTSTATUS
-NTAPI
-QueryUniqueIdQueryRoutine(IN PWSTR ValueName,
-                          IN ULONG ValueType,
-                          IN PVOID ValueData,
-                          IN ULONG ValueLength,
-                          IN PVOID Context,
-                          IN PVOID EntryContext)
-{
-    PMOUNTDEV_UNIQUE_ID IntUniqueId;
-    PMOUNTDEV_UNIQUE_ID * UniqueId;
-
-    UNREFERENCED_PARAMETER(ValueName);
-    UNREFERENCED_PARAMETER(ValueType);
-    UNREFERENCED_PARAMETER(EntryContext);
-
-    /* Sanity check */
-    if (ValueLength >= 0x10000)
-    {
-        return STATUS_SUCCESS;
-    }
-
-    /* Allocate the Unique ID */
-    IntUniqueId = AllocatePool(sizeof(UniqueId) + ValueLength);
-    if (IntUniqueId)
-    {
-        /* Copy data & return */
-        IntUniqueId->UniqueIdLength = (USHORT)ValueLength;
-        RtlCopyMemory(&(IntUniqueId->UniqueId), ValueData, ValueLength);
-
-        UniqueId = Context;
-        *UniqueId = IntUniqueId;
-    }
-
-    return STATUS_SUCCESS;
-}
-
-/*
- * @implemented
- */
-NTSTATUS
-QueryUniqueIdFromMaster(IN PDEVICE_EXTENSION DeviceExtension,
-                        IN PUNICODE_STRING SymbolicName,
-                        OUT PMOUNTDEV_UNIQUE_ID * UniqueId)
-{
-    NTSTATUS Status;
-    PDEVICE_INFORMATION DeviceInformation;
-    RTL_QUERY_REGISTRY_TABLE QueryTable[2];
-
-    /* Query the unique ID */
-    RtlZeroMemory(QueryTable, sizeof(QueryTable));
-    QueryTable[0].QueryRoutine = QueryUniqueIdQueryRoutine;
-    QueryTable[0].Name = SymbolicName->Buffer;
-
-    *UniqueId = NULL;
-    RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE,
-                           DatabasePath,
-                           QueryTable,
-                           UniqueId,
-                           NULL);
-    /* Unique ID found, no need to go farther */
-    if (*UniqueId)
-    {
-        return STATUS_SUCCESS;
-    }
-
-    /* Otherwise, find associate device information */
-    Status = FindDeviceInfo(DeviceExtension, SymbolicName, FALSE, &DeviceInformation);
-    if (!NT_SUCCESS(Status))
-    {
-        return Status;
-    }
-
-    *UniqueId = AllocatePool(DeviceInformation->UniqueId->UniqueIdLength + sizeof(MOUNTDEV_UNIQUE_ID));
-    if (!*UniqueId)
-    {
-        return STATUS_INSUFFICIENT_RESOURCES;
-    }
-
-    /* Return this unique ID (better than nothing) */
-    (*UniqueId)->UniqueIdLength = DeviceInformation->UniqueId->UniqueIdLength;
-    RtlCopyMemory(&((*UniqueId)->UniqueId), &(DeviceInformation->UniqueId->UniqueId), (*UniqueId)->UniqueIdLength);
-
-    return STATUS_SUCCESS;
-}
-
-/*
- * @implemented
- */
-NTSTATUS
-WriteUniqueIdToMaster(IN PDEVICE_EXTENSION DeviceExtension,
-                      IN PDATABASE_ENTRY DatabaseEntry)
-{
-    NTSTATUS Status;
-    PWCHAR SymbolicName;
-    PLIST_ENTRY NextEntry;
-    UNICODE_STRING SymbolicString;
-    PDEVICE_INFORMATION DeviceInformation;
-
-    /* Create symbolic name from database entry */
-    SymbolicName = AllocatePool(DatabaseEntry->SymbolicNameLength + sizeof(WCHAR));
-    if (!SymbolicName)
-    {
-        return STATUS_INSUFFICIENT_RESOURCES;
-    }
-
-    RtlCopyMemory(SymbolicName,
-                  (PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->SymbolicNameOffset),
-                  DatabaseEntry->SymbolicNameLength);
-    SymbolicName[DatabaseEntry->SymbolicNameLength / sizeof(WCHAR)] = UNICODE_NULL;
-
-    /* Associate the unique ID with the name from remote database */
-    Status = RtlWriteRegistryValue(RTL_REGISTRY_ABSOLUTE,
-                                   DatabasePath,
-                                   SymbolicName,
-                                   REG_BINARY,
-                                   (PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->UniqueIdOffset),
-                                   DatabaseEntry->UniqueIdLength);
-    FreePool(SymbolicName);
-
-    /* Reget symbolic name */
-    SymbolicString.Length = DatabaseEntry->SymbolicNameLength;
-    SymbolicString.MaximumLength = DatabaseEntry->SymbolicNameLength;
-    SymbolicString.Buffer = (PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->SymbolicNameOffset);
-
-    /* Find the device using this unique ID */
-    for (NextEntry = DeviceExtension->DeviceListHead.Flink;
-         NextEntry != &(DeviceExtension->DeviceListHead);
-         NextEntry = NextEntry->Flink)
-    {
-        DeviceInformation = CONTAINING_RECORD(NextEntry,
-                                              DEVICE_INFORMATION,
-                                              DeviceListEntry);
-
-        if (DeviceInformation->UniqueId->UniqueIdLength != DatabaseEntry->UniqueIdLength)
-        {
-            continue;
-        }
-
-        if (RtlCompareMemory((PVOID)((ULONG_PTR)DatabaseEntry + DatabaseEntry->UniqueIdOffset),
-                             DeviceInformation->UniqueId->UniqueId,
-                             DatabaseEntry->UniqueIdLength) == DatabaseEntry->UniqueIdLength)
-        {
-            break;
-        }
-    }
-
-    /* If found, create a mount point */
-    if (NextEntry != &(DeviceExtension->DeviceListHead))
-    {
-        MountMgrCreatePointWorker(DeviceExtension, &SymbolicString, &(DeviceInformation->DeviceName));
-    }
-
-    return Status;
-}
-
 /*
  * @implemented
  */
index 01d12e3..c243169 100644 (file)
@@ -420,6 +420,18 @@ HasNoDriveLetterEntry(
     IN PMOUNTDEV_UNIQUE_ID UniqueId
 );
 
+VOID
+UpdateReplicatedUniqueIds(
+    IN PDEVICE_INFORMATION DeviceInformation,
+    IN PDATABASE_ENTRY DatabaseEntry 
+);
+
+BOOLEAN
+IsUniqueIdPresent(
+    IN PDEVICE_EXTENSION DeviceExtension,
+    IN PDATABASE_ENTRY DatabaseEntry 
+);
+
 /* point.c */
 NTSTATUS
 MountMgrCreatePointWorker(
index dd2e4f0..df02f0f 100644 (file)
@@ -43,11 +43,6 @@ LONG Unloading;
 
 static const WCHAR Cunc[] = L"\\??\\C:";
 
-/*
- * TODO:
- * - ReconcileThisDatabaseWithMasterWorker
- */
-
 /*
  * @implemented
  */