[SMSS] Use RTL string-safe functions in critical places. Add validity checks for...
[reactos.git] / base / system / smss / sminit.c
index d671bb7..8960010 100644 (file)
@@ -789,12 +789,13 @@ SmpTranslateSystemPartitionInformation(VOID)
     UNICODE_STRING UnicodeString, LinkTarget, SearchString, SystemPartition;
     OBJECT_ATTRIBUTES ObjectAttributes;
     HANDLE KeyHandle, LinkHandle;
-    CHAR ValueBuffer[512 + sizeof(KEY_VALUE_PARTIAL_INFORMATION)];
-    PKEY_VALUE_PARTIAL_INFORMATION PartialInfo = (PVOID)ValueBuffer;
     ULONG Length, Context;
-    CHAR DirInfoBuffer[512 + sizeof(OBJECT_DIRECTORY_INFORMATION)];
-    POBJECT_DIRECTORY_INFORMATION DirInfo = (PVOID)DirInfoBuffer;
+    size_t StrLength;
     WCHAR LinkBuffer[MAX_PATH];
+    CHAR ValueBuffer[sizeof(KEY_VALUE_PARTIAL_INFORMATION) + 512];
+    PKEY_VALUE_PARTIAL_INFORMATION PartialInfo = (PVOID)ValueBuffer;
+    CHAR DirInfoBuffer[sizeof(OBJECT_DIRECTORY_INFORMATION) + 512];
+    POBJECT_DIRECTORY_INFORMATION DirInfo = (PVOID)DirInfoBuffer;
 
     /* Open the setup key */
     RtlInitUnicodeString(&UnicodeString, L"\\Registry\\Machine\\System\\Setup");
@@ -806,7 +807,7 @@ SmpTranslateSystemPartitionInformation(VOID)
     Status = NtOpenKey(&KeyHandle, KEY_READ, &ObjectAttributes);
     if (!NT_SUCCESS(Status))
     {
-        DPRINT1("SMSS: can't open system setup key for reading: 0x%x\n", Status);
+        DPRINT1("SMSS: Cannot open system setup key for reading: 0x%x\n", Status);
         return;
     }
 
@@ -819,14 +820,22 @@ SmpTranslateSystemPartitionInformation(VOID)
                              sizeof(ValueBuffer),
                              &Length);
     NtClose(KeyHandle);
-    if (!NT_SUCCESS(Status))
+    if (!NT_SUCCESS(Status) ||
+        ((PartialInfo->Type != REG_SZ) && (PartialInfo->Type != REG_EXPAND_SZ)))
     {
-        DPRINT1("SMSS: can't query SystemPartition value: 0x%x\n", Status);
+        DPRINT1("SMSS: Cannot query SystemPartition value (Type %lu, Status 0x%x)\n",
+                PartialInfo->Type, Status);
         return;
     }
 
-    /* Initialize the system partition string string */
-    RtlInitUnicodeString(&SystemPartition, (PWCHAR)PartialInfo->Data);
+    /* Initialize the system partition string */
+    RtlInitEmptyUnicodeString(&SystemPartition,
+                              (PWCHAR)PartialInfo->Data,
+                              PartialInfo->DataLength);
+    RtlStringCbLengthW(SystemPartition.Buffer,
+                       SystemPartition.MaximumLength,
+                       &StrLength);
+    SystemPartition.Length = (USHORT)StrLength;
 
     /* Enumerate the directory looking for the symbolic link string */
     RtlInitUnicodeString(&SearchString, L"SymbolicLink");
@@ -840,7 +849,7 @@ SmpTranslateSystemPartitionInformation(VOID)
                                     NULL);
     if (!NT_SUCCESS(Status))
     {
-        DPRINT1("SMSS: can't find drive letter for system partition\n");
+        DPRINT1("SMSS: Cannot find drive letter for system partition\n");
         return;
     }
 
@@ -872,8 +881,8 @@ SmpTranslateSystemPartitionInformation(VOID)
                 /* Check if it matches the string we had found earlier */
                 if ((NT_SUCCESS(Status)) &&
                     ((RtlEqualUnicodeString(&SystemPartition,
-                                           &LinkTarget,
-                                           TRUE)) ||
+                                            &LinkTarget,
+                                            TRUE)) ||
                     ((RtlPrefixUnicodeString(&SystemPartition,
                                              &LinkTarget,
                                              TRUE)) &&
@@ -896,7 +905,7 @@ SmpTranslateSystemPartitionInformation(VOID)
     } while (NT_SUCCESS(Status));
     if (!NT_SUCCESS(Status))
     {
-        DPRINT1("SMSS: can't find drive letter for system partition\n");
+        DPRINT1("SMSS: Cannot find drive letter for system partition\n");
         return;
     }
 
@@ -911,7 +920,7 @@ SmpTranslateSystemPartitionInformation(VOID)
     Status = NtOpenKey(&KeyHandle, KEY_ALL_ACCESS, &ObjectAttributes);
     if (!NT_SUCCESS(Status))
     {
-        DPRINT1("SMSS: can't open software setup key for writing: 0x%x\n",
+        DPRINT1("SMSS: Cannot open software setup key for writing: 0x%x\n",
                 Status);
         return;
     }
@@ -1349,7 +1358,7 @@ NTAPI
 SmpProcessModuleImports(IN PVOID Unused,
                         IN PCHAR ImportName)
 {
-    ULONG Length = 0, Chars;
+    ULONG Length = 0;
     WCHAR Buffer[MAX_PATH];
     PWCHAR DllName, DllValue;
     ANSI_STRING ImportString;
@@ -1365,20 +1374,22 @@ SmpProcessModuleImports(IN PVOID Unused,
     Status = RtlAnsiStringToUnicodeString(&ImportUnicodeString, &ImportString, FALSE);
     if (!NT_SUCCESS(Status)) return;
 
-    /* Loop in case we find a forwarder */
-    ImportUnicodeString.MaximumLength = ImportUnicodeString.Length + sizeof(UNICODE_NULL);
+    /* Loop to find the DLL file extension */
     while (Length < ImportUnicodeString.Length)
     {
         if (ImportUnicodeString.Buffer[Length / sizeof(WCHAR)] == L'.') break;
         Length += sizeof(WCHAR);
     }
 
-    /* Break up the values as needed */
+    /*
+     * Break up the values as needed; the buffer acquires the form:
+     * "dll_name.dll\0dll_name\0"
+     */
     DllValue = ImportUnicodeString.Buffer;
-    DllName = &ImportUnicodeString.Buffer[ImportUnicodeString.MaximumLength / sizeof(WCHAR)];
-    Chars = Length >> 1;
-    wcsncpy(DllName, ImportUnicodeString.Buffer, Chars);
-    DllName[Chars] = 0;
+    DllName = &ImportUnicodeString.Buffer[(ImportUnicodeString.Length + sizeof(UNICODE_NULL)) / sizeof(WCHAR)];
+    RtlStringCbCopyNW(DllName,
+                      ImportUnicodeString.MaximumLength - (ImportUnicodeString.Length + sizeof(UNICODE_NULL)),
+                      ImportUnicodeString.Buffer, Length);
 
     /* Add the DLL to the list */
     SmpSaveRegistryValue(&SmpKnownDllsList, DllName, DllValue, TRUE);
@@ -1652,9 +1663,11 @@ SmpCreateDynamicEnvironmentVariables(VOID)
     OBJECT_ATTRIBUTES ObjectAttributes;
     UNICODE_STRING ValueName, DestinationString;
     HANDLE KeyHandle, KeyHandle2;
-    ULONG ResultLength;
     PWCHAR ValueData;
-    WCHAR ValueBuffer[512], ValueBuffer2[512];
+    ULONG ResultLength;
+    size_t StrLength;
+    WCHAR ValueBuffer[sizeof(KEY_VALUE_PARTIAL_INFORMATION) + 512];
+    WCHAR ValueBuffer2[sizeof(KEY_VALUE_PARTIAL_INFORMATION) + 512];
     PKEY_VALUE_PARTIAL_INFORMATION PartialInfo = (PVOID)ValueBuffer;
     PKEY_VALUE_PARTIAL_INFORMATION PartialInfo2 = (PVOID)ValueBuffer2;
 
@@ -1798,15 +1811,27 @@ SmpCreateDynamicEnvironmentVariables(VOID)
                              PartialInfo,
                              sizeof(ValueBuffer),
                              &ResultLength);
-    if (!NT_SUCCESS(Status))
+    if (!NT_SUCCESS(Status) ||
+        ((PartialInfo->Type != REG_SZ) && (PartialInfo->Type != REG_EXPAND_SZ)))
     {
         NtClose(KeyHandle2);
         NtClose(KeyHandle);
-        DPRINT1("SMSS: Unable to read %wZ\\%wZ - %x\n",
-                &DestinationString, &ValueName, Status);
+        DPRINT1("SMSS: Unable to read %wZ\\%wZ (Type %lu, Status 0x%x)\n",
+                &DestinationString, &ValueName, PartialInfo->Type, Status);
         return Status;
     }
 
+    /* Initialize the string so that it can be large enough
+     * to contain both the identifier and the vendor strings. */
+    RtlInitEmptyUnicodeString(&DestinationString,
+                              (PWCHAR)PartialInfo->Data,
+                              sizeof(ValueBuffer) -
+                                FIELD_OFFSET(KEY_VALUE_PARTIAL_INFORMATION, Data));
+    RtlStringCbLengthW(DestinationString.Buffer,
+                       PartialInfo->DataLength,
+                       &StrLength);
+    DestinationString.Length = (USHORT)StrLength;
+
     /* As well as the vendor... */
     RtlInitUnicodeString(&ValueName, L"VendorIdentifier");
     Status = NtQueryValueKey(KeyHandle2,
@@ -1816,23 +1841,27 @@ SmpCreateDynamicEnvironmentVariables(VOID)
                              sizeof(ValueBuffer2),
                              &ResultLength);
     NtClose(KeyHandle2);
-    if (NT_SUCCESS(Status))
+    if (NT_SUCCESS(Status) &&
+        ((PartialInfo2->Type == REG_SZ) || (PartialInfo2->Type == REG_EXPAND_SZ)))
     {
         /* To combine it into a single string */
-        swprintf((PWCHAR)PartialInfo->Data + wcslen((PWCHAR)PartialInfo->Data),
-                 L", %s",
-                 (PWCHAR)PartialInfo2->Data);
+        RtlStringCbPrintfW(DestinationString.Buffer + DestinationString.Length / sizeof(WCHAR),
+                           DestinationString.MaximumLength - DestinationString.Length,
+                           L", %.*s",
+                           PartialInfo2->DataLength / sizeof(WCHAR),
+                           (PWCHAR)PartialInfo2->Data);
+        DestinationString.Length = (USHORT)(wcslen(DestinationString.Buffer) * sizeof(WCHAR));
     }
 
     /* So that we can set this as the PROCESSOR_IDENTIFIER variable */
     RtlInitUnicodeString(&ValueName, L"PROCESSOR_IDENTIFIER");
-    DPRINT("Setting %wZ to %S\n", &ValueName, PartialInfo->Data);
+    DPRINT("Setting %wZ to %wZ\n", &ValueName, &DestinationString);
     Status = NtSetValueKey(KeyHandle,
                            &ValueName,
                            0,
                            REG_SZ,
-                           PartialInfo->Data,
-                           (wcslen((PWCHAR)PartialInfo->Data) + 1) * sizeof(WCHAR));
+                           DestinationString.Buffer,
+                           DestinationString.Length + sizeof(UNICODE_NULL));
     if (!NT_SUCCESS(Status))
     {
         DPRINT1("SMSS: Failed writing %wZ environment variable - %x\n",
@@ -1922,7 +1951,9 @@ SmpCreateDynamicEnvironmentVariables(VOID)
                                  sizeof(ValueBuffer),
                                  &ResultLength);
         NtClose(KeyHandle2);
-        if (NT_SUCCESS(Status))
+        if (NT_SUCCESS(Status) &&
+            (PartialInfo->Type == REG_DWORD) &&
+            (PartialInfo->DataLength >= sizeof(ULONG)))
         {
             /* Convert from the integer value to the correct specifier */
             RtlInitUnicodeString(&ValueName, L"SAFEBOOT_OPTION");
@@ -1957,7 +1988,8 @@ SmpCreateDynamicEnvironmentVariables(VOID)
         }
         else
         {
-            DPRINT1("SMSS: Failed querying safeboot option = %x\n", Status);
+            DPRINT1("SMSS: Failed to query SAFEBOOT option (Type %lu, Status 0x%x)\n",
+                    PartialInfo->Type, Status);
         }
     }