[NTOSKRNL/USERINIT/WIN32K]
[reactos.git] / reactos / ntoskrnl / config / cmboot.c
index 1416c4d..6af7e3c 100644 (file)
 #include "ntoskrnl.h"
 #define NDEBUG
 #include "debug.h"
+
+/* GLOBALS ********************************************************************/
+
+extern ULONG InitSafeBootMode;
+
 /* FUNCTIONS ******************************************************************/
 
 HCELL_INDEX
@@ -383,6 +387,7 @@ CmpFindDrivers(IN PHHIVE Hive,
                IN PLIST_ENTRY DriverListHead)
 {
     HCELL_INDEX ServicesCell, ControlCell, GroupOrderCell, DriverCell;
+    HCELL_INDEX SafeBootCell = HCELL_NIL;
     UNICODE_STRING Name;
     ULONG i;
     WCHAR Buffer[128];
@@ -416,6 +421,31 @@ CmpFindDrivers(IN PHHIVE Hive,
     GroupOrderCell = CmpFindSubKeyByName(Hive, Node, &Name);
     if (GroupOrderCell == HCELL_NIL) return FALSE;
 
+    /* Get Safe Boot cell */
+    if(InitSafeBootMode)
+    {
+        /* Open the Safe Boot key */
+        RtlInitUnicodeString(&Name, L"SafeBoot");
+        Node = HvGetCell(Hive, ControlCell);
+        ASSERT(Node);
+        SafeBootCell = CmpFindSubKeyByName(Hive, Node, &Name);
+        if (SafeBootCell == HCELL_NIL) return FALSE;
+
+        /* Open the correct start key (depending on the mode) */
+        Node = HvGetCell(Hive, SafeBootCell);
+        ASSERT(Node);
+        switch(InitSafeBootMode)
+        {
+            /* NOTE: Assumes MINIMAL (1) and DSREPAIR (3) load same items */
+            case 1:
+            case 3: RtlInitUnicodeString(&Name, L"Minimal"); break;
+            case 2: RtlInitUnicodeString(&Name, L"Network"); break;
+            default: return FALSE;
+        }
+        SafeBootCell = CmpFindSubKeyByName(Hive, Node, &Name);
+        if(SafeBootCell == HCELL_NIL) return FALSE;
+    }
+
     /* Build the root registry path */
     RtlInitEmptyUnicodeString(&KeyPath, Buffer, sizeof(Buffer));
     RtlAppendUnicodeToString(&KeyPath, L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\");
@@ -425,8 +455,9 @@ CmpFindDrivers(IN PHHIVE Hive,
     DriverCell = CmpFindSubKeyByNumber(Hive, ServicesNode, i);
     while (DriverCell != HCELL_NIL)
     {
-        /* Make sure it's a driver of this start type */
-        if (CmpIsLoadType(Hive, DriverCell, LoadType))
+        /* Make sure it's a driver of this start type AND is "safe" to load */
+        if (CmpIsLoadType(Hive, DriverCell, LoadType) &&
+            CmpIsSafe(Hive, SafeBootCell, DriverCell))
         {
             /* Add it to the list */
             CmpAddDriverToList(Hive,
@@ -687,4 +718,106 @@ CmpResolveDriverDependencies(IN PLIST_ENTRY DriverListHead)
     return TRUE;
 }
 
+BOOLEAN
+NTAPI
+INIT_FUNCTION
+CmpIsSafe(IN PHHIVE Hive,
+          IN HCELL_INDEX SafeBootCell,
+          IN HCELL_INDEX DriverCell)
+{
+    PCM_KEY_NODE SafeBootNode;
+    PCM_KEY_NODE DriverNode;
+    PCM_KEY_VALUE KeyValue;
+    HCELL_INDEX CellIndex;
+    ULONG Length = 0;
+    UNICODE_STRING Name;
+    PWCHAR OriginalName;
+    ASSERT(Hive->ReleaseCellRoutine == NULL);
+
+    /* Driver key node (mandatory) */
+    ASSERT(DriverCell != HCELL_NIL);
+    DriverNode = HvGetCell(Hive, DriverCell);
+    ASSERT(DriverNode);
+
+    /* Safe boot key node (optional but return TRUE if not present) */
+    if(SafeBootCell == HCELL_NIL) return TRUE;
+    SafeBootNode = HvGetCell(Hive, SafeBootCell);
+    if(!SafeBootNode) return FALSE;
+
+    /* Search by the name from the group */
+    RtlInitUnicodeString(&Name, L"Group");
+    CellIndex = CmpFindValueByName(Hive, DriverNode, &Name);
+    if(CellIndex != HCELL_NIL)
+    {
+        KeyValue = HvGetCell(Hive, CellIndex);
+        ASSERT(KeyValue);
+        if (KeyValue->Type == REG_SZ || KeyValue->Type == REG_EXPAND_SZ)
+        {
+            /* Compose the search 'key' */
+            Name.Buffer = (PWCHAR)CmpValueToData(Hive, KeyValue, &Length);
+            if (!Name.Buffer) return FALSE;
+            Name.Length = Length - sizeof(UNICODE_NULL);
+            Name.MaximumLength = Name.Length;
+            /* Search for corresponding key in the Safe Boot key */
+            CellIndex = CmpFindSubKeyByName(Hive, SafeBootNode, &Name);
+            if(CellIndex != HCELL_NIL) return TRUE;
+        }
+    }
+
+    /* Group has not been found - find driver name */
+    Name.Length = DriverNode->Flags & KEY_COMP_NAME ?
+                              CmpCompressedNameSize(DriverNode->Name,
+                                                    DriverNode->NameLength) :
+                              DriverNode->NameLength;
+    Name.MaximumLength = Name.Length;
+    /* Now allocate the buffer for it and copy the name */
+    Name.Buffer = CmpAllocate(Name.Length, FALSE, TAG_CM);
+    if (!Name.Buffer) return FALSE;
+    if (DriverNode->Flags & KEY_COMP_NAME)
+    {
+        /* Compressed name */
+        CmpCopyCompressedName(Name.Buffer,
+                              Name.Length,
+                              DriverNode->Name,
+                              DriverNode->NameLength);
+    }
+    else
+    {
+        /* Normal name */
+        RtlCopyMemory(Name.Buffer, DriverNode->Name, DriverNode->NameLength);
+    }
+    CellIndex = CmpFindSubKeyByName(Hive, SafeBootNode, &Name);
+    RtlFreeUnicodeString(&Name);
+    if(CellIndex != HCELL_NIL) return TRUE;
+
+    /* Not group or driver name - search by image name */
+    RtlInitUnicodeString(&Name, L"ImagePath");
+    CellIndex = CmpFindValueByName(Hive, DriverNode, &Name);
+    if(CellIndex != HCELL_NIL)
+    {
+        KeyValue = HvGetCell(Hive, CellIndex);
+        ASSERT(KeyValue);
+        if (KeyValue->Type == REG_SZ || KeyValue->Type == REG_EXPAND_SZ)
+        {
+            /* Compose the search 'key' */
+            OriginalName = (PWCHAR)CmpValueToData(Hive, KeyValue, &Length);
+            if (!OriginalName) return FALSE;
+            /* Get the base image file name */
+            Name.Buffer = wcsrchr(OriginalName, L'\\');
+            if (!Name.Buffer) return FALSE;
+            ++Name.Buffer;
+            /* Length of the base name must be >=1 */
+            Name.Length = Length - ((PUCHAR)Name.Buffer - (PUCHAR)OriginalName)
+                                 - sizeof(UNICODE_NULL);
+            if(Name.Length < 1) return FALSE;
+            Name.MaximumLength = Name.Length;
+            /* Search for corresponding key in the Safe Boot key */
+            CellIndex = CmpFindSubKeyByName(Hive, SafeBootNode, &Name);
+            if(CellIndex != HCELL_NIL) return TRUE;
+        }
+    }
+    /* Nothing found - nothing else to search */
+    return FALSE;
+}
+
 /* EOF */