[RTL]
[reactos.git] / reactos / lib / rtl / registry.c
index d6e4458..ed158e8 100644 (file)
@@ -10,6 +10,9 @@
 /* INCLUDES *****************************************************************/
 
 #include <rtl.h>
+
+#include <ndk/cmfuncs.h>
+
 #define NDEBUG
 #include <debug.h>
 
@@ -38,7 +41,7 @@ RtlpQueryRegistryDirect(IN ULONG ValueType,
                         IN ULONG ValueLength,
                         IN PVOID Buffer)
 {
-    USHORT ActualLength = (USHORT)ValueLength;
+    USHORT ActualLength;
     PUNICODE_STRING ReturnString = Buffer;
     PULONG Length = Buffer;
     ULONG RealLength;
@@ -49,7 +52,10 @@ RtlpQueryRegistryDirect(IN ULONG ValueType,
         (ValueType == REG_MULTI_SZ))
     {
         /* Normalize the length */
-        if (ValueLength > MAXUSHORT) ValueLength = MAXUSHORT;
+        if (ValueLength > MAXUSHORT)
+            ActualLength = MAXUSHORT;
+        else
+            ActualLength = (USHORT)ValueLength;
 
         /* Check if the return string has been allocated */
         if (!ReturnString->Buffer)
@@ -116,8 +122,9 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
                              IN PVOID Context,
                              IN PVOID Environment)
 {
-    ULONG InfoLength, Length, c;
-    LONG RequiredLength, SpareLength;
+    ULONG InfoLength;
+    SIZE_T Length, SpareLength, c;
+    ULONG RequiredLength;
     PCHAR SpareData, DataEnd;
     ULONG Type;
     PWCHAR Name, p, ValueEnd;
@@ -204,10 +211,10 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
 
             /* Check if we have space to copy the data */
             RequiredLength = KeyValueInfo->NameLength + sizeof(UNICODE_NULL);
-            if (SpareLength < RequiredLength)
+            if ((SpareData > DataEnd) || (SpareLength < RequiredLength))
             {
                 /* Fail and return the missing length */
-                *InfoSize = SpareData - (PCHAR)KeyValueInfo + RequiredLength;
+                *InfoSize = (ULONG)(SpareData - (PCHAR)KeyValueInfo) + RequiredLength;
                 return STATUS_BUFFER_TOO_SMALL;
             }
 
@@ -241,7 +248,8 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
         {
             /* Prepare defaults */
             Status = STATUS_SUCCESS;
-            ValueEnd = (PWSTR)((ULONG_PTR)Data + Length) - sizeof(UNICODE_NULL);
+            /* Skip the last two UNICODE_NULL chars (the terminating null string) */
+            ValueEnd = (PWSTR)((ULONG_PTR)Data + Length - 2 * sizeof(UNICODE_NULL));
             p = Data;
 
             /* Loop all strings */
@@ -257,11 +265,11 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
                     /* Do the query */
                     Status = RtlpQueryRegistryDirect(REG_SZ,
                                                      Data,
-                                                     Length,
+                                                     (ULONG)Length,
                                                      QueryTable->EntryContext);
-                    QueryTable->EntryContext = (PVOID)((ULONG_PTR)QueryTable->
-                                                       EntryContext +
-                                                       sizeof(UNICODE_STRING));
+                    QueryTable->EntryContext =
+                        (PVOID)((ULONG_PTR)QueryTable->EntryContext +
+                                sizeof(UNICODE_STRING));
                 }
                 else
                 {
@@ -269,7 +277,7 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
                     Status = QueryTable->QueryRoutine(Name,
                                                       REG_SZ,
                                                       Data,
-                                                      Length,
+                                                      (ULONG)Length,
                                                       Context,
                                                       QueryTable->EntryContext);
                 }
@@ -311,7 +319,7 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
             if (FoundExpander)
             {
                 /* Setup the source string */
-                RtlInitEmptyUnicodeString(&Source, Data, Length);
+                RtlInitEmptyUnicodeString(&Source, Data, (USHORT)Length);
                 Source.Length = Source.MaximumLength - sizeof(UNICODE_NULL);
 
                 /* Setup the desination string */
@@ -326,21 +334,21 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
                 else if (SpareLength <= MAXUSHORT)
                 {
                     /* This is the good case, where we fit into a string */
-                    Destination.MaximumLength = SpareLength;
-                    Destination.Buffer[SpareLength / 2 - 1] = UNICODE_NULL;
+                    Destination.MaximumLength = (USHORT)SpareLength;
+                    Destination.Buffer[SpareLength / sizeof(WCHAR) - 1] = UNICODE_NULL;
                 }
                 else
                 {
                     /* We can't fit into a string, so truncate */
                     Destination.MaximumLength = MAXUSHORT;
-                    Destination.Buffer[MAXUSHORT / 2 - 1] = UNICODE_NULL;
+                    Destination.Buffer[MAXUSHORT / sizeof(WCHAR) - 1] = UNICODE_NULL;
                 }
 
                 /* Expand the strings and set our type as one string */
                 Status = RtlExpandEnvironmentStrings_U(Environment,
                                                        &Source,
                                                        &Destination,
-                                                       (PULONG)&RequiredLength);
+                                                       &RequiredLength);
                 Type = REG_SZ;
 
                 /* Check for success */
@@ -356,9 +364,8 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
                     if (Status == STATUS_BUFFER_TOO_SMALL)
                     {
                         /* Set the required missing length */
-                        *InfoSize = SpareData -
-                                    (PCHAR)KeyValueInfo +
-                                    RequiredLength;
+                        *InfoSize = (ULONG)(SpareData - (PCHAR)KeyValueInfo) +
+                                           RequiredLength;
 
                         /* Notify debugger */
                         DPRINT1("RTL: Expand variables for %wZ failed - "
@@ -391,7 +398,7 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
         /* Return the data */
         Status = RtlpQueryRegistryDirect(Type,
                                          Data,
-                                         Length,
+                                         (ULONG)Length,
                                          QueryTable->EntryContext);
     }
     else
@@ -400,7 +407,7 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
         Status = QueryTable->QueryRoutine(Name,
                                           Type,
                                           Data,
-                                          Length,
+                                          (ULONG)Length,
                                           Context,
                                           QueryTable->EntryContext);
     }
@@ -409,12 +416,15 @@ RtlpCallQueryRegistryRoutine(IN PRTL_QUERY_REGISTRY_TABLE QueryTable,
     return (Status == STATUS_BUFFER_TOO_SMALL) ? STATUS_SUCCESS : Status;
 }
 
+_Success_(return!=NULL || BufferSize==0)
+_When_(BufferSize!=NULL,__drv_allocatesMem(Mem))
 PVOID
 NTAPI
-RtlpAllocDeallocQueryBuffer(IN OUT PSIZE_T BufferSize,
-                            IN PVOID OldBuffer,
-                            IN SIZE_T OldBufferSize,
-                            OUT PNTSTATUS Status)
+RtlpAllocDeallocQueryBuffer(
+    _In_opt_ PSIZE_T BufferSize,
+    _In_opt_ __drv_freesMem(Mem) PVOID OldBuffer,
+    _In_ SIZE_T OldBufferSize,
+    _Out_opt_ _On_failure_(_Post_satisfies_(*Status < 0)) PNTSTATUS Status)
 {
     PVOID Buffer = NULL;
 
@@ -541,6 +551,20 @@ RtlpGetRegistryHandle(IN ULONG RelativeTo,
     return Status;
 }
 
+FORCEINLINE
+VOID
+RtlpCloseRegistryHandle(
+    _In_ ULONG RelativeTo,
+    _In_ HANDLE KeyHandle)
+{
+    /* Did the caller pass a key handle? */
+    if (!(RelativeTo & RTL_REGISTRY_HANDLE))
+    {
+        /* We opened the key in RtlpGetRegistryHandle, so close it now */
+        ZwClose(KeyHandle);
+    }
+}
+
 /* PUBLIC FUNCTIONS **********************************************************/
 
 /*
@@ -562,7 +586,7 @@ RtlCheckRegistryKey(IN ULONG RelativeTo,
                                    &KeyHandle);
     if (!NT_SUCCESS(Status)) return Status;
 
-    /* All went well, close the handle and return success */
+    /* Close the handle even for RTL_REGISTRY_HANDLE */
     ZwClose(KeyHandle);
     return STATUS_SUCCESS;
 }
@@ -586,8 +610,8 @@ RtlCreateRegistryKey(IN ULONG RelativeTo,
                                    &KeyHandle);
     if (!NT_SUCCESS(Status)) return Status;
 
-    /* All went well, close the handle and return success */
-    ZwClose(KeyHandle);
+    /* All went well, close the handle and return status */
+    RtlpCloseRegistryHandle(RelativeTo, KeyHandle);
     return STATUS_SUCCESS;
 }
 
@@ -616,8 +640,8 @@ RtlDeleteRegistryValue(IN ULONG RelativeTo,
     RtlInitUnicodeString(&Name, ValueName);
     Status = ZwDeleteValueKey(KeyHandle, &Name);
 
-    /* All went well, close the handle and return status */
-    ZwClose(KeyHandle);
+    /* Close the handle and return status */
+    RtlpCloseRegistryHandle(RelativeTo, KeyHandle);
     return Status;
 }
 
@@ -654,8 +678,8 @@ RtlWriteRegistryValue(IN ULONG RelativeTo,
                            ValueData,
                            ValueLength);
 
-    /* All went well, close the handle and return status */
-    ZwClose(KeyHandle);
+    /* Close the handle and return status */
+    RtlpCloseRegistryHandle(RelativeTo, KeyHandle);
     return Status;
 }
 
@@ -756,7 +780,7 @@ RtlFormatCurrentUserKeyPath(OUT PUNICODE_STRING KeyPath)
     /* Initialize a string */
     RtlInitEmptyUnicodeString(KeyPath,
                               RtlpAllocateStringMemory(Length, TAG_USTR),
-                              Length);
+                              (USHORT)Length);
     if (!KeyPath->Buffer)
     {
         /* Free the string and fail */
@@ -834,13 +858,13 @@ RtlpNtEnumerateSubKey(IN HANDLE KeyHandle,
                             KeyInfo,
                             BufferLength,
                             &ReturnedLength);
-    if (NT_SUCCESS(Status))
+    if (NT_SUCCESS(Status) && (KeyInfo != NULL))
     {
         /* Check if the name fits */
         if (KeyInfo->NameLength <= SubKeyName->MaximumLength)
         {
             /* Set the length */
-            SubKeyName->Length = KeyInfo->NameLength;
+            SubKeyName->Length = (USHORT)KeyInfo->NameLength;
 
             /* Copy it */
             RtlMoveMemory(SubKeyName->Buffer,
@@ -1002,7 +1026,7 @@ RtlQueryRegistryValues(IN ULONG RelativeTo,
     if (!KeyValueInfo)
     {
         /* Close the handle if we have one and fail */
-        if (!(RelativeTo & RTL_REGISTRY_HANDLE)) ZwClose(KeyHandle);
+        RtlpCloseRegistryHandle(RelativeTo, KeyHandle);
         return Status;
     }
 
@@ -1094,7 +1118,7 @@ RtlQueryRegistryValues(IN ULONG RelativeTo,
                                          &KeyValueName,
                                          KeyValueFullInformation,
                                          KeyValueInfo,
-                                         InfoSize,
+                                         (ULONG)InfoSize,
                                          &ResultLength);
                 if (Status == STATUS_BUFFER_OVERFLOW)
                 {
@@ -1111,7 +1135,7 @@ RtlQueryRegistryValues(IN ULONG RelativeTo,
                         /* Setup a default */
                         KeyValueInfo->Type = REG_NONE;
                         KeyValueInfo->DataLength = 0;
-                        ResultLength = InfoSize;
+                        ResultLength = (ULONG)InfoSize;
 
                         /* Call the query routine */
                         Status = RtlpCallQueryRegistryRoutine(QueryTable,
@@ -1146,12 +1170,12 @@ RtlQueryRegistryValues(IN ULONG RelativeTo,
                     if (KeyValueInfo->Type == REG_MULTI_SZ)
                     {
                         /* Add a null-char */
-                        ((PWCHAR)KeyValueInfo)[ResultLength / 2] = UNICODE_NULL;
+                        ((PWCHAR)KeyValueInfo)[ResultLength / sizeof(WCHAR)] = UNICODE_NULL;
                         KeyValueInfo->DataLength += sizeof(UNICODE_NULL);
                     }
 
                     /* Call the query routine */
-                    ResultLength = InfoSize;
+                    ResultLength = (ULONG)InfoSize;
                     Status = RtlpCallQueryRegistryRoutine(QueryTable,
                                                           KeyValueInfo,
                                                           &ResultLength,
@@ -1212,7 +1236,7 @@ ProcessValues:
                                              Value,
                                              KeyValueFullInformation,
                                              KeyValueInfo,
-                                             InfoSize,
+                                             (ULONG)InfoSize,
                                              &ResultLength);
                 if (Status == STATUS_BUFFER_OVERFLOW)
                 {
@@ -1242,7 +1266,7 @@ ProcessValues:
                 if (NT_SUCCESS(Status))
                 {
                     /* Call the query routine */
-                    ResultLength = InfoSize;
+                    ResultLength = (ULONG)InfoSize;
                     Status = RtlpCallQueryRegistryRoutine(QueryTable,
                                                           KeyValueInfo,
                                                           &ResultLength,
@@ -1303,7 +1327,7 @@ ProcessValues:
     }
 
     /* Check if we need to close our handle */
-    if ((KeyHandle) && !(RelativeTo & RTL_REGISTRY_HANDLE)) ZwClose(KeyHandle);
+    if (KeyHandle) RtlpCloseRegistryHandle(RelativeTo, KeyHandle);
     if ((CurrentKey) && (CurrentKey != KeyHandle)) ZwClose(CurrentKey);
 
     /* Free our buffer and return status */