[NTOSKRNL]
authorTimo Kreuzer <timo.kreuzer@reactos.org>
Thu, 2 Dec 2010 14:37:16 +0000 (14:37 +0000)
committerTimo Kreuzer <timo.kreuzer@reactos.org>
Thu, 2 Dec 2010 14:37:16 +0000 (14:37 +0000)
patch by Samuel Serapion:
Implement MemorySectionName case for NtQueryVirtualMemory. Protect buffer access with SEH.

See issue #5753 for more details.

svn path=/trunk/; revision=49898

reactos/ntoskrnl/mm/ARM3/virtual.c

index 4e69411..508db86 100644 (file)
@@ -2339,35 +2339,21 @@ NtResetWriteWatch(IN HANDLE ProcessHandle,
 
 NTSTATUS
 NTAPI
-NtQueryVirtualMemory(IN HANDLE ProcessHandle,
-                     IN PVOID BaseAddress,
-                     IN MEMORY_INFORMATION_CLASS MemoryInformationClass,
-                     OUT PVOID MemoryInformation,
-                     IN SIZE_T MemoryInformationLength,
-                     OUT PSIZE_T ReturnLength)
+MiQueryMemoryBasicInformation(IN HANDLE ProcessHandle,
+                              IN PVOID BaseAddress,
+                              OUT PVOID MemoryInformation,
+                              IN SIZE_T MemoryInformationLength,
+                              OUT PSIZE_T ReturnLength)
 {
     PEPROCESS TargetProcess;
-    NTSTATUS Status;
+    NTSTATUS Status = STATUS_SUCCESS;
     PMMVAD Vad = NULL;
     PVOID Address, NextAddress;
     BOOLEAN Found = FALSE;
     ULONG NewProtect, NewState, BaseVpn;
     MEMORY_BASIC_INFORMATION MemoryInfo;
     KAPC_STATE ApcState;
-    DPRINT("Querying class %d about address: %p\n", MemoryInformationClass, BaseAddress);
-
-    /* Only this class is supported for now */
-    ASSERT(MemoryInformationClass == MemoryBasicInformation);
-    
-    /* Validate the size information of the class */
-    if (MemoryInformationLength < sizeof(MEMORY_BASIC_INFORMATION))
-    {
-        /* The size is invalid */
-        return STATUS_INFO_LENGTH_MISMATCH;
-    }
-
-    /* Bail out if the address is invalid */
-    if (BaseAddress > MM_HIGHEST_USER_ADDRESS) return STATUS_INVALID_PARAMETER;
+    KPROCESSOR_MODE PreviousMode = ExGetPreviousMode();
 
     /* Check for illegal addresses in user-space, or the shared memory area */
     if ((BaseAddress > MM_HIGHEST_VAD_ADDRESS) ||
@@ -2396,11 +2382,27 @@ NtQueryVirtualMemory(IN HANDLE ProcessHandle,
             MemoryInfo.RegionSize = (ULONG_PTR)MemoryInfo.AllocationBase - (ULONG_PTR)Address;
         }
 
-        /* Return the data (FIXME: Use SEH) */
-        *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
-        if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+        /* Return the data, NtQueryInformation already probed it*/
+        if (PreviousMode != KernelMode)
+        {
+            _SEH2_TRY
+            {
+                *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
+                if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+            }
+             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+            {
+                Status = _SEH2_GetExceptionCode();
+            }
+            _SEH2_END;
+        }
+        else
+        {
+            *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
+            if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+        }
 
-        return STATUS_SUCCESS;
+        return Status;
     }
 
     /* Check if this is for a local or remote process */
@@ -2507,11 +2509,27 @@ NtQueryVirtualMemory(IN HANDLE ProcessHandle,
         MemoryInfo.Protect = PAGE_NOACCESS;
         MemoryInfo.Type = 0;
 
-        /* Return the data (FIXME: Use SEH) */
-        *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
-        if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+        /* Return the data, NtQueryInformation already probed it*/
+        if (PreviousMode != KernelMode)
+        {
+            _SEH2_TRY
+            {
+                *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
+                if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+            }
+             _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+            {
+                Status = _SEH2_GetExceptionCode();
+            }
+            _SEH2_END;
+        }
+        else
+        {
+            *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
+            if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+        }
 
-        return STATUS_SUCCESS;
+        return Status;
     }
 
     /* This must be a VM VAD */
@@ -2539,7 +2557,7 @@ NtQueryVirtualMemory(IN HANDLE ProcessHandle,
         Address = NextAddress;
     }
 
-    /* Now that we know the last VA address, calculate hte region size */
+    /* Now that we know the last VA address, calculate the region size */
     MemoryInfo.RegionSize = ((ULONG_PTR)Address - (ULONG_PTR)MemoryInfo.BaseAddress);
 
     /* Check if we were attached */
@@ -2550,17 +2568,179 @@ NtQueryVirtualMemory(IN HANDLE ProcessHandle,
         ObDereferenceObject(TargetProcess);
     }
 
-    /* Return the data (FIXME: Use SEH) */
-    *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
-    if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
-    
+    /* Return the data, NtQueryInformation already probed it*/
+    if (PreviousMode != KernelMode)
+    {
+        _SEH2_TRY
+        {
+            *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
+            if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+        }
+         _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            Status = _SEH2_GetExceptionCode();
+        }
+        _SEH2_END;
+    }
+    else
+    {
+        *(PMEMORY_BASIC_INFORMATION)MemoryInformation = MemoryInfo;
+        if (ReturnLength) *ReturnLength = sizeof(MEMORY_BASIC_INFORMATION);
+    }
+
     /* All went well */
     DPRINT("Base: %p AllocBase: %p Protect: %lx AllocProtect: %lx "
             "State: %lx Type: %lx Size: %lx\n",
             MemoryInfo.BaseAddress, MemoryInfo.AllocationBase,
             MemoryInfo.AllocationProtect, MemoryInfo.Protect,
             MemoryInfo.State, MemoryInfo.Type, MemoryInfo.RegionSize);
-    return STATUS_SUCCESS;
+
+    return Status;
+}
+
+NTSTATUS
+NTAPI
+MiQueryMemorySectionName(IN HANDLE ProcessHandle,
+                         IN PVOID BaseAddress,
+                         OUT PVOID MemoryInformation,
+                         IN SIZE_T MemoryInformationLength,
+                         OUT PSIZE_T ReturnLength)
+{
+    PEPROCESS Process;
+    NTSTATUS Status;
+    WCHAR ModuleFileNameBuffer[MAX_PATH] = {0};
+    UNICODE_STRING ModuleFileName;
+    PMEMORY_SECTION_NAME SectionName = NULL;
+    KPROCESSOR_MODE PreviousMode = ExGetPreviousMode();
+
+    Status = ObReferenceObjectByHandle(ProcessHandle,
+                                       PROCESS_QUERY_INFORMATION,
+                                       NULL,
+                                       PreviousMode,
+                                       (PVOID*)(&Process),
+                                       NULL);
+
+    if (!NT_SUCCESS(Status))
+    {
+        DPRINT("MiQueryMemorySectionName: ObReferenceObjectByHandle returned %x\n",Status);
+        return Status;
+    }
+
+    RtlInitEmptyUnicodeString(&ModuleFileName, ModuleFileNameBuffer, sizeof(ModuleFileNameBuffer));
+    Status = MmGetFileNameForAddress(BaseAddress, &ModuleFileName);
+
+    if (NT_SUCCESS(Status))
+    {
+        SectionName = MemoryInformation;
+        if (PreviousMode != KernelMode)
+        {
+            _SEH2_TRY
+            {
+                RtlInitUnicodeString(&SectionName->SectionFileName, SectionName->NameBuffer);
+                SectionName->SectionFileName.MaximumLength = MemoryInformationLength;
+                RtlCopyUnicodeString(&SectionName->SectionFileName, &ModuleFileName);
+
+                if (ReturnLength) *ReturnLength = ModuleFileName.Length;
+
+            }
+            _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+            {
+                Status = _SEH2_GetExceptionCode();
+            }
+            _SEH2_END;
+        }
+        else
+        {
+            RtlInitUnicodeString(&SectionName->SectionFileName, SectionName->NameBuffer);
+            SectionName->SectionFileName.MaximumLength = MemoryInformationLength;
+            RtlCopyUnicodeString(&SectionName->SectionFileName, &ModuleFileName);
+
+            if (ReturnLength) *ReturnLength = ModuleFileName.Length;
+
+        }
+    }
+    ObDereferenceObject(Process);
+    return Status;
+}
+
+NTSTATUS
+NTAPI
+NtQueryVirtualMemory(IN HANDLE ProcessHandle,
+                     IN PVOID BaseAddress,
+                     IN MEMORY_INFORMATION_CLASS MemoryInformationClass,
+                     OUT PVOID MemoryInformation,
+                     IN SIZE_T MemoryInformationLength,
+                     OUT PSIZE_T ReturnLength)
+{
+    NTSTATUS Status = STATUS_SUCCESS;
+    KPROCESSOR_MODE PreviousMode;
+
+    DPRINT("Querying class %d about address: %p\n", MemoryInformationClass, BaseAddress);
+
+    /* Bail out if the address is invalid */
+    if (BaseAddress > MM_HIGHEST_USER_ADDRESS) return STATUS_INVALID_PARAMETER;
+
+    /* Probe return buffer */
+    PreviousMode =  ExGetPreviousMode();
+    if (PreviousMode != KernelMode)
+    {
+        _SEH2_TRY
+        {
+            ProbeForWrite(MemoryInformation,
+                          MemoryInformationLength,
+                          sizeof(ULONG_PTR));
+
+            if (ReturnLength) ProbeForWriteSize_t(ReturnLength);
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            Status = _SEH2_GetExceptionCode();
+        }
+        _SEH2_END;
+
+        if (!NT_SUCCESS(Status))
+        {
+            return Status;
+        }
+    }
+
+    switch(MemoryInformationClass)
+    {
+        case MemoryBasicInformation:
+            /* Validate the size information of the class */
+            if (MemoryInformationLength < sizeof(MEMORY_BASIC_INFORMATION))
+            {
+                /* The size is invalid */
+                return STATUS_INFO_LENGTH_MISMATCH;
+            }
+            Status = MiQueryMemoryBasicInformation(ProcessHandle,
+                                                   BaseAddress, 
+                                                   MemoryInformation,
+                                                   MemoryInformationLength,
+                                                   ReturnLength);
+            break;
+
+        case MemorySectionName:
+            /* Validate the size information of the class */
+            if (MemoryInformationLength < sizeof(MEMORY_SECTION_NAME))
+            {
+                /* The size is invalid */
+                return STATUS_INFO_LENGTH_MISMATCH;
+            }
+            Status = MiQueryMemorySectionName(ProcessHandle,
+                                              BaseAddress, 
+                                              MemoryInformation,
+                                              MemoryInformationLength,
+                                              ReturnLength);
+            break;
+        case MemoryWorkingSetList:
+        case MemoryBasicVlmInformation:
+        default:
+            DPRINT1("Unhandled memory information class %d\n", MemoryInformationClass);
+            break;
+    }
+
+    return Status;
 }
 
 /* EOF */