[KERNEL32]: Reimplement GetDiskFreeSpaceExA() to make it w2k3 compliant
[reactos.git] / dll / win32 / kernel32 / client / file / disk.c
index 2397247..1b7d1e3 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * COPYRIGHT:       See COPYING in the top level directory
  * PROJECT:         ReactOS system libraries
- * FILE:            lib/kernel32/file/disk.c
+ * FILE:            dll/win32/kernel32/client/file/disk.c
  * PURPOSE:         Disk and Drive functions
  * PROGRAMMER:      Ariadne ( ariadne@xs4all.nl)
  *                  Erik Bos, Alexandre Julliard :
  */
 
 #include <k32.h>
+#include <strsafe.h>
+
 #define NDEBUG
 #include <debug.h>
 DEBUG_CHANNEL(kernel32file);
 
 #define MAX_DOS_DRIVES 26
-HANDLE WINAPI InternalOpenDirW(IN LPCWSTR DirName, IN BOOLEAN Write);
+
+HANDLE
+WINAPI
+InternalOpenDirW(IN LPCWSTR DirName,
+                 IN BOOLEAN Write)
+{
+    UNICODE_STRING NtPathU;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    NTSTATUS errCode;
+    IO_STATUS_BLOCK IoStatusBlock;
+    HANDLE hFile;
+
+    if (!RtlDosPathNameToNtPathName_U(DirName, &NtPathU, NULL, NULL))
+    {
+        WARN("Invalid path\n");
+        SetLastError(ERROR_BAD_PATHNAME);
+        return INVALID_HANDLE_VALUE;
+    }
+
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &NtPathU,
+                               OBJ_CASE_INSENSITIVE,
+                               NULL,
+                               NULL);
+
+    errCode = NtCreateFile(&hFile,
+                           Write ? FILE_GENERIC_WRITE : FILE_GENERIC_READ,
+                           &ObjectAttributes,
+                           &IoStatusBlock,
+                           NULL,
+                           0,
+                           FILE_SHARE_READ | FILE_SHARE_WRITE,
+                           FILE_OPEN,
+                           0,
+                           NULL,
+                           0);
+
+    RtlFreeHeap(RtlGetProcessHeap(), 0, NtPathU.Buffer);
+
+    if (!NT_SUCCESS(errCode))
+    {
+        BaseSetLastNTError(errCode);
+        return INVALID_HANDLE_VALUE;
+    }
+
+    return hFile;
+}
 
 /*
  * @implemented
@@ -35,32 +83,32 @@ WINAPI
 GetLogicalDriveStringsA(IN DWORD nBufferLength,
                         IN LPSTR lpBuffer)
 {
-   DWORD drive, count;
-   DWORD dwDriveMap;
-   LPSTR p;
+    DWORD drive, count;
+    DWORD dwDriveMap;
+    LPSTR p;
 
-   dwDriveMap = GetLogicalDrives();
+    dwDriveMap = GetLogicalDrives();
 
-   for (drive = count = 0; drive < MAX_DOS_DRIVES; drive++)
-     {
-       if (dwDriveMap & (1<<drive))
-          count++;
-     }
+    for (drive = count = 0; drive < MAX_DOS_DRIVES; drive++)
+    {
+        if (dwDriveMap & (1<<drive))
+            count++;
+    }
 
 
-   if ((count * 4) + 1 > nBufferLength) return ((count * 4) + 1);
+    if ((count * 4) + 1 > nBufferLength) return ((count * 4) + 1);
 
-       p = lpBuffer;
+    p = lpBuffer;
 
-       for (drive = 0; drive < MAX_DOS_DRIVES; drive++)
-         if (dwDriveMap & (1<<drive))
-         {
-            *p++ = 'A' + (UCHAR)drive;
-            *p++ = ':';
-            *p++ = '\\';
-            *p++ = '\0';
-         }
-       *p = '\0';
+    for (drive = 0; drive < MAX_DOS_DRIVES; drive++)
+        if (dwDriveMap & (1<<drive))
+        {
+            *p++ = 'A' + (UCHAR)drive;
+            *p++ = ':';
+            *p++ = '\\';
+            *p++ = '\0';
+        }
+    *p = '\0';
 
     return (count * 4);
 }
@@ -74,17 +122,17 @@ WINAPI
 GetLogicalDriveStringsW(IN DWORD nBufferLength,
                         IN LPWSTR lpBuffer)
 {
-   DWORD drive, count;
-   DWORD dwDriveMap;
-   LPWSTR p;
+    DWORD drive, count;
+    DWORD dwDriveMap;
+    LPWSTR p;
 
-   dwDriveMap = GetLogicalDrives();
+    dwDriveMap = GetLogicalDrives();
 
-   for (drive = count = 0; drive < MAX_DOS_DRIVES; drive++)
-     {
-       if (dwDriveMap & (1<<drive))
-          count++;
-     }
+    for (drive = count = 0; drive < MAX_DOS_DRIVES; drive++)
+    {
+        if (dwDriveMap & (1<<drive))
+            count++;
+    }
 
     if ((count * 4) + 1 > nBufferLength) return ((count * 4) + 1);
 
@@ -110,24 +158,24 @@ DWORD
 WINAPI
 GetLogicalDrives(VOID)
 {
-       NTSTATUS Status;
-       PROCESS_DEVICEMAP_INFORMATION ProcessDeviceMapInfo;
-
-       /* Get the Device Map for this Process */
-       Status = NtQueryInformationProcess(NtCurrentProcess(),
-                                          ProcessDeviceMap,
-                                          &ProcessDeviceMapInfo,
-                                          sizeof(ProcessDeviceMapInfo),
-                                          NULL);
-
-       /* Return the Drive Map */
-       if (!NT_SUCCESS(Status))
-       {
-               BaseSetLastNTError(Status);
-               return 0;
-       }
-
-        return ProcessDeviceMapInfo.Query.DriveMap;
+    NTSTATUS Status;
+    PROCESS_DEVICEMAP_INFORMATION ProcessDeviceMapInfo;
+
+    /* Get the Device Map for this Process */
+    Status = NtQueryInformationProcess(NtCurrentProcess(),
+                                       ProcessDeviceMap,
+                                       &ProcessDeviceMapInfo,
+                                       sizeof(ProcessDeviceMapInfo),
+                                       NULL);
+
+    /* Return the Drive Map */
+    if (!NT_SUCCESS(Status))
+    {
+        BaseSetLastNTError(Status);
+        return 0;
+    }
+
+    return ProcessDeviceMapInfo.Query.DriveMap;
 }
 
 /*
@@ -141,19 +189,24 @@ GetDiskFreeSpaceA(IN LPCSTR lpRootPathName,
                   OUT LPDWORD lpNumberOfFreeClusters,
                   OUT LPDWORD lpTotalNumberOfClusters)
 {
-   PWCHAR RootPathNameW=NULL;
-
-   if (lpRootPathName)
-   {
-      if (!(RootPathNameW = FilenameA2W(lpRootPathName, FALSE)))
-         return FALSE;
-   }
-
-       return GetDiskFreeSpaceW (RootPathNameW,
-                                   lpSectorsPerCluster,
-                                   lpBytesPerSector,
-                                   lpNumberOfFreeClusters,
-                                   lpTotalNumberOfClusters);
+    PCSTR RootPath;
+    PUNICODE_STRING RootPathU;
+
+    RootPath = lpRootPathName;
+    if (RootPath == NULL)
+    {
+        RootPath = "\\";
+    }
+
+    RootPathU = Basep8BitStringToStaticUnicodeString(RootPath);
+    if (RootPathU == NULL)
+    {
+        return FALSE;
+    }
+
+    return GetDiskFreeSpaceW(RootPathU->Buffer, lpSectorsPerCluster,
+                             lpBytesPerSector, lpNumberOfFreeClusters,
+                             lpTotalNumberOfClusters);
 }
 
 /*
@@ -167,50 +220,131 @@ GetDiskFreeSpaceW(IN LPCWSTR lpRootPathName,
                   OUT LPDWORD lpNumberOfFreeClusters,
                   OUT LPDWORD lpTotalNumberOfClusters)
 {
-    FILE_FS_SIZE_INFORMATION FileFsSize;
+    BOOL Below2GB;
+    PCWSTR RootPath;
+    NTSTATUS Status;
+    HANDLE RootHandle;
+    UNICODE_STRING FileName;
     IO_STATUS_BLOCK IoStatusBlock;
-    WCHAR RootPathName[MAX_PATH];
-    HANDLE hFile;
-    NTSTATUS errCode;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    FILE_FS_SIZE_INFORMATION FileFsSize;
 
-    if (lpRootPathName)
+    /* If no path provided, get root path */
+    RootPath = lpRootPathName;
+    if (lpRootPathName == NULL)
     {
-        wcsncpy (RootPathName, lpRootPathName, 3);
+        RootPath = L"\\";
     }
-    else
+
+    /* Convert the path to NT path */
+    if (!RtlDosPathNameToNtPathName_U(RootPath, &FileName, NULL, NULL))
     {
-        GetCurrentDirectoryW (MAX_PATH, RootPathName);
+        SetLastError(ERROR_PATH_NOT_FOUND);
+        return FALSE;
     }
-    RootPathName[3] = 0;
 
-  hFile = InternalOpenDirW(RootPathName, FALSE);
-  if (INVALID_HANDLE_VALUE == hFile)
+    /* Open it for disk space query! */
+    InitializeObjectAttributes(&ObjectAttributes, &FileName,
+                               OBJ_CASE_INSENSITIVE, NULL, NULL);
+    Status = NtOpenFile(&RootHandle, SYNCHRONIZE, &ObjectAttributes, &IoStatusBlock,
+                        FILE_SHARE_READ | FILE_SHARE_WRITE,
+                        FILE_DIRECTORY_FILE | FILE_SYNCHRONOUS_IO_NONALERT | FILE_OPEN_FOR_FREE_SPACE_QUERY);
+    if (!NT_SUCCESS(Status))
     {
-      SetLastError(ERROR_PATH_NOT_FOUND);
-      return FALSE;
+        BaseSetLastNTError(Status);
+        RtlFreeHeap(RtlGetProcessHeap(), 0, FileName.Buffer);
+        if (lpBytesPerSector != NULL)
+        {
+            *lpBytesPerSector = 0;
+        }
+
+        return FALSE;
     }
 
-    errCode = NtQueryVolumeInformationFile(hFile,
-                                           &IoStatusBlock,
-                                           &FileFsSize,
-                                           sizeof(FILE_FS_SIZE_INFORMATION),
-                                           FileFsSizeInformation);
-    if (!NT_SUCCESS(errCode))
+    /* We don't need the name any longer */
+    RtlFreeHeap(RtlGetProcessHeap(), 0, FileName.Buffer);
+
+    /* Query disk space! */
+    Status = NtQueryVolumeInformationFile(RootHandle, &IoStatusBlock, &FileFsSize,
+                                          sizeof(FILE_FS_SIZE_INFORMATION),
+                                          FileFsSizeInformation);
+    NtClose(RootHandle);
+    if (!NT_SUCCESS(Status))
     {
-        CloseHandle(hFile);
-        BaseSetLastNTError (errCode);
+        BaseSetLastNTError(Status);
         return FALSE;
     }
 
-    if (lpSectorsPerCluster)
+    /* Are we in some compatibility mode where size must be below 2GB? */
+    Below2GB = ((NtCurrentPeb()->AppCompatFlags.LowPart & GetDiskFreeSpace2GB) == GetDiskFreeSpace2GB);
+
+    /* If we're to overflow output, make sure we return the maximum */
+    if (FileFsSize.TotalAllocationUnits.HighPart != 0)
+    {
+        FileFsSize.TotalAllocationUnits.LowPart = -1;
+    }
+
+    if (FileFsSize.AvailableAllocationUnits.HighPart != 0)
+    {
+        FileFsSize.AvailableAllocationUnits.LowPart = -1;
+    }
+
+    /* Return what user asked for */
+    if (lpSectorsPerCluster != NULL)
+    {
         *lpSectorsPerCluster = FileFsSize.SectorsPerAllocationUnit;
-    if (lpBytesPerSector)
+    }
+
+    if (lpBytesPerSector != NULL)
+    {
         *lpBytesPerSector = FileFsSize.BytesPerSector;
-    if (lpNumberOfFreeClusters)
-        *lpNumberOfFreeClusters = FileFsSize.AvailableAllocationUnits.u.LowPart;
-    if (lpTotalNumberOfClusters)
-        *lpTotalNumberOfClusters = FileFsSize.TotalAllocationUnits.u.LowPart;
-    CloseHandle(hFile);
+    }
+
+    if (lpNumberOfFreeClusters != NULL)
+    {
+        if (!Below2GB)
+        {
+            *lpNumberOfFreeClusters = FileFsSize.AvailableAllocationUnits.LowPart;
+        }
+        /* If we have to remain below 2GB... */
+        else
+        {
+            DWORD FreeClusters;
+
+            /* Compute how many clusters there are in less than 2GB: 2 * 1024 * 1024 * 1024- 1 */
+            FreeClusters = 0x7FFFFFFF / (FileFsSize.SectorsPerAllocationUnit * FileFsSize.BytesPerSector);
+            /* If that's higher than what was queried, then return the queried value, it's OK! */
+            if (FreeClusters > FileFsSize.AvailableAllocationUnits.LowPart)
+            {
+                FreeClusters = FileFsSize.AvailableAllocationUnits.LowPart;
+            }
+
+            *lpNumberOfFreeClusters = FreeClusters;
+        }
+    }
+
+    if (lpTotalNumberOfClusters != NULL)
+    {
+        if (!Below2GB)
+        {
+            *lpTotalNumberOfClusters = FileFsSize.TotalAllocationUnits.LowPart;
+        }
+        /* If we have to remain below 2GB... */
+        else
+        {
+            DWORD TotalClusters;
+
+            /* Compute how many clusters there are in less than 2GB: 2 * 1024 * 1024 * 1024- 1 */
+            TotalClusters = 0x7FFFFFFF / (FileFsSize.SectorsPerAllocationUnit * FileFsSize.BytesPerSector);
+            /* If that's higher than what was queried, then return the queried value, it's OK! */
+            if (TotalClusters > FileFsSize.TotalAllocationUnits.LowPart)
+            {
+                TotalClusters = FileFsSize.TotalAllocationUnits.LowPart;
+            }
+
+            *lpTotalNumberOfClusters = TotalClusters;
+        }
+    }
 
     return TRUE;
 }
@@ -225,18 +359,23 @@ GetDiskFreeSpaceExA(IN LPCSTR lpDirectoryName OPTIONAL,
                     OUT PULARGE_INTEGER lpTotalNumberOfBytes,
                     OUT PULARGE_INTEGER lpTotalNumberOfFreeBytes)
 {
-   PWCHAR DirectoryNameW=NULL;
-
-       if (lpDirectoryName)
-       {
-      if (!(DirectoryNameW = FilenameA2W(lpDirectoryName, FALSE)))
-         return FALSE;
-       }
-
-   return GetDiskFreeSpaceExW (DirectoryNameW ,
-                                     lpFreeBytesAvailableToCaller,
-                                     lpTotalNumberOfBytes,
-                                     lpTotalNumberOfFreeBytes);
+    PCSTR RootPath;
+    PUNICODE_STRING RootPathU;
+
+    RootPath = lpDirectoryName;
+    if (RootPath == NULL)
+    {
+        RootPath = "\\";
+    }
+
+    RootPathU = Basep8BitStringToStaticUnicodeString(RootPath);
+    if (RootPathU == NULL)
+    {
+        return FALSE;
+    }
+
+    return GetDiskFreeSpaceExW(RootPathU->Buffer, lpFreeBytesAvailableToCaller,
+                              lpTotalNumberOfBytes, lpTotalNumberOfFreeBytes);
 }
 
 /*
@@ -249,103 +388,117 @@ GetDiskFreeSpaceExW(IN LPCWSTR lpDirectoryName OPTIONAL,
                     OUT PULARGE_INTEGER lpTotalNumberOfBytes,
                     OUT PULARGE_INTEGER lpTotalNumberOfFreeBytes)
 {
-    union
-    {
-        FILE_FS_SIZE_INFORMATION FsSize;
-        FILE_FS_FULL_SIZE_INFORMATION FsFullSize;
-    } FsInfo;
-    IO_STATUS_BLOCK IoStatusBlock;
-    ULARGE_INTEGER BytesPerCluster;
-    HANDLE hFile;
+    PCWSTR RootPath;
     NTSTATUS Status;
+    HANDLE RootHandle;
+    UNICODE_STRING FileName;
+    DWORD BytesPerAllocationUnit;
+    IO_STATUS_BLOCK IoStatusBlock;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    FILE_FS_SIZE_INFORMATION FileFsSize;
 
+    /* If no path provided, get root path */
+    RootPath = lpDirectoryName;
     if (lpDirectoryName == NULL)
-        lpDirectoryName = L"\\";
+    {
+        RootPath = L"\\";
+    }
 
-    hFile = InternalOpenDirW(lpDirectoryName, FALSE);
-    if (INVALID_HANDLE_VALUE == hFile)
+    /* Convert the path to NT path */
+    if (!RtlDosPathNameToNtPathName_U(RootPath, &FileName, NULL, NULL))
     {
+        SetLastError(ERROR_PATH_NOT_FOUND);
         return FALSE;
     }
 
-    if (lpFreeBytesAvailableToCaller != NULL || lpTotalNumberOfBytes != NULL)
+    /* Open it for disk space query! */
+    InitializeObjectAttributes(&ObjectAttributes, &FileName,
+                               OBJ_CASE_INSENSITIVE, NULL, NULL);
+    Status = NtOpenFile(&RootHandle, SYNCHRONIZE, &ObjectAttributes, &IoStatusBlock,
+                        FILE_SHARE_READ | FILE_SHARE_WRITE,
+                        FILE_DIRECTORY_FILE | FILE_SYNCHRONOUS_IO_NONALERT | FILE_OPEN_FOR_FREE_SPACE_QUERY);
+    if (!NT_SUCCESS(Status))
     {
-        /* To get the free space available to the user associated with the
-           current thread, try FileFsFullSizeInformation. If this is not
-           supported by the file system, fall back to FileFsSize */
+        BaseSetLastNTError(Status);
+        /* If error conversion lead to file not found, override to use path not found
+         * which is more accurate
+         */
+        if (GetLastError() == ERROR_FILE_NOT_FOUND)
+        {
+            SetLastError(ERROR_PATH_NOT_FOUND);
+        }
 
-        Status = NtQueryVolumeInformationFile(hFile,
-                                              &IoStatusBlock,
-                                              &FsInfo.FsFullSize,
-                                              sizeof(FsInfo.FsFullSize),
-                                              FileFsFullSizeInformation);
+        RtlFreeHeap(RtlGetProcessHeap(), 0, FileName.Buffer);
+
+        return FALSE;
+    }
 
+    RtlFreeHeap(RtlGetProcessHeap(), 0, FileName.Buffer);
+
+    /* If user asks for lpTotalNumberOfFreeBytes, try to use full size information */
+    if (lpTotalNumberOfFreeBytes != NULL)
+    {
+        FILE_FS_FULL_SIZE_INFORMATION FileFsFullSize;
+
+        /* Issue the full fs size request */
+        Status = NtQueryVolumeInformationFile(RootHandle, &IoStatusBlock, &FileFsFullSize,
+                                              sizeof(FILE_FS_FULL_SIZE_INFORMATION),
+                                              FileFsFullSizeInformation);
+        /* If it succeed, complete out buffers */
         if (NT_SUCCESS(Status))
         {
-            /* Close the handle before returning data
-               to avoid a handle leak in case of a fault! */
-            CloseHandle(hFile);
+            /* We can close here, we'll return */
+            NtClose(RootHandle);
 
-            BytesPerCluster.QuadPart =
-                FsInfo.FsFullSize.BytesPerSector * FsInfo.FsFullSize.SectorsPerAllocationUnit;
+            /* Compute the size of an AU */
+            BytesPerAllocationUnit = FileFsFullSize.SectorsPerAllocationUnit * FileFsFullSize.BytesPerSector;
 
+            /* And then return what was asked */
             if (lpFreeBytesAvailableToCaller != NULL)
             {
-                lpFreeBytesAvailableToCaller->QuadPart =
-                    BytesPerCluster.QuadPart * FsInfo.FsFullSize.CallerAvailableAllocationUnits.QuadPart;
+                lpFreeBytesAvailableToCaller->QuadPart = FileFsFullSize.CallerAvailableAllocationUnits.QuadPart * BytesPerAllocationUnit;
             }
 
             if (lpTotalNumberOfBytes != NULL)
             {
-                lpTotalNumberOfBytes->QuadPart =
-                    BytesPerCluster.QuadPart * FsInfo.FsFullSize.TotalAllocationUnits.QuadPart;
+                lpTotalNumberOfBytes->QuadPart = FileFsFullSize.TotalAllocationUnits.QuadPart * BytesPerAllocationUnit;
             }
 
-            if (lpTotalNumberOfFreeBytes != NULL)
-            {
-                lpTotalNumberOfFreeBytes->QuadPart =
-                    BytesPerCluster.QuadPart * FsInfo.FsFullSize.ActualAvailableAllocationUnits.QuadPart;
-            }
+            /* No need to check for nullness ;-) */
+            lpTotalNumberOfFreeBytes->QuadPart = FileFsFullSize.ActualAvailableAllocationUnits.QuadPart * BytesPerAllocationUnit;
 
             return TRUE;
         }
     }
 
-    Status = NtQueryVolumeInformationFile(hFile,
-                                          &IoStatusBlock,
-                                          &FsInfo.FsSize,
-                                          sizeof(FsInfo.FsSize),
+    /* Otherwise, fallback to normal size information */
+    Status = NtQueryVolumeInformationFile(RootHandle, &IoStatusBlock,
+                                          &FileFsSize, sizeof(FILE_FS_SIZE_INFORMATION),
                                           FileFsSizeInformation);
-
-    /* Close the handle before returning data
-       to avoid a handle leak in case of a fault! */
-    CloseHandle(hFile);
-
+    NtClose(RootHandle);
     if (!NT_SUCCESS(Status))
     {
-        BaseSetLastNTError (Status);
+        BaseSetLastNTError(Status);
         return FALSE;
     }
 
-    BytesPerCluster.QuadPart =
-        FsInfo.FsSize.BytesPerSector * FsInfo.FsSize.SectorsPerAllocationUnit;
+    /* Compute the size of an AU */
+    BytesPerAllocationUnit = FileFsSize.SectorsPerAllocationUnit * FileFsSize.BytesPerSector;
 
-    if (lpFreeBytesAvailableToCaller)
+    /* And then return what was asked, available is free, the same! */
+    if (lpFreeBytesAvailableToCaller != NULL)
     {
-        lpFreeBytesAvailableToCaller->QuadPart =
-            BytesPerCluster.QuadPart * FsInfo.FsSize.AvailableAllocationUnits.QuadPart;
+        lpFreeBytesAvailableToCaller->QuadPart = FileFsSize.AvailableAllocationUnits.QuadPart * BytesPerAllocationUnit;
     }
 
-    if (lpTotalNumberOfBytes)
+    if (lpTotalNumberOfBytes != NULL)
     {
-        lpTotalNumberOfBytes->QuadPart =
-            BytesPerCluster.QuadPart * FsInfo.FsSize.TotalAllocationUnits.QuadPart;
+        lpTotalNumberOfBytes->QuadPart = FileFsSize.TotalAllocationUnits.QuadPart * BytesPerAllocationUnit;
     }
 
-    if (lpTotalNumberOfFreeBytes)
+    if (lpTotalNumberOfFreeBytes != NULL)
     {
-        lpTotalNumberOfFreeBytes->QuadPart =
-            BytesPerCluster.QuadPart * FsInfo.FsSize.AvailableAllocationUnits.QuadPart;
+        lpTotalNumberOfFreeBytes->QuadPart = FileFsSize.AvailableAllocationUnits.QuadPart * BytesPerAllocationUnit;
     }
 
     return TRUE;
@@ -358,12 +511,15 @@ UINT
 WINAPI
 GetDriveTypeA(IN LPCSTR lpRootPathName)
 {
-   PWCHAR RootPathNameW;
+    PWCHAR RootPathNameW;
+
+    if (!lpRootPathName)
+        return GetDriveTypeW(NULL);
 
-   if (!(RootPathNameW = FilenameA2W(lpRootPathName, FALSE)))
-      return DRIVE_UNKNOWN;
+    if (!(RootPathNameW = FilenameA2W(lpRootPathName, FALSE)))
+        return DRIVE_UNKNOWN;
 
-   return GetDriveTypeW(RootPathNameW);
+    return GetDriveTypeW(RootPathNameW);
 }
 
 /*
@@ -373,52 +529,128 @@ UINT
 WINAPI
 GetDriveTypeW(IN LPCWSTR lpRootPathName)
 {
-       FILE_FS_DEVICE_INFORMATION FileFsDevice;
-       IO_STATUS_BLOCK IoStatusBlock;
-
-       HANDLE hFile;
-       NTSTATUS errCode;
-
-       hFile = InternalOpenDirW(lpRootPathName, FALSE);
-       if (hFile == INVALID_HANDLE_VALUE)
-       {
-           return DRIVE_NO_ROOT_DIR;   /* According to WINE regression tests */
-       }
-
-       errCode = NtQueryVolumeInformationFile (hFile,
-                                               &IoStatusBlock,
-                                               &FileFsDevice,
-                                               sizeof(FILE_FS_DEVICE_INFORMATION),
-                                               FileFsDeviceInformation);
-       if (!NT_SUCCESS(errCode))
-       {
-               CloseHandle(hFile);
-               BaseSetLastNTError (errCode);
-               return 0;
-       }
-       CloseHandle(hFile);
-
-        switch (FileFsDevice.DeviceType)
+    FILE_FS_DEVICE_INFORMATION FileFsDevice;
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    IO_STATUS_BLOCK IoStatusBlock;
+    UNICODE_STRING PathName;
+    HANDLE FileHandle;
+    NTSTATUS Status;
+    PWSTR CurrentDir = NULL;
+    PCWSTR lpRootPath;
+
+    if (!lpRootPathName)
+    {
+        /* If NULL is passed, use current directory path */
+        DWORD BufferSize = GetCurrentDirectoryW(0, NULL);
+        CurrentDir = HeapAlloc(GetProcessHeap(), 0, BufferSize * sizeof(WCHAR));
+        if (!CurrentDir)
+            return DRIVE_UNKNOWN;
+        if (!GetCurrentDirectoryW(BufferSize, CurrentDir))
         {
-               case FILE_DEVICE_CD_ROM:
-               case FILE_DEVICE_CD_ROM_FILE_SYSTEM:
-                       return DRIVE_CDROM;
-               case FILE_DEVICE_VIRTUAL_DISK:
-                       return DRIVE_RAMDISK;
-               case FILE_DEVICE_NETWORK_FILE_SYSTEM:
-                       return DRIVE_REMOTE;
-               case FILE_DEVICE_DISK:
-               case FILE_DEVICE_DISK_FILE_SYSTEM:
-                       if (FileFsDevice.Characteristics & FILE_REMOTE_DEVICE)
-                               return DRIVE_REMOTE;
-                       if (FileFsDevice.Characteristics & FILE_REMOVABLE_MEDIA)
-                               return DRIVE_REMOVABLE;
-                       return DRIVE_FIXED;
+            HeapFree(GetProcessHeap(), 0, CurrentDir);
+            return DRIVE_UNKNOWN;
         }
 
-        ERR("Returning DRIVE_UNKNOWN for device type %lu\n", FileFsDevice.DeviceType);
+        if (wcslen(CurrentDir) > 3)
+            CurrentDir[3] = 0;
+
+        lpRootPath = CurrentDir;
+    }
+    else
+    {
+        size_t Length = wcslen(lpRootPathName);
+
+        TRACE("lpRootPathName: %S\n", lpRootPathName);
+
+        lpRootPath = lpRootPathName;
+        if (Length == 2)
+        {
+            WCHAR DriveLetter = RtlUpcaseUnicodeChar(lpRootPathName[0]);
+
+            if (DriveLetter >= L'A' && DriveLetter <= L'Z' && lpRootPathName[1] == L':')
+            {
+                Length = (Length + 2) * sizeof(WCHAR);
+
+                CurrentDir = HeapAlloc(GetProcessHeap(), 0, Length);
+                if (!CurrentDir)
+                    return DRIVE_UNKNOWN;
+
+                StringCbPrintfW(CurrentDir, Length, L"%s\\", lpRootPathName);
+
+                lpRootPath = CurrentDir;
+            }
+        }
+    }
+
+    TRACE("lpRootPath: %S\n", lpRootPath);
+
+    if (!RtlDosPathNameToNtPathName_U(lpRootPath, &PathName, NULL, NULL))
+    {
+        if (CurrentDir != NULL)
+            HeapFree(GetProcessHeap(), 0, CurrentDir);
+
+        return DRIVE_NO_ROOT_DIR;
+    }
+
+    TRACE("PathName: %S\n", PathName.Buffer);
+
+    if (CurrentDir != NULL)
+        HeapFree(GetProcessHeap(), 0, CurrentDir);
+
+    if (PathName.Buffer[(PathName.Length >> 1) - 1] != L'\\')
+    {
+        return DRIVE_NO_ROOT_DIR;
+    }
+
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &PathName,
+                               OBJ_CASE_INSENSITIVE,
+                               NULL,
+                               NULL);
+
+    Status = NtOpenFile(&FileHandle,
+                        FILE_READ_ATTRIBUTES | SYNCHRONIZE,
+                        &ObjectAttributes,
+                        &IoStatusBlock,
+                        FILE_SHARE_READ | FILE_SHARE_WRITE,
+                        FILE_SYNCHRONOUS_IO_NONALERT);
+
+    RtlFreeHeap(RtlGetProcessHeap(), 0, PathName.Buffer);
+    if (!NT_SUCCESS(Status))
+        return DRIVE_NO_ROOT_DIR; /* According to WINE regression tests */
+
+    Status = NtQueryVolumeInformationFile(FileHandle,
+                                          &IoStatusBlock,
+                                          &FileFsDevice,
+                                          sizeof(FILE_FS_DEVICE_INFORMATION),
+                                          FileFsDeviceInformation);
+    NtClose(FileHandle);
+    if (!NT_SUCCESS(Status))
+    {
+        return 0;
+    }
+
+    switch (FileFsDevice.DeviceType)
+    {
+        case FILE_DEVICE_CD_ROM:
+        case FILE_DEVICE_CD_ROM_FILE_SYSTEM:
+            return DRIVE_CDROM;
+        case FILE_DEVICE_VIRTUAL_DISK:
+            return DRIVE_RAMDISK;
+        case FILE_DEVICE_NETWORK_FILE_SYSTEM:
+            return DRIVE_REMOTE;
+        case FILE_DEVICE_DISK:
+        case FILE_DEVICE_DISK_FILE_SYSTEM:
+            if (FileFsDevice.Characteristics & FILE_REMOTE_DEVICE)
+                return DRIVE_REMOTE;
+            if (FileFsDevice.Characteristics & FILE_REMOVABLE_MEDIA)
+                return DRIVE_REMOVABLE;
+        return DRIVE_FIXED;
+    }
+
+    ERR("Returning DRIVE_UNKNOWN for device type %lu\n", FileFsDevice.DeviceType);
 
-       return DRIVE_UNKNOWN;
+    return DRIVE_UNKNOWN;
 }
 
 /* EOF */