[ADVAPI32]
[reactos.git] / reactos / dll / win32 / advapi32 / reg / reg.c
index a090cbf..cfc8ba1 100644 (file)
@@ -1916,9 +1916,9 @@ RegpApplyRestrictions(DWORD dwFlags,
             {
                 DWORD cbExpect = 0;
 
-                if ((dwFlags & RRF_RT_DWORD) == RRF_RT_DWORD)
+                if ((dwFlags & RRF_RT_ANY) == RRF_RT_DWORD)
                     cbExpect = 4;
-                else if ((dwFlags & RRF_RT_QWORD) == RRF_RT_QWORD)
+                else if ((dwFlags & RRF_RT_ANY) == RRF_RT_QWORD)
                     cbExpect = 8;
 
                 if (cbExpect && cbData != cbExpect)
@@ -2793,14 +2793,14 @@ RegEnumValueA(HKEY hKey,
 
     status = NtEnumerateValueKey( KeyHandle, index, KeyValueFullInformation,
                                   buffer, total_size, &total_size );
-    if (status && status != STATUS_BUFFER_OVERFLOW) goto done;
+    if (status && (status != STATUS_BUFFER_OVERFLOW) && (status != STATUS_BUFFER_TOO_SMALL)) goto done;
 
     /* we need to fetch the contents for a string type even if not requested,
      * because we need to compute the length of the ASCII string. */
     if (value || data || is_string(info->Type))
     {
         /* retry with a dynamically allocated buffer */
-        while (status == STATUS_BUFFER_OVERFLOW)
+        while ((status == STATUS_BUFFER_OVERFLOW) || (status == STATUS_BUFFER_TOO_SMALL))
         {
             if (buf_ptr != buffer) HeapFree( GetProcessHeap(), 0, buf_ptr );
             if (!(buf_ptr = HeapAlloc( GetProcessHeap(), 0, total_size )))
@@ -2819,14 +2819,14 @@ RegEnumValueA(HKEY hKey,
         {
             ULONG len;
             RtlUnicodeToMultiByteSize( &len, (WCHAR *)(buf_ptr + info->DataOffset),
-                                       total_size - info->DataOffset );
+                                       info->DataLength );
             if (data && len)
             {
                 if (len > *count) status = STATUS_BUFFER_OVERFLOW;
                 else
                 {
                     RtlUnicodeToMultiByteN( (PCHAR)data, len, NULL, (WCHAR *)(buf_ptr + info->DataOffset),
-                                            total_size - info->DataOffset );
+                                            info->DataLength );
                     /* if the type is REG_SZ and data is not 0-terminated
                      * and there is enough space in the buffer NT appends a \0 */
                     if (len < *count && data[len-1]) data[len] = 0;
@@ -2836,8 +2836,8 @@ RegEnumValueA(HKEY hKey,
         }
         else if (data)
         {
-            if (total_size - info->DataOffset > *count) status = STATUS_BUFFER_OVERFLOW;
-            else memcpy( data, buf_ptr + info->DataOffset, total_size - info->DataOffset );
+            if (info->DataLength > *count) status = STATUS_BUFFER_OVERFLOW;
+            else memcpy( data, buf_ptr + info->DataOffset, info->DataLength );
         }
 
         if (value && !status)
@@ -2928,12 +2928,12 @@ RegEnumValueW(HKEY hKey,
 
     status = NtEnumerateValueKey( KeyHandle, index, KeyValueFullInformation,
                                   buffer, total_size, &total_size );
-    if (status && status != STATUS_BUFFER_OVERFLOW) goto done;
+    if (status && (status != STATUS_BUFFER_OVERFLOW) && (status != STATUS_BUFFER_TOO_SMALL)) goto done;
 
     if (value || data)
     {
         /* retry with a dynamically allocated buffer */
-        while (status == STATUS_BUFFER_OVERFLOW)
+        while ((status == STATUS_BUFFER_OVERFLOW) || (status == STATUS_BUFFER_TOO_SMALL))
         {
             if (buf_ptr != buffer) HeapFree( GetProcessHeap(), 0, buf_ptr );
             if (!(buf_ptr = HeapAlloc( GetProcessHeap(), 0, total_size )))
@@ -2962,17 +2962,17 @@ RegEnumValueW(HKEY hKey,
 
         if (data)
         {
-            if (total_size - info->DataOffset > *count)
+            if (info->DataLength > *count)
             {
                 status = STATUS_BUFFER_OVERFLOW;
                 goto overflow;
             }
-            memcpy( data, buf_ptr + info->DataOffset, total_size - info->DataOffset );
-            if (total_size - info->DataOffset <= *count-sizeof(WCHAR) && is_string(info->Type))
+            memcpy( data, buf_ptr + info->DataOffset, info->DataLength );
+            if (is_string(info->Type) && info->DataLength <= *count - sizeof(WCHAR))
             {
                 /* if the type is REG_SZ and data is not 0-terminated
                  * and there is enough space in the buffer NT appends a \0 */
-                WCHAR *ptr = (WCHAR *)(data + total_size - info->DataOffset);
+                WCHAR *ptr = (WCHAR *)(data + info->DataLength);
                 if (ptr > (WCHAR *)data && ptr[-1]) *ptr = 0;
             }
         }
@@ -3272,6 +3272,9 @@ RegOpenKeyA(HKEY hKey,
     TRACE("RegOpenKeyA hKey 0x%x lpSubKey %s phkResult %p\n",
           hKey, lpSubKey, phkResult);
 
+    if (!phkResult)
+        return ERROR_INVALID_PARAMETER;
+
     if (!hKey && lpSubKey && phkResult)
     {
         return ERROR_INVALID_HANDLE;
@@ -3308,6 +3311,9 @@ RegOpenKeyW(HKEY hKey,
     TRACE("RegOpenKeyW hKey 0x%x lpSubKey %S phkResult %p\n",
           hKey, lpSubKey, phkResult);
 
+    if (!phkResult)
+        return ERROR_INVALID_PARAMETER;
+
     if (!hKey && lpSubKey && phkResult)
     {
         return ERROR_INVALID_HANDLE;
@@ -3982,11 +3988,12 @@ RegQueryValueExA(HKEY hKey,
                  LPDWORD lpcbData)
 {
     UNICODE_STRING ValueName;
-    UNICODE_STRING ValueData;
-    ANSI_STRING AnsiString;
+    LPWSTR lpValueBuffer;
     LONG ErrorCode;
     DWORD Length;
     DWORD Type;
+    NTSTATUS Status;
+    ULONG Index;
 
     TRACE("hKey 0x%X  lpValueName %s  lpData 0x%X  lpcbData %d\n",
           hKey, lpValueName, lpData, lpcbData ? *lpcbData : 0);
@@ -3996,60 +4003,70 @@ RegQueryValueExA(HKEY hKey,
         return ERROR_INVALID_PARAMETER;
     }
 
+    Length = (lpcbData == NULL || lpData == NULL) ? 0 : *lpcbData * sizeof(WCHAR);
+
     if (lpData)
     {
-        ValueData.Length = 0;
-        ValueData.MaximumLength = (*lpcbData + 1) * sizeof(WCHAR);
-        ValueData.Buffer = RtlAllocateHeap(ProcessHeap,
-                                           0,
-                                           ValueData.MaximumLength);
-        if (!ValueData.Buffer)
+        lpValueBuffer = RtlAllocateHeap(ProcessHeap,
+                                        0,
+                                        Length + sizeof(WCHAR));
+        if (!lpValueBuffer)
         {
             return ERROR_OUTOFMEMORY;
         }
     }
     else
     {
-        ValueData.Buffer = NULL;
-        ValueData.Length = 0;
-        ValueData.MaximumLength = 0;
+        lpValueBuffer = NULL;
 
         if (lpcbData)
             *lpcbData = 0;
     }
 
-    RtlCreateUnicodeStringFromAsciiz(&ValueName,
-                                     (LPSTR)lpValueName);
+    if(!RtlCreateUnicodeStringFromAsciiz(&ValueName,
+                                         (LPSTR)lpValueName))
+    {
+        ERR("RtlCreateUnicodeStringFromAsciiz failed!\n");
+        ErrorCode = ERROR_OUTOFMEMORY;
+        goto cleanup;
+    }
 
-    Length = (lpcbData == NULL) ? 0 : *lpcbData * sizeof(WCHAR);
     ErrorCode = RegQueryValueExW(hKey,
                                  ValueName.Buffer,
                                  lpReserved,
                                  &Type,
-                                 (lpData == NULL) ? NULL : (LPBYTE)ValueData.Buffer,
+                                 (LPBYTE)lpValueBuffer,
                                  &Length);
     TRACE("ErrorCode %lu\n", ErrorCode);
+
     RtlFreeUnicodeString(&ValueName);
 
     if (ErrorCode == ERROR_SUCCESS ||
         ErrorCode == ERROR_MORE_DATA)
     {
 
-        if ((Type == REG_SZ) || (Type == REG_MULTI_SZ) || (Type == REG_EXPAND_SZ))
+        if (is_string(Type))
         {
-            if (ErrorCode == ERROR_SUCCESS && ValueData.Buffer != NULL)
+            if (ErrorCode == ERROR_SUCCESS && lpValueBuffer != NULL)
             {
-                RtlInitAnsiString(&AnsiString, NULL);
-                AnsiString.Buffer = (LPSTR)lpData;
-                AnsiString.MaximumLength = *lpcbData;
-                ValueData.Length = Length;
-                ValueData.MaximumLength = ValueData.Length + sizeof(WCHAR);
-                RtlUnicodeStringToAnsiString(&AnsiString, &ValueData, FALSE);
+                Status = RtlUnicodeToMultiByteN((PCHAR)lpData, *lpcbData, &Index, (PWCHAR)lpValueBuffer, Length);
+                if (NT_SUCCESS(Status))
+                {
+                    PCHAR szData = (PCHAR)lpData;
+                    if(&szData[Index] < (PCHAR)(lpData + *lpcbData))
+                    {
+                        szData[Index] = '\0';
+                    }
+                }
+                else
+                {
+                    ErrorCode = RtlNtStatusToDosError(Status);
+                }
             }
 
             Length = Length / sizeof(WCHAR);
         }
-        else if (ErrorCode == ERROR_SUCCESS && ValueData.Buffer != NULL)
+        else if (ErrorCode == ERROR_SUCCESS && lpValueBuffer != NULL)
         {
             if (*lpcbData < Length)
             {
@@ -4057,7 +4074,7 @@ RegQueryValueExA(HKEY hKey,
             }
             else
             {
-                RtlMoveMemory(lpData, ValueData.Buffer, Length);
+                RtlMoveMemory(lpData, lpValueBuffer, Length);
             }
         }
 
@@ -4072,9 +4089,10 @@ RegQueryValueExA(HKEY hKey,
         *lpType = Type;
     }
 
-    if (ValueData.Buffer != NULL)
+cleanup:
+    if (lpValueBuffer != NULL)
     {
-        RtlFreeHeap(ProcessHeap, 0, ValueData.Buffer);
+        RtlFreeHeap(ProcessHeap, 0, lpValueBuffer);
     }
 
     return ErrorCode;
@@ -4129,7 +4147,7 @@ RegQueryValueExW(HKEY hkeyorg,
 
     status = NtQueryValueKey( hkey, &name_str, KeyValuePartialInformation,
                               buffer, total_size, &total_size );
-    if (status && status != STATUS_BUFFER_OVERFLOW) goto done;
+    if (!NT_SUCCESS(status) && status != STATUS_BUFFER_OVERFLOW) goto done;
 
     if (data)
     {
@@ -4147,12 +4165,12 @@ RegQueryValueExW(HKEY hkeyorg,
                                       buf_ptr, total_size, &total_size );
         }
 
-        if (!status)
+        if (NT_SUCCESS(status))
         {
             memcpy( data, buf_ptr + info_size, total_size - info_size );
             /* if the type is REG_SZ and data is not 0-terminated
              * and there is enough space in the buffer NT appends a \0 */
-            if (total_size - info_size <= *count-sizeof(WCHAR) && is_string(info->Type))
+            if (is_string(info->Type) && total_size - info_size <= *count-sizeof(WCHAR))
             {
                 WCHAR *ptr = (WCHAR *)(data + total_size - info_size);
                 if (ptr > (WCHAR *)data && ptr[-1]) *ptr = 0;
@@ -4737,12 +4755,16 @@ RegSetValueExA(HKEY hKey,
     LONG ErrorCode;
     LPBYTE pData;
     DWORD DataSize;
+    NTSTATUS Status;
 
-    if (lpValueName != NULL &&
-        strlen(lpValueName) != 0)
+    /* Convert SubKey name to Unicode */
+    if (lpValueName != NULL && lpValueName[0] != '\0')
     {
-        RtlCreateUnicodeStringFromAsciiz(&ValueName,
-                                         (PSTR)lpValueName);
+        BOOL bConverted;
+        bConverted = RtlCreateUnicodeStringFromAsciiz(&ValueName,
+                                                  (PSTR)lpValueName);
+        if(!bConverted)
+            return ERROR_NOT_ENOUGH_MEMORY;
     }
     else
     {
@@ -4751,32 +4773,32 @@ RegSetValueExA(HKEY hKey,
 
     pValueName = (LPWSTR)ValueName.Buffer;
 
-    if (((dwType == REG_SZ) ||
-         (dwType == REG_MULTI_SZ) ||
-         (dwType == REG_EXPAND_SZ)) &&
-        (cbData != 0))
-    {
-        /* NT adds one if the caller forgot the NULL-termination character */
-        if (lpData[cbData - 1] != '\0')
-        {
-            cbData++;
-        }
 
-        RtlInitAnsiString(&AnsiString,
-                          NULL);
+    if (is_string(dwType) && (cbData != 0))
+    {
+        /* Convert ANSI string Data to Unicode */
+        /* If last character NOT zero then increment length */
+        LONG bNoNulledStr = ((lpData[cbData-1] != '\0') ? 1 : 0);
         AnsiString.Buffer = (PSTR)lpData;
-        AnsiString.Length = cbData - 1;
-        AnsiString.MaximumLength = cbData;
-        RtlAnsiStringToUnicodeString(&Data,
+        AnsiString.Length = cbData + bNoNulledStr;
+        AnsiString.MaximumLength = cbData + bNoNulledStr;
+        Status = RtlAnsiStringToUnicodeString(&Data,
                                      &AnsiString,
                                      TRUE);
+
+        if (!NT_SUCCESS(Status))
+        {
+            if (pValueName != NULL)
+                RtlFreeUnicodeString(&ValueName);
+
+            return RtlNtStatusToDosError(Status);
+        }
         pData = (LPBYTE)Data.Buffer;
         DataSize = cbData * sizeof(WCHAR);
     }
     else
     {
-        RtlInitUnicodeString(&Data,
-                             NULL);
+        Data.Buffer = NULL;
         pData = (LPBYTE)lpData;
         DataSize = cbData;
     }
@@ -4787,19 +4809,12 @@ RegSetValueExA(HKEY hKey,
                                dwType,
                                pData,
                                DataSize);
+
     if (pValueName != NULL)
-    {
-        RtlFreeHeap(ProcessHeap,
-                    0,
-                    ValueName.Buffer);
-    }
+        RtlFreeUnicodeString(&ValueName);
 
     if (Data.Buffer != NULL)
-    {
-        RtlFreeHeap(ProcessHeap,
-                    0,
-                    Data.Buffer);
-    }
+        RtlFreeUnicodeString(&Data);
 
     return ErrorCode;
 }
@@ -4830,24 +4845,19 @@ RegSetValueExW(HKEY hKey,
         return RtlNtStatusToDosError(Status);
     }
 
-    if (lpValueName != NULL)
-    {
-        RtlInitUnicodeString(&ValueName,
-                             lpValueName);
-    }
-    else
-    {
-        RtlInitUnicodeString(&ValueName, L"");
-    }
+    RtlInitUnicodeString(&ValueName, lpValueName);
     pValueName = &ValueName;
 
-    if (((dwType == REG_SZ) ||
-         (dwType == REG_MULTI_SZ) ||
-         (dwType == REG_EXPAND_SZ)) &&
-        (cbData != 0) && (*(((PWCHAR)lpData) + (cbData / sizeof(WCHAR)) - 1) != L'\0'))
+    if (is_string(dwType) && (cbData != 0))
     {
-        /* NT adds one if the caller forgot the NULL-termination character */
-        cbData += sizeof(WCHAR);
+        PWSTR pwsData = (PWSTR)lpData;
+
+        if((pwsData[cbData / sizeof(WCHAR) - 1] != L'\0') &&
+            (pwsData[cbData / sizeof(WCHAR)] == L'\0'))
+        {
+            /* Increment length if last character is not zero and next is zero */
+            cbData += sizeof(WCHAR);
+        }
     }
 
     Status = NtSetValueKey(KeyHandle,