[AUTOCHK] Let the timeout for disk repair to be configured
[reactos.git] / base / system / autochk / autochk.c
index d1cd6a4..5a71ae7 100644 (file)
@@ -1,10 +1,12 @@
-/* PROJECT:         ReactOS Kernel
+/*
+ * PROJECT:         ReactOS Kernel
  * LICENSE:         GPL - See COPYING in the top level directory
  * FILE:            base/system/autochk/autochk.c
  * PURPOSE:         Filesystem checker
  * PROGRAMMERS:     Aleksey Bragin
  *                  Eric Kohl
  *                  HervĂ© Poussineau
+ *                  Pierre Schweitzer
  */
 
 /* INCLUDES *****************************************************************/
@@ -13,6 +15,7 @@
 #define WIN32_NO_STATUS
 #include <windef.h>
 #include <winbase.h>
+#include <ntddkbd.h>
 #define NTOS_MODE_USER
 #include <ndk/exfuncs.h>
 #include <ndk/iofuncs.h>
 #include <ndk/umfuncs.h>
 #include <fmifs/fmifs.h>
 
+#include <fslib/vfatlib.h>
+#include <fslib/ext2lib.h>
+#include <fslib/ntfslib.h>
+#include <fslib/cdfslib.h>
+#include <fslib/btrfslib.h>
+#include <fslib/ffslib.h>
+#include <fslib/reiserfslib.h>
+
 #define NDEBUG
 #include <debug.h>
 
 
 #define FS_ATTRIBUTE_BUFFER_SIZE (MAX_PATH * sizeof(WCHAR) + sizeof(FILE_FS_ATTRIBUTE_INFORMATION))
 
+typedef struct _FILESYSTEM_CHKDSK
+{
+    WCHAR Name[10];
+    CHKDSKEX ChkdskFunc;
+} FILESYSTEM_CHKDSK, *PFILESYSTEM_CHKDSK;
+
+FILESYSTEM_CHKDSK FileSystems[10] =
+{
+    { L"FAT", VfatChkdsk },
+    { L"FAT32", VfatChkdsk },
+    { L"NTFS", NtfsChkdsk },
+    { L"EXT2", Ext2Chkdsk },
+    { L"EXT3", Ext2Chkdsk },
+    { L"EXT4", Ext2Chkdsk },
+    { L"Btrfs", BtrfsChkdskEx },
+    { L"RFSD", ReiserfsChkdsk },
+    { L"FFS", FfsChkdsk },
+    { L"CDFS", CdfsChkdsk },
+};
+
+HANDLE KeyboardHandle;
 
 /* FUNCTIONS ****************************************************************/
 //
@@ -106,6 +138,67 @@ OpenDirectory(
     return hFile;
 }
 
+static NTSTATUS
+OpenKeyboard(VOID)
+{
+    OBJECT_ATTRIBUTES ObjectAttributes;
+    IO_STATUS_BLOCK IoStatusBlock;
+    UNICODE_STRING KeyboardName = RTL_CONSTANT_STRING(L"\\Device\\KeyboardClass0");
+
+    /* Just open the class driver */
+    InitializeObjectAttributes(&ObjectAttributes,
+                               &KeyboardName,
+                               0,
+                               NULL,
+                               NULL);
+    return  NtOpenFile(&KeyboardHandle,
+                       FILE_ALL_ACCESS,
+                       &ObjectAttributes,
+                       &IoStatusBlock,
+                       FILE_OPEN,
+                       0);
+}
+
+static NTSTATUS
+WaitForKeyboard(
+    IN LONG TimeOut)
+{
+    NTSTATUS Status;
+    IO_STATUS_BLOCK IoStatusBlock;
+    LARGE_INTEGER Offset, Timeout;
+    KEYBOARD_INPUT_DATA InputData;
+
+    /* Attempt to read from the keyboard */
+    Offset.QuadPart = 0;
+    Status = NtReadFile(KeyboardHandle,
+                        NULL,
+                        NULL,
+                        NULL,
+                        &IoStatusBlock,
+                        &InputData,
+                        sizeof(KEYBOARD_INPUT_DATA),
+                        &Offset,
+                        NULL);
+    if (Status == STATUS_PENDING)
+    {
+        /* Wait TimeOut seconds */
+        Timeout.QuadPart = TimeOut * -10000000;
+        Status = NtWaitForSingleObject(KeyboardHandle, FALSE, &Timeout);
+        /* The user didn't enter anything, cancel the read */
+        if (Status == STATUS_TIMEOUT)
+        {
+            NtCancelIoFile(KeyboardHandle, &IoStatusBlock);
+        }
+        /* Else, return some status */
+        else
+        {
+            Status = IoStatusBlock.Status;
+        }
+    }
+
+    return Status;
+}
+
 static NTSTATUS
 GetFileSystem(
     IN LPCWSTR Drive,
@@ -234,7 +327,7 @@ ChkdskCallback(
 
     case DONE:
         Status = (PBOOLEAN)Argument;
-        if (*Status == TRUE)
+        if (*Status != FALSE)
         {
             PrintString("Autochk was unable to complete successfully.\r\n\r\n");
             // Error = TRUE;
@@ -244,53 +337,23 @@ ChkdskCallback(
     return TRUE;
 }
 
-/* Load the provider associated with this file system */
-static PVOID
-LoadProvider(
-    IN PWCHAR FileSystem)
-{
-    UNICODE_STRING ProviderDll;
-    PVOID BaseAddress;
-    NTSTATUS Status;
-
-    /* FIXME: add more providers here */
-
-    if (wcscmp(FileSystem, L"NTFS") == 0)
-    {
-      RtlInitUnicodeString(&ProviderDll, L"untfs.dll");
-    }
-    else if (wcscmp(FileSystem, L"FAT") == 0
-             || wcscmp(FileSystem, L"FAT32") == 0)
-    {
-      RtlInitUnicodeString(&ProviderDll, L"ufat.dll");
-    }
-    else
-    {
-      return NULL;
-    }
-
-    Status = LdrLoadDll(NULL, NULL, &ProviderDll, &BaseAddress);
-    if (!NT_SUCCESS(Status))
-        return NULL;
-    return BaseAddress;
-}
-
 static NTSTATUS
 CheckVolume(
-    IN PWCHAR DrivePath)
+    IN PWCHAR DrivePath,
+    IN LONG TimeOut)
 {
     WCHAR FileSystem[128];
-    ANSI_STRING ChkdskFunctionName = RTL_CONSTANT_STRING("ChkdskEx");
-    PVOID Provider;
-    CHKDSKEX ChkdskFunc;
     WCHAR NtDrivePath[64];
     UNICODE_STRING DrivePathU;
     NTSTATUS Status;
+    DWORD Count;
+
+    PrintString("  Verifying volume %S\r\n", DrivePath);
 
     /* Get the file system */
     Status = GetFileSystem(DrivePath,
                            FileSystem,
-                           sizeof(FileSystem) / sizeof(FileSystem[0]));
+                           ARRAYSIZE(FileSystem));
     if (!NT_SUCCESS(Status))
     {
         DPRINT1("GetFileSystem() failed with status 0x%08lx\n", Status);
@@ -298,47 +361,91 @@ CheckVolume(
         return Status;
     }
 
-    /* Load the provider which will do the chkdsk */
-    Provider = LoadProvider(FileSystem);
-    if (Provider == NULL)
+    PrintString("  Its filesystem type is %S\r\n", FileSystem);
+
+    /* Call provider */
+    for (Count = 0; Count < sizeof(FileSystems) / sizeof(FileSystems[0]); ++Count)
     {
-        DPRINT1("LoadProvider() failed\n");
-        PrintString("  Unable to verify a %S volume\r\n", FileSystem);
-        return STATUS_DLL_NOT_FOUND;
+        if (wcscmp(FileSystem, FileSystems[Count].Name) != 0)
+        {
+            continue;
+        }
+
+        swprintf(NtDrivePath, L"\\??\\");
+        wcscat(NtDrivePath, DrivePath);
+        NtDrivePath[wcslen(NtDrivePath)-1] = 0;
+        RtlInitUnicodeString(&DrivePathU, NtDrivePath);
+
+        DPRINT1("AUTOCHK: Checking %wZ\n", &DrivePathU);
+        /* First, check whether the volume is dirty */
+        Status = FileSystems[Count].ChkdskFunc(&DrivePathU,
+                                               FALSE, // FixErrors
+                                               TRUE, // Verbose
+                                               TRUE, // CheckOnlyIfDirty
+                                               FALSE,// ScanDrive
+                                               ChkdskCallback);
+        /* It is */
+        if (Status == STATUS_DISK_CORRUPT_ERROR)
+        {
+            NTSTATUS WaitStatus;
+
+            /* Let the user decide whether to repair */
+            PrintString("  %S needs to be checked\r\n", DrivePath);
+            PrintString("  You can skip it, but be advised it is not recommanded\r\n");
+            PrintString("  To skip disk checking press any key in %d second(s)\r\n", TimeOut);
+
+            /* Timeout == fix it! */
+            WaitStatus = WaitForKeyboard(TimeOut);
+            if (WaitStatus == STATUS_TIMEOUT)
+            {
+                Status = FileSystems[Count].ChkdskFunc(&DrivePathU,
+                                                       TRUE, // FixErrors
+                                                       TRUE, // Verbose
+                                                       TRUE, // CheckOnlyIfDirty
+                                                       FALSE,// ScanDrive
+                                                       ChkdskCallback);
+            }
+            else
+            {
+                PrintString("  %S checking has been skipped\r\n", DrivePath);
+            }
+        }
+        break;
     }
 
-    /* Get the Chkdsk function address */
-    Status = LdrGetProcedureAddress(Provider,
-                                    &ChkdskFunctionName,
-                                    0,
-                                    (PVOID*)&ChkdskFunc);
-    if (!NT_SUCCESS(Status))
+    if (Count == sizeof(FileSystems) / sizeof(FileSystems[0]))
     {
-        DPRINT1("LdrGetProcedureAddress() failed with status 0x%08lx\n", Status);
+        DPRINT1("File system not supported\n");
         PrintString("  Unable to verify a %S volume\r\n", FileSystem);
-        LdrUnloadDll(Provider);
-        return Status;
+        return STATUS_DLL_NOT_FOUND;
     }
 
-    /* Call provider */
-    // PrintString("  Verifying volume %S\r\n", DrivePath);
-    swprintf(NtDrivePath, L"\\??\\");
-    wcscat(NtDrivePath, DrivePath);
-    NtDrivePath[wcslen(NtDrivePath)-1] = 0;
-    RtlInitUnicodeString(&DrivePathU, NtDrivePath);
-
-    DPRINT("AUTOCHK: Checking %wZ\n", &DrivePathU);
-    Status = ChkdskFunc(&DrivePathU,
-                        TRUE, // FixErrors
-                        TRUE, // Verbose
-                        TRUE, // CheckOnlyIfDirty
-                        FALSE,// ScanDrive
-                        ChkdskCallback);
-
-    LdrUnloadDll(Provider);
     return Status;
 }
 
+static VOID
+QueryTimeout(
+    IN OUT PLONG TimeOut)
+{
+    RTL_QUERY_REGISTRY_TABLE QueryTable[2];
+
+    RtlZeroMemory(QueryTable, sizeof(QueryTable));
+    QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT;
+    QueryTable[0].Name = L"AutoChkTimeOut";
+    QueryTable[0].EntryContext = TimeOut;
+
+    RtlQueryRegistryValues(RTL_REGISTRY_CONTROL, L"Session Manager", QueryTable, NULL, NULL);
+    /* See: https://docs.microsoft.com/en-us/windows-server/administration/windows-commands/autochk */
+    if (*TimeOut > 259200)
+    {
+        *TimeOut = 259200;
+    }
+    else if (*TimeOut < 0)
+    {
+        *TimeOut = 0;
+    }
+}
+
 /* Native image's entry point */
 int
 _cdecl
@@ -351,6 +458,7 @@ _main(int argc,
     ULONG i;
     NTSTATUS Status;
     WCHAR DrivePath[128];
+    LONG TimeOut;
 
     // Win2003 passes the only param - "*". Probably means to check all drives
     /*
@@ -359,6 +467,10 @@ _main(int argc,
         DPRINT("Param %d: %s\n", i, argv[i]);
     */
 
+    /* Query timeout */
+    TimeOut = 3;
+    QueryTimeout(&TimeOut);
+
     /* FIXME: We should probably use here the mount manager to be
      * able to check volumes which don't have a drive letter.
      */
@@ -375,15 +487,27 @@ _main(int argc,
         return 1;
     }
 
+    /* Open keyboard */
+    Status = OpenKeyboard();
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT1("OpenKeyboard() failed with status 0x%08lx\n", Status);
+        return 1;
+    }
+
     for (i = 0; i < 26; i++)
     {
         if ((DeviceMap.Query.DriveMap & (1 << i))
          && (DeviceMap.Query.DriveType[i] == DOSDEVICE_DRIVE_FIXED))
         {
             swprintf(DrivePath, L"%c:\\", L'A'+i);
-            CheckVolume(DrivePath);
+            CheckVolume(DrivePath, TimeOut);
         }
     }
+
+    /* Close keyboard */
+    NtClose(KeyboardHandle);
+
     // PrintString("  Done\r\n\r\n");
     return 0;
 }