Fix the ntstrsafe.h tests to make sense, actually succeed using ntstrsafe.h from...
[reactos.git] / modules / rostests / kmtests / rtl / RtlStrSafe.c
index 807a036..82c2fe5 100644 (file)
@@ -1,8 +1,9 @@
 /*
  * PROJECT:         ReactOS kernel-mode tests
- * LICENSE:         GPL-2.0+ (https://spdx.org/licenses/GPL-2.0+)
+ * LICENSE:         GPL-2.0-or-later (https://spdx.org/licenses/GPL-2.0-or-later)
  * PURPOSE:         Test for ntstrsafe.h functions
  * COPYRIGHT:       Copyright 2018 HernĂ¡n Di Pietro <hernan.di.pietro@gmail.com>
+ *                  Copyright 2019 Colin Finck <colin@reactos.org>
  */
 
 #define KMT_EMULATE_KERNEL
 
 #define TESTAPI static void
 
+static const WCHAR FormatStringInts[] = L"%d %d %d";
+static const WCHAR FormatStringIntsResult[] = L"1 2 3";
+static const WCHAR FormatStringStrs[] = L"%s %s %s";
+
+
 TESTAPI
 Test_RtlUnicodeStringPrintf()
 {
-    WCHAR Buffer[1024];
-    WCHAR OvrBuffer[1024];
-    WCHAR BufferSmall[2];
-    WCHAR BufferSmall2[7];
+    NTSTATUS Status;
+    PWSTR pBuffer = NULL;
+    size_t BufferSize;
+    size_t EqualBytes;
     UNICODE_STRING UsString;
-    const WCHAR FormatStringInts[] = L"%d %d %d";
-    const WCHAR FormatStringStrs[] = L"%s %s %s";
-    const WCHAR Result[] = L"1 2 3";
-    UNICODE_STRING UsStringNull;
 
-    /* No zeros (Don't assume UNICODE_STRINGS are NULL terminated) */
-
-    RtlFillMemory(Buffer, sizeof(Buffer), 0xAA);
-    RtlFillMemory(BufferSmall, sizeof(BufferSmall), 0xAA);
-    RtlFillMemory(BufferSmall2, sizeof(BufferSmall2), 0xAA);
+    KmtStartSeh();
 
     /* STATUS_SUCCESS test */
+    BufferSize = 6 * sizeof(WCHAR);
+    pBuffer = KmtAllocateGuarded(BufferSize);
+    if (!pBuffer)
+        goto Cleanup;
+
+    RtlFillMemory(pBuffer, BufferSize, 0xAA);
+    RtlInitEmptyUnicodeString(&UsString, pBuffer, BufferSize);
+
+    Status = RtlUnicodeStringPrintf(&UsString, FormatStringInts, 1, 2, 3);
+    EqualBytes = RtlCompareMemory(UsString.Buffer, FormatStringIntsResult, sizeof(FormatStringIntsResult));
+    ok_eq_hex(Status, STATUS_SUCCESS);
+    ok_eq_size(EqualBytes, sizeof(FormatStringIntsResult));
+    ok_eq_uint(UsString.Length, sizeof(FormatStringIntsResult) - sizeof(WCHAR));
+    ok_eq_uint(UsString.MaximumLength, BufferSize);    
+
+    KmtFreeGuarded(pBuffer);
+    pBuffer = NULL;
 
-    UsString.Buffer = Buffer;
-    UsString.Length = 0;
-    UsString.MaximumLength = sizeof(Buffer);
-
-    ok_eq_hex(RtlUnicodeStringPrintf(&UsString, FormatStringInts, 1, 2, 3), STATUS_SUCCESS);
-    ok_eq_uint(UsString.Length, sizeof(Result) - sizeof(WCHAR));
-    ok_eq_uint(UsString.MaximumLength, sizeof(Buffer));
-    ok_eq_wchar(UsString.Buffer[0], L'1');
-    ok_eq_wchar(UsString.Buffer[1], L' ');
-    ok_eq_wchar(UsString.Buffer[2], L'2');
-    ok_eq_wchar(UsString.Buffer[3], L' ');
-    ok_eq_wchar(UsString.Buffer[4], L'3');
-    ok_eq_wchar(UsString.Buffer[5], (WCHAR) 0);
-    
     /* STATUS_BUFFER_OVERFLOW tests */
+    BufferSize = 2 * sizeof(WCHAR);
+    pBuffer = KmtAllocateGuarded(BufferSize);
+    if (!pBuffer)
+        goto Cleanup;
 
-    UsString.Buffer = BufferSmall;
-    UsString.Length = 0;
-    UsString.MaximumLength = sizeof(BufferSmall);
-    
-    ok_eq_hex(RtlUnicodeStringPrintf(&UsString, FormatStringStrs, L"AAA", L"BBB", L"CCC"), STATUS_BUFFER_OVERFLOW);
+    RtlInitEmptyUnicodeString(&UsString, pBuffer, BufferSize);
+
+    Status = RtlUnicodeStringPrintf(&UsString, FormatStringStrs, L"AAA", L"BBB", L"CCC");
+    EqualBytes = RtlCompareMemory(UsString.Buffer, L"AA", BufferSize);
+    ok_eq_hex(Status, STATUS_BUFFER_OVERFLOW);
+    ok_eq_size(EqualBytes, BufferSize);
     ok_eq_uint(UsString.Length, UsString.MaximumLength);
-    ok_eq_char(UsString.Buffer[0], L'A');
-    ok_eq_char(UsString.Buffer[1], (WCHAR)0);
 
-    UsString.Buffer = BufferSmall2;
-    UsString.Length = 0;
-    UsString.MaximumLength = sizeof(BufferSmall2);
+    KmtFreeGuarded(pBuffer);
+    pBuffer = NULL;
+
 
-    ok_eq_hex(RtlUnicodeStringPrintf(&UsString, FormatStringStrs, L"0123", L"4567", L"89AB"), STATUS_BUFFER_OVERFLOW);
+    BufferSize = 7 * sizeof(WCHAR);
+    pBuffer = KmtAllocateGuarded(BufferSize);
+    if (!pBuffer)
+        goto Cleanup;
+
+    RtlInitEmptyUnicodeString(&UsString, pBuffer, BufferSize);
+
+    Status = RtlUnicodeStringPrintf(&UsString, FormatStringStrs, L"0123", L"4567", L"89AB");
+    EqualBytes = RtlCompareMemory(UsString.Buffer, L"0123 45", BufferSize);
+    ok_eq_hex(Status, STATUS_BUFFER_OVERFLOW);
+    ok_eq_size(EqualBytes, BufferSize);
     ok_eq_uint(UsString.Length, UsString.MaximumLength);
-    ok_eq_char(UsString.Buffer[0], L'0');
-    ok_eq_char(UsString.Buffer[1], L'1');
-    ok_eq_char(UsString.Buffer[2], L'2');
-    ok_eq_char(UsString.Buffer[3], L'3');
-    ok_eq_char(UsString.Buffer[4], L' ');
-    ok_eq_char(UsString.Buffer[5], L'4');
-    ok_eq_char(UsString.Buffer[6], (WCHAR) 0);
-
-    ///* STATUS_INVALID_PARAMETER tests */
-
-    ok_eq_hex(RtlUnicodeStringPrintf(NULL, FormatStringStrs, L"AAA", L"BBB", L"CCC"), STATUS_INVALID_PARAMETER);
-    ok_eq_hex(RtlUnicodeStringPrintf(&UsString, NULL, L"AAA", L"BBB", L"CCC"), STATUS_INVALID_PARAMETER);
-
-    UsStringNull.Buffer = (PWCH)NULL;
-    UsStringNull.Length = 0;
-    UsStringNull.MaximumLength = 0;
-    ok_eq_bool(RtlUnicodeStringPrintf(&UsStringNull, FormatStringStrs, L"AAA", L"BBB", L"CCC"), STATUS_INVALID_PARAMETER);
-
-    /* Test  for buffer overruns */
-
-    RtlFillMemory(Buffer, sizeof(Buffer), 0xAA);
-    RtlFillMemory(OvrBuffer, sizeof(OvrBuffer), 0xAA);
-    UsString.Buffer = Buffer;
-    UsString.Length = 0;
-    UsString.MaximumLength = 16 * sizeof(WCHAR);
-
-    ok_eq_hex(RtlUnicodeStringPrintf(&UsString, FormatStringStrs, L"abc", L"def", L"ghi"), STATUS_SUCCESS);
-    ok_eq_uint(UsString.Length, sizeof(L"abc def ghi") -sizeof(WCHAR));
-    ok_eq_char(UsString.Buffer[11], (WCHAR)0);
-    ok_eq_uint(0, memcmp(OvrBuffer + 12, Buffer + 12, sizeof(Buffer) - (12 * sizeof(WCHAR))));
+
+    KmtFreeGuarded(pBuffer);
+    pBuffer = NULL;
+
+    // Note: RtlUnicodeStringPrintf returns STATUS_BUFFER_OVERFLOW here while RtlUnicodeStringPrintfEx returns STATUS_INVALID_PARAMETER!
+    // Documented on MSDN and verified with the Win10 version of ntstrsafe.h
+    RtlInitEmptyUnicodeString(&UsString, NULL, 0);
+    Status = RtlUnicodeStringPrintf(&UsString, FormatStringStrs, L"AAA", L"BBB", L"CCC");
+    ok_eq_hex(Status, STATUS_BUFFER_OVERFLOW);
+
+
+Cleanup:
+    if (pBuffer)
+        KmtFreeGuarded(pBuffer);
+
+    // None of these functions should have crashed.
+    KmtEndSeh(STATUS_SUCCESS);
 }
 
 TESTAPI
 Test_RtlUnicodeStringPrintfEx()
 {
-    WCHAR Buffer[32];
-    WCHAR BufferSmall[8] = { 0 };
-    WCHAR OvrBuffer[1024];
-    UNICODE_STRING UsString, RemString;
-    const WCHAR FormatStringInts[] = L"%d %d %d";
-    const WCHAR FormatStringStrs[] = L"%s %s %s";
-    const WCHAR Result[] = L"1 2 3";
+    NTSTATUS Status;
+    PWSTR pBuffer = NULL;
+    size_t BufferSize;
+    size_t EqualBytes;
+    UNICODE_STRING RemString;
+    UNICODE_STRING UsString;
+    WCHAR FillResult[10];
 
-    /* No zeros (Don't assume UNICODE_STRINGS are NULL terminated) */
+    RtlFillMemory(FillResult, sizeof(FillResult), 0xAA);
 
-    RtlFillMemory(Buffer, sizeof(Buffer), 0xAA);
+    KmtStartSeh();
 
-    UsString.Buffer = Buffer;
-    UsString.Length = 0;
-    UsString.MaximumLength = sizeof(Buffer);
+    /* STATUS_SUCCESS test, fill behind flag: low-byte as fill character */
+    BufferSize = sizeof(FormatStringIntsResult) - sizeof(UNICODE_NULL) + sizeof(FillResult);
+    pBuffer = KmtAllocateGuarded(BufferSize);
+    if (!pBuffer)
+        goto Cleanup;
+
+    RtlInitEmptyUnicodeString(&UsString, pBuffer, BufferSize);
+    RtlInitEmptyUnicodeString(&RemString, NULL, 0);
+
+    Status = RtlUnicodeStringPrintfEx(&UsString, &RemString, STRSAFE_FILL_BEHIND | 0xAA, FormatStringInts, 1, 2, 3);
+    EqualBytes = RtlCompareMemory(UsString.Buffer, FormatStringIntsResult, sizeof(FormatStringIntsResult) - sizeof(WCHAR));
+    ok_eq_hex(Status, STATUS_SUCCESS);
+    ok_eq_size(EqualBytes, sizeof(FormatStringIntsResult) - sizeof(WCHAR));
+    ok_eq_uint(UsString.Length, sizeof(FormatStringIntsResult) - sizeof(WCHAR));
+    ok_eq_uint(UsString.MaximumLength, BufferSize);
+
+    ok_eq_pointer(RemString.Buffer, &UsString.Buffer[UsString.Length / sizeof(WCHAR)]);
+    ok_eq_uint(RemString.Length, 0);
+    ok_eq_uint(RemString.MaximumLength, UsString.MaximumLength - UsString.Length);
 
-    RemString.Buffer = (PWCH)NULL;
-    RemString.Length = 0;
-    RemString.MaximumLength = 0;
+    EqualBytes = RtlCompareMemory(RemString.Buffer, FillResult, RemString.MaximumLength);
+    ok_eq_size(EqualBytes, sizeof(FillResult));
+
+    KmtFreeGuarded(pBuffer);
+    pBuffer = NULL;
 
-    /* STATUS_SUCCESS test, fill behind flag: low-byte as fill character */
 
-    ok_eq_hex(RtlUnicodeStringPrintfEx(&UsString, &RemString, STRSAFE_FILL_BEHIND | 0xFF, FormatStringInts, 1, 2, 3), STATUS_SUCCESS);
-    
-    ok_eq_uint(UsString.Length, sizeof(Result) - sizeof(WCHAR));
-    ok_eq_uint(UsString.MaximumLength, sizeof(Buffer));
-    ok_eq_uint(memcmp(UsString.Buffer, Result, sizeof(Result) - sizeof(WCHAR)), 0);
+    /* STATUS_BUFFER_OVERFLOW test  */
+    BufferSize = 8 * sizeof(WCHAR);
+    pBuffer = KmtAllocateGuarded(BufferSize);
+    if (!pBuffer)
+        goto Cleanup;
+
+    RtlInitEmptyUnicodeString(&UsString, pBuffer, BufferSize);
+    RtlInitEmptyUnicodeString(&RemString, NULL, 0);
+
+    Status = RtlUnicodeStringPrintfEx(&UsString, &RemString, 0, FormatStringStrs, L"AAA", L"BBB", L"CCC");
+    EqualBytes = RtlCompareMemory(UsString.Buffer, L"AAA BBB ", UsString.Length);
+    ok_eq_hex(Status, STATUS_BUFFER_OVERFLOW);
+    ok_eq_size(EqualBytes, UsString.Length);
+    ok_eq_uint(UsString.Length, UsString.MaximumLength);
 
-    ok_eq_pointer(RemString.Buffer, UsString.Buffer + (UsString.Length / sizeof(WCHAR)));
+    ok_eq_pointer(RemString.Buffer, &UsString.Buffer[UsString.Length / sizeof(WCHAR)]);
     ok_eq_uint(RemString.Length, 0);
-    ok_eq_uint(RemString.MaximumLength, UsString.MaximumLength - UsString.Length);
+    ok_eq_uint(RemString.MaximumLength, 0);
 
-    /* STATUS_BUFFER_OVERFLOW test  */
+    KmtFreeGuarded(pBuffer);
+    pBuffer = NULL;
 
-    UsString.Buffer = BufferSmall;
-    UsString.Length = 0;
-    UsString.MaximumLength = sizeof(BufferSmall);
 
-    RemString.Buffer = (PWCH)NULL;
-    RemString.Length = 0;
-    RemString.MaximumLength = 0;
-    
-    ok_eq_hex(RtlUnicodeStringPrintfEx(&UsString, &RemString, 0, FormatStringStrs, L"AAA", L"BBB", L"CCC"), STATUS_BUFFER_OVERFLOW);
-    ok_eq_uint(UsString.Length, UsString.MaximumLength);
-    ok_eq_char(UsString.Buffer[0], L'A');
-    ok_eq_char(UsString.Buffer[1], L'A');
-    ok_eq_char(UsString.Buffer[2], L'A');
-    ok_eq_char(UsString.Buffer[3], L' ');
-    ok_eq_char(UsString.Buffer[4], L'B');
-    ok_eq_char(UsString.Buffer[5], L'B');
-    ok_eq_char(UsString.Buffer[6], L'B');
-    ok_eq_char(UsString.Buffer[7], (WCHAR)0);
-
-    // Takes \0 into account
-    ok_eq_pointer(RemString.Buffer, UsString.Buffer + (UsString.Length - 1) / sizeof(WCHAR));
-    ok_eq_uint(RemString.Length, 0); 
-    ok_eq_uint(RemString.MaximumLength, 2);
-
-    /* Test  for buffer overruns */
-
-    RtlFillMemory(Buffer, sizeof(Buffer), 0xAA);
-    RtlFillMemory(OvrBuffer, sizeof(OvrBuffer), 0xAA);
-    UsString.Buffer = Buffer;
-    UsString.Length = 0;
-    UsString.MaximumLength = 16 * sizeof(WCHAR);
-
-    ok_eq_hex(RtlUnicodeStringPrintfEx(&UsString, &RemString, 0, FormatStringStrs, L"abc", L"def", L"ghi"), STATUS_SUCCESS);
-    ok_eq_uint(UsString.Length, sizeof(L"abc def ghi") - sizeof(WCHAR));
-    ok_eq_char(UsString.Buffer[11], (WCHAR)0);
-    ok_eq_uint(0, memcmp(OvrBuffer + 12, Buffer + 12, sizeof(Buffer) - (12 * sizeof(WCHAR))));
+    // Note: RtlUnicodeStringPrintf returns STATUS_BUFFER_OVERFLOW here while RtlUnicodeStringPrintfEx returns STATUS_INVALID_PARAMETER!
+    // Documented on MSDN and verified with the Win10 version of ntstrsafe.h
+    RtlInitEmptyUnicodeString(&UsString, NULL, 0);
+    Status = RtlUnicodeStringPrintfEx(&UsString, NULL, 0, FormatStringStrs, L"AAA", L"BBB", L"CCC");
+    ok_eq_hex(Status, STATUS_INVALID_PARAMETER);
+
+
+Cleanup:
+    if (pBuffer)
+        KmtFreeGuarded(pBuffer);
+
+    // None of these functions should have crashed.
+    KmtEndSeh(STATUS_SUCCESS);
 }
 
 START_TEST(RtlStrSafe)