Fix bunch of bugs in GetVolumeNameForVolumeMountPointW. Thanks to w3seek.
[reactos.git] / reactos / lib / kernel32 / file / volume.c
index 27a8ea3..d0cf382 100644 (file)
@@ -13,7 +13,7 @@
  */
 //WINE copyright notice:
 /*
- * DOS drives handling functions 
+ * DOS drives handling functions
  *
  * Copyright 1993 Erik Bos
  * Copyright 1996 Alexandre Julliard
@@ -195,7 +195,7 @@ GetDiskFreeSpaceA (
       if (!(RootPathNameW = FilenameA2W(lpRootPathName, FALSE)))
          return FALSE;
    }
-      
+
        return GetDiskFreeSpaceW (RootPathNameW,
                                    lpSectorsPerCluster,
                                    lpBytesPerSector,
@@ -324,7 +324,7 @@ GetDiskFreeSpaceExW(
     {
         return FALSE;
     }
-   
+
     errCode = NtQueryVolumeInformationFile(hFile,
                                            &IoStatusBlock,
                                            &FileFsSize,
@@ -344,7 +344,7 @@ GetDiskFreeSpaceExW(
        if (lpFreeBytesAvailableToCaller)
         lpFreeBytesAvailableToCaller->QuadPart =
             BytesPerCluster.QuadPart * FileFsSize.AvailableAllocationUnits.QuadPart;
-       
+
        if (lpTotalNumberOfBytes)
         lpTotalNumberOfBytes->QuadPart =
             BytesPerCluster.QuadPart * FileFsSize.TotalAllocationUnits.QuadPart;
@@ -365,7 +365,7 @@ UINT STDCALL
 GetDriveTypeA(LPCSTR lpRootPathName)
 {
    PWCHAR RootPathNameW;
-   
+
    if (!(RootPathNameW = FilenameA2W(lpRootPathName, FALSE)))
       return DRIVE_UNKNOWN;
 
@@ -400,7 +400,7 @@ GetDriveTypeW(LPCWSTR lpRootPathName)
        {
                CloseHandle(hFile);
                SetLastErrorByStatus (errCode);
-               return 0;       
+               return 0;
        }
        CloseHandle(hFile);
 
@@ -444,7 +444,7 @@ GetVolumeInformationA(
        )
 {
   UNICODE_STRING FileSystemNameU;
-  UNICODE_STRING VolumeNameU;
+  UNICODE_STRING VolumeNameU = {0};
   ANSI_STRING VolumeName;
   ANSI_STRING FileSystemName;
   PWCHAR RootPathNameW;
@@ -452,14 +452,17 @@ GetVolumeInformationA(
 
   if (!(RootPathNameW = FilenameA2W(lpRootPathName, FALSE)))
      return FALSE;
-  
+
   if (lpVolumeNameBuffer)
     {
-      VolumeNameU.Length = 0;
       VolumeNameU.MaximumLength = nVolumeNameSize * sizeof(WCHAR);
       VolumeNameU.Buffer = RtlAllocateHeap (RtlGetProcessHeap (),
                                            0,
                                            VolumeNameU.MaximumLength);
+      if (VolumeNameU.Buffer == NULL)
+      {
+          goto FailNoMem;
+      }
     }
 
   if (lpFileSystemNameBuffer)
@@ -469,6 +472,19 @@ GetVolumeInformationA(
       FileSystemNameU.Buffer = RtlAllocateHeap (RtlGetProcessHeap (),
                                                0,
                                                FileSystemNameU.MaximumLength);
+      if (FileSystemNameU.Buffer == NULL)
+      {
+          if (VolumeNameU.Buffer != NULL)
+          {
+              RtlFreeHeap(RtlGetProcessHeap(),
+                          0,
+                          VolumeNameU.Buffer);
+          }
+
+FailNoMem:
+          SetLastError(ERROR_NOT_ENOUGH_MEMORY);
+          return FALSE;
+      }
     }
 
   Result = GetVolumeInformationW (RootPathNameW,
@@ -603,7 +619,7 @@ GetVolumeInformationW(
                                          FileFsVolume,
                                          FS_VOLUME_BUFFER_SIZE,
                                          FileFsVolumeInformation);
-  if ( !NT_SUCCESS(errCode) ) 
+  if ( !NT_SUCCESS(errCode) )
     {
       DPRINT("Status: %x\n", errCode);
       CloseHandle(hFile);
@@ -618,8 +634,8 @@ GetVolumeInformationW(
     {
       if (nVolumeNameSize * sizeof(WCHAR) >= FileFsVolume->VolumeLabelLength + sizeof(WCHAR))
         {
-         memcpy(lpVolumeNameBuffer, 
-                FileFsVolume->VolumeLabel, 
+         memcpy(lpVolumeNameBuffer,
+                FileFsVolume->VolumeLabel,
                 FileFsVolume->VolumeLabelLength);
          lpVolumeNameBuffer[FileFsVolume->VolumeLabelLength / sizeof(WCHAR)] = 0;
        }
@@ -652,8 +668,8 @@ GetVolumeInformationW(
     {
       if (nFileSystemNameSize * sizeof(WCHAR) >= FileFsAttribute->FileSystemNameLength + sizeof(WCHAR))
         {
-         memcpy(lpFileSystemNameBuffer, 
-                FileFsAttribute->FileSystemName, 
+         memcpy(lpFileSystemNameBuffer,
+                FileFsAttribute->FileSystemName,
                 FileFsAttribute->FileSystemNameLength);
          lpFileSystemNameBuffer[FileFsAttribute->FileSystemNameLength / sizeof(WCHAR)] = 0;
        }
@@ -689,7 +705,7 @@ SetVolumeLabelA (
       if (!(VolumeNameW = FilenameA2W(lpVolumeName, TRUE)))
          return FALSE;
    }
-      
+
    Result = SetVolumeLabelW (RootPathNameW,
                              VolumeNameW);
 
@@ -718,12 +734,17 @@ SetVolumeLabelW(
    ULONG LabelLength;
    HANDLE hFile;
    NTSTATUS Status;
-   
+
    LabelLength = wcslen(lpVolumeName) * sizeof(WCHAR);
    LabelInfo = RtlAllocateHeap(RtlGetProcessHeap(),
                               0,
                               sizeof(FILE_FS_LABEL_INFORMATION) +
                               LabelLength);
+   if (LabelInfo == NULL)
+   {
+       SetLastError(ERROR_NOT_ENOUGH_MEMORY);
+       return FALSE;
+   }
    LabelInfo->VolumeLabelLength = LabelLength;
    memcpy(LabelInfo->VolumeLabel,
          lpVolumeName,
@@ -737,7 +758,7 @@ SetVolumeLabelW(
                    LabelInfo);
         return FALSE;
    }
-   
+
    Status = NtSetVolumeInformationFile(hFile,
                                       &IoStatusBlock,
                                       LabelInfo,
@@ -761,4 +782,223 @@ SetVolumeLabelW(
    return TRUE;
 }
 
+/**
+ * @name GetVolumeNameForVolumeMountPointW
+ *
+ * Return an unique volume name for a drive root or mount point.
+ *
+ * @param VolumeMountPoint
+ *        Pointer to string that contains either root drive name or
+ *        mount point name.
+ * @param VolumeName
+ *        Pointer to buffer that is filled with resulting unique
+ *        volume name on success.
+ * @param VolumeNameLength
+ *        Size of VolumeName buffer in TCHARs.
+ *
+ * @return
+ *     TRUE when the function succeeds and the VolumeName buffer is filled,
+ *     FALSE otherwise.
+ */
+
+BOOL WINAPI
+GetVolumeNameForVolumeMountPointW(
+   IN LPCWSTR VolumeMountPoint,
+   OUT LPWSTR VolumeName,
+   IN DWORD VolumeNameLength)
+{
+   UNICODE_STRING NtFileName;
+   OBJECT_ATTRIBUTES ObjectAttributes;
+   HANDLE FileHandle;
+   IO_STATUS_BLOCK Iosb;
+   PVOID Buffer;
+   ULONG BufferLength;
+   PMOUNTDEV_NAME MountDevName;
+   PMOUNTMGR_MOUNT_POINT MountPoint;
+   ULONG MountPointSize;
+   PMOUNTMGR_MOUNT_POINTS MountPoints;
+   ULONG Index;
+   PUCHAR SymbolicLinkName;
+   NTSTATUS Status;
+
+   /*
+    * First step is to convert the passed volume mount point name to
+    * an NT acceptable name.
+    */
+
+   if (!RtlDosPathNameToNtPathName_U(VolumeName, &NtFileName, NULL, NULL))
+   {
+      SetLastError(ERROR_PATH_NOT_FOUND);
+      return FALSE;
+   }
+
+   if (NtFileName.Length > sizeof(WCHAR) &&
+       NtFileName.Buffer[(NtFileName.Length / sizeof(WCHAR)) - 1] == '\\')
+   {
+      NtFileName.Length -= sizeof(WCHAR);
+   }
+
+   /*
+    * Query mount point device name which we will later use for determining
+    * the volume name.
+    */
+
+   InitializeObjectAttributes(&ObjectAttributes, &NtFileName, 0, NULL, NULL);
+   Status = NtOpenFile(&FileHandle, FILE_READ_ATTRIBUTES | SYNCHRONIZE,
+                       &ObjectAttributes, &Iosb,
+                       FILE_SHARE_READ | FILE_SHARE_WRITE,
+                       FILE_SYNCHRONOUS_IO_NONALERT);
+   RtlFreeUnicodeString(&NtFileName);
+   if (!NT_SUCCESS(Status))
+   {
+      SetLastErrorByStatus(Status);
+      return FALSE;
+   }
+
+   BufferLength = sizeof(MOUNTDEV_NAME) + 50 * sizeof(WCHAR);
+   do
+   {
+      MountDevName = Buffer = RtlAllocateHeap(GetProcessHeap(), 0, BufferLength);
+      if (Buffer == NULL)
+      {
+         NtClose(FileHandle);
+         SetLastErrorByStatus(Status);
+         return FALSE;
+      }
+
+      Status = NtDeviceIoControlFile(FileHandle, NULL, NULL, NULL, &Iosb,
+                                     IOCTL_MOUNTDEV_QUERY_DEVICE_NAME,
+                                     NULL, 0, Buffer, BufferLength);
+      if (Status == STATUS_BUFFER_OVERFLOW)
+      {
+         BufferLength = sizeof(MOUNTDEV_NAME) + MountDevName->NameLength;
+         RtlFreeHeap(GetProcessHeap(), 0, Buffer);
+         continue;
+      }
+      else if (!NT_SUCCESS(Status))
+      {
+         RtlFreeHeap(GetProcessHeap(), 0, Buffer);
+         NtClose(FileHandle);
+         SetLastErrorByStatus(Status);
+         return FALSE;
+      }
+   }
+   while (!NT_SUCCESS(Status));
+
+   NtClose(FileHandle);
+
+   /*
+    * Get the mount point information from mount manager.
+    */
+
+   MountPointSize = MountDevName->NameLength + sizeof(MOUNTMGR_MOUNT_POINT);
+   MountPoint = RtlAllocateHeap(GetProcessHeap(), 0, MountPointSize);
+   if (MountPoint == NULL)
+   {
+      RtlFreeHeap(GetProcessHeap(), 0, MountDevName);
+      SetLastError(ERROR_NOT_ENOUGH_MEMORY);
+      return FALSE;
+   }
+   RtlZeroMemory(MountPoint, sizeof(MOUNTMGR_MOUNT_POINT));
+   MountPoint->DeviceNameOffset = sizeof(MOUNTMGR_MOUNT_POINT);
+   MountPoint->DeviceNameLength = MountDevName->NameLength;
+   RtlCopyMemory(MountPoint + 1, MountDevName->Name, MountDevName->NameLength);
+   RtlFreeHeap(GetProcessHeap(), 0, MountDevName);
+
+   RtlInitUnicodeString(&NtFileName, L"\\??\\MountPointManager");
+   InitializeObjectAttributes(&ObjectAttributes, &NtFileName, 0, NULL, NULL);
+   Status = NtOpenFile(&FileHandle, FILE_GENERIC_READ, &ObjectAttributes,
+                       &Iosb, FILE_SHARE_READ | FILE_SHARE_WRITE,
+                       FILE_SYNCHRONOUS_IO_NONALERT);
+   if (!NT_SUCCESS(Status))
+   {
+      SetLastErrorByStatus(Status);
+      RtlFreeHeap(GetProcessHeap(), 0, MountPoint);
+      return FALSE;
+   }
+
+   BufferLength = sizeof(MOUNTMGR_MOUNT_POINTS);
+   do
+   {
+      MountPoints = Buffer = RtlAllocateHeap(GetProcessHeap(), 0, BufferLength);
+      if (Buffer == NULL)
+      {
+         RtlFreeHeap(GetProcessHeap(), 0, MountPoint);
+         NtClose(FileHandle);
+         SetLastErrorByStatus(Status);
+         return FALSE;
+      }
+
+      Status = NtDeviceIoControlFile(FileHandle, NULL, NULL, NULL, &Iosb,
+                                     IOCTL_MOUNTMGR_QUERY_POINTS,
+                                     MountPoint, MountPointSize,
+                                     Buffer, BufferLength);
+      if (Status == STATUS_BUFFER_OVERFLOW)
+      {
+         BufferLength = MountPoints->Size;
+         RtlFreeHeap(GetProcessHeap(), 0, Buffer);
+         continue;
+      }
+      else if (!NT_SUCCESS(Status))
+      {
+         RtlFreeHeap(GetProcessHeap(), 0, MountPoint);
+         RtlFreeHeap(GetProcessHeap(), 0, Buffer);
+         SetLastErrorByStatus(Status);
+         return FALSE;
+      }
+   }
+   while (!NT_SUCCESS(Status));
+
+   RtlFreeHeap(GetProcessHeap(), 0, MountPoint);
+   NtClose(FileHandle);
+
+   /*
+    * Now we've gathered info about all mount points mapped to our device, so
+    * select the correct one and copy it into the output buffer.
+    */
+
+   for (Index = 0; Index < MountPoints->NumberOfMountPoints; Index++)
+   {
+      MountPoint = MountPoints->MountPoints + Index;
+      SymbolicLinkName = (PUCHAR)MountPoints + MountPoint->SymbolicLinkNameOffset;
+      
+      /*
+       * Check for "\\?\Volume{xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}\"
+       * (with the last slash being optional) style symbolic links.
+       */
+
+      if (MountPoint->SymbolicLinkNameLength == 48 * sizeof(WCHAR) ||
+          (MountPoint->SymbolicLinkNameLength == 49 * sizeof(WCHAR) &&
+           SymbolicLinkName[48] == L'\\'))
+      {
+         if (RtlCompareMemory(SymbolicLinkName, L"\\??\\Volume{",
+                              11 * sizeof(WCHAR)) == 11 * sizeof(WCHAR) &&
+             SymbolicLinkName[19] == L'-' && SymbolicLinkName[24] == L'-' &&
+             SymbolicLinkName[29] == L'-' && SymbolicLinkName[34] == L'-' &&
+             SymbolicLinkName[47] == L'}')
+         {
+            if (VolumeNameLength >= MountPoint->SymbolicLinkNameLength / sizeof(WCHAR))
+            {
+               RtlCopyMemory(VolumeName,
+                             (PUCHAR)MountPoints + MountPoint->SymbolicLinkNameOffset,
+                             MountPoint->SymbolicLinkNameLength);
+               VolumeName[1] = L'\\';
+               RtlFreeHeap(GetProcessHeap(), 0, MountPoints);
+               return TRUE;
+            }
+
+            RtlFreeHeap(GetProcessHeap(), 0, MountPoints);
+            SetLastError(ERROR_FILENAME_EXCED_RANGE);
+
+            return FALSE;
+         }
+      }
+   }
+
+   RtlFreeHeap(GetProcessHeap(), 0, MountPoints);
+   SetLastError(ERROR_INVALID_PARAMETER);
+
+   return FALSE;
+}
+
 /* EOF */