- Set the correct type and state in MiQueryVirtualMemory.
[reactos.git] / reactos / ntoskrnl / mm / virtual.c
index 1b198b2..b2d9f85 100644 (file)
@@ -123,7 +123,7 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
    MEMORY_AREA* MemoryArea;
    PMADDRESS_SPACE AddressSpace;
 
-   if (Address < (PVOID)KERNEL_BASE)
+   if (Address < MmSystemRangeStart)
    {
       Status = ObReferenceObjectByHandle(ProcessHandle,
                                          PROCESS_QUERY_INFORMATION,
@@ -175,6 +175,7 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
               switch(MemoryArea->Type)
               {
                  case MEMORY_AREA_VIRTUAL_MEMORY:
+                  case MEMORY_AREA_PEB_OR_TEB:
                      Status = MmQueryAnonMem(MemoryArea, Address, Info,
                                              ResultLength);
                     break;
@@ -183,8 +184,8 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
                                                  ResultLength);
                      break;
                  case MEMORY_AREA_NO_ACCESS:
-                    Info->Type = 0;
-                     Info->State = MEM_FREE;
+                    Info->Type = MEM_PRIVATE;
+                     Info->State = MEM_RESERVE;
                     Info->Protect = MemoryArea->Attributes;
                     Info->AllocationProtect = MemoryArea->Attributes;
                     Info->BaseAddress = MemoryArea->StartingAddress;
@@ -195,7 +196,7 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
                      *ResultLength = sizeof(MEMORY_BASIC_INFORMATION);
                     break;
                  case MEMORY_AREA_SHARED_DATA:
-                    Info->Type = 0;
+                    Info->Type = MEM_PRIVATE;
                      Info->State = MEM_COMMIT;
                     Info->Protect = MemoryArea->Attributes;
                     Info->AllocationProtect = MemoryArea->Attributes;
@@ -207,15 +208,6 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
                      *ResultLength = sizeof(MEMORY_BASIC_INFORMATION);
                     break;
                  case MEMORY_AREA_SYSTEM:
-                    {
-                       static int warned = 0;
-                       if ( !warned )
-                       {
-                         DPRINT1("FIXME: MEMORY_AREA_SYSTEM case incomplete (or possibly wrong) for NtQueryVirtualMemory()\n");
-                         warned = 1;
-                       }
-                    }
-                    /* FIXME - don't have a clue if this is right, but it's better than nothing */
                     Info->Type = 0;
                      Info->State = MEM_COMMIT;
                     Info->Protect = MemoryArea->Attributes;
@@ -228,15 +220,18 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
                      *ResultLength = sizeof(MEMORY_BASIC_INFORMATION);
                     break;
                  case MEMORY_AREA_KERNEL_STACK:
-                    {
-                       static int warned = 0;
-                       if ( !warned )
-                       {
-                         DPRINT1("FIXME: MEMORY_AREA_KERNEL_STACK case incomplete (or possibly wrong) for NtQueryVirtualMemory()\n");
-                         warned = 1;
-                       }
-                    }
-                    /* FIXME - don't have a clue if this is right, but it's better than nothing */
+                    Info->Type = 0;
+                     Info->State = MEM_COMMIT;
+                    Info->Protect = MemoryArea->Attributes;
+                    Info->AllocationProtect = MemoryArea->Attributes;
+                    Info->BaseAddress = MemoryArea->StartingAddress;
+                    Info->AllocationBase = MemoryArea->StartingAddress;
+                    Info->RegionSize = (ULONG_PTR)MemoryArea->EndingAddress -
+                                       (ULONG_PTR)MemoryArea->StartingAddress;
+                     Status = STATUS_SUCCESS;
+                     *ResultLength = sizeof(MEMORY_BASIC_INFORMATION);
+                    break;
+                  case MEMORY_AREA_PAGED_POOL:
                     Info->Type = 0;
                      Info->State = MEM_COMMIT;
                     Info->Protect = MemoryArea->Attributes;
@@ -266,7 +261,7 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
    }
 
    MmUnlockAddressSpace(AddressSpace);
-   if (Address < (PVOID)KERNEL_BASE)
+   if (Address < MmSystemRangeStart)
    {
       ObDereferenceObject(Process);
    }
@@ -282,14 +277,14 @@ MiQueryVirtualMemory (IN HANDLE ProcessHandle,
 NTSTATUS STDCALL
 NtQueryVirtualMemory (IN HANDLE ProcessHandle,
                       IN PVOID Address,
-                      IN CINT VirtualMemoryInformationClass,
+                      IN MEMORY_INFORMATION_CLASS VirtualMemoryInformationClass,
                       OUT PVOID VirtualMemoryInformation,
                       IN ULONG Length,
                       OUT PULONG UnsafeResultLength)
 {
-   NTSTATUS Status;
+   NTSTATUS Status = STATUS_SUCCESS;
    ULONG ResultLength = 0;
-   KPROCESSOR_MODE PrevMode;
+   KPROCESSOR_MODE PreviousMode;
    union
    {
       MEMORY_BASIC_INFORMATION BasicInfo;
@@ -302,9 +297,27 @@ NtQueryVirtualMemory (IN HANDLE ProcessHandle,
           VirtualMemoryInformationClass,VirtualMemoryInformation,
           Length,ResultLength);
 
-   PrevMode =  ExGetPreviousMode();
+   PreviousMode =  ExGetPreviousMode();
+   
+   if (PreviousMode != KernelMode && UnsafeResultLength != NULL)
+     {
+       _SEH_TRY
+         {
+           ProbeForWriteUlong(UnsafeResultLength);
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           return Status;
+         }
+     }
 
-   if (PrevMode == UserMode && Address >= (PVOID)KERNEL_BASE)
+   if (Address >= MmSystemRangeStart)
    {
       DPRINT1("Invalid parameter\n");
       return STATUS_INVALID_PARAMETER;
@@ -317,19 +330,48 @@ NtQueryVirtualMemory (IN HANDLE ProcessHandle,
        Length,
        &ResultLength );
 
-   if (NT_SUCCESS(Status) && ResultLength > 0)
+   if (NT_SUCCESS(Status))
    {
-      Status = MmCopyToCaller(VirtualMemoryInformation, &VirtualMemoryInfo, ResultLength);
-      if (!NT_SUCCESS(Status))
-      {
-         ResultLength = 0;
-      }
-   }
+      if (PreviousMode != KernelMode)
+        {
+          _SEH_TRY
+            {
+              if (ResultLength > 0)
+                {
+                  ProbeForWrite(VirtualMemoryInformation,
+                                ResultLength,
+                                1);
+                  RtlCopyMemory(VirtualMemoryInformation,
+                                &VirtualMemoryInfo,
+                                ResultLength);
+                }
+              if (UnsafeResultLength != NULL)
+                {
+                  *UnsafeResultLength = ResultLength;
+                }
+            }
+          _SEH_HANDLE
+            {
+              Status = _SEH_GetExceptionCode();
+            }
+          _SEH_END;
+        }
+      else
+        {
+          if (ResultLength > 0)
+            {
+              RtlCopyMemory(VirtualMemoryInformation,
+                            &VirtualMemoryInfo,
+                            ResultLength);
+            }
 
-   if (UnsafeResultLength != NULL)
-   {
-      MmCopyToCaller(UnsafeResultLength, &ResultLength, sizeof(ULONG));
+          if (UnsafeResultLength != NULL)
+            {
+              *UnsafeResultLength = ResultLength;
+            }
+        }
    }
+
    return(Status);
 }
 
@@ -380,7 +422,7 @@ MiProtectVirtualMemory(IN PEPROCESS Process,
    else
    {
       /* FIXME: Should we return failure or success in this case? */
-      Status = STATUS_SUCCESS;
+      Status = STATUS_CONFLICTING_ADDRESSES;
    }
 
    MmUnlockAddressSpace(AddressSpace);
@@ -402,17 +444,48 @@ NtProtectVirtualMemory(IN HANDLE ProcessHandle,
                        OUT PULONG UnsafeOldAccessProtection)
 {
    PEPROCESS Process;
-   NTSTATUS Status;
    ULONG OldAccessProtection;
-   PVOID BaseAddress;
-   ULONG NumberOfBytesToProtect;
+   PVOID BaseAddress = NULL;
+   ULONG NumberOfBytesToProtect = 0;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS Status = STATUS_SUCCESS;
+   
+   PreviousMode = ExGetPreviousMode();
+   
+   if (PreviousMode != KernelMode)
+     {
+       _SEH_TRY
+         {
+           ProbeForWritePointer(UnsafeBaseAddress);
+           ProbeForWriteUlong(UnsafeNumberOfBytesToProtect);
+           ProbeForWriteUlong(UnsafeOldAccessProtection);
 
-   Status = MmCopyFromCaller(&BaseAddress, UnsafeBaseAddress, sizeof(PVOID));
-   if (!NT_SUCCESS(Status))
-      return Status;
-   Status = MmCopyFromCaller(&NumberOfBytesToProtect, UnsafeNumberOfBytesToProtect, sizeof(ULONG));
-   if (!NT_SUCCESS(Status))
-      return Status;
+           BaseAddress = *UnsafeBaseAddress;
+           NumberOfBytesToProtect = *UnsafeNumberOfBytesToProtect;
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           return Status;
+         }
+     }
+   else
+     {
+       BaseAddress = *UnsafeBaseAddress;
+       NumberOfBytesToProtect = *UnsafeNumberOfBytesToProtect;
+     }
+
+   if ((ULONG_PTR)BaseAddress + NumberOfBytesToProtect - 1 < (ULONG_PTR)BaseAddress ||
+       (ULONG_PTR)BaseAddress + NumberOfBytesToProtect - 1 >= MmUserProbeAddress)
+     {
+       /* Don't allow to change the protection of a kernel mode address */
+       return STATUS_INVALID_PARAMETER_2;
+     }
 
    /* (tMk 2004.II.5) in Microsoft SDK I read:
     * 'if this parameter is NULL or does not point to a valid variable, the function fails'
@@ -442,9 +515,26 @@ NtProtectVirtualMemory(IN HANDLE ProcessHandle,
 
    ObDereferenceObject(Process);
 
-   MmCopyToCaller(UnsafeOldAccessProtection, &OldAccessProtection, sizeof(ULONG));
-   MmCopyToCaller(UnsafeBaseAddress, &BaseAddress, sizeof(PVOID));
-   MmCopyToCaller(UnsafeNumberOfBytesToProtect, &NumberOfBytesToProtect, sizeof(ULONG));
+   if (PreviousMode != KernelMode)
+     {
+       _SEH_TRY
+         {
+           *UnsafeOldAccessProtection = OldAccessProtection;
+           *UnsafeBaseAddress = BaseAddress;
+           *UnsafeNumberOfBytesToProtect = NumberOfBytesToProtect;
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+     }
+   else
+     {
+       *UnsafeOldAccessProtection = OldAccessProtection;
+       *UnsafeBaseAddress = BaseAddress;
+       *UnsafeNumberOfBytesToProtect = NumberOfBytesToProtect;
+     }
 
    return(Status);
 }
@@ -471,20 +561,49 @@ NtReadVirtualMemory(IN HANDLE ProcessHandle,
 
    PAGED_CODE();
 
+   DPRINT("NtReadVirtualMemory(ProcessHandle %x, BaseAddress %x, "
+          "Buffer %x, NumberOfBytesToRead %d)\n",ProcessHandle,BaseAddress,
+          Buffer,NumberOfBytesToRead);
+
+   if ((ULONG_PTR)BaseAddress + NumberOfBytesToRead - 1 < (ULONG_PTR)BaseAddress ||
+       (ULONG_PTR)BaseAddress + NumberOfBytesToRead - 1 >= MmUserProbeAddress)
+     {
+       /* Don't allow to read from kernel space */
+       return STATUS_ACCESS_VIOLATION;
+     }
+
    PreviousMode = ExGetPreviousMode();
 
+   if (PreviousMode != KernelMode)
+     {
+       if ((ULONG_PTR)Buffer + NumberOfBytesToRead - 1 < (ULONG_PTR)Buffer ||
+           (ULONG_PTR)Buffer + NumberOfBytesToRead - 1 >= MmUserProbeAddress)
+         {
+           /* Don't allow to write into kernel space */
+           return STATUS_ACCESS_VIOLATION;
+         }
+     }
+
+   Status = ObReferenceObjectByHandle(ProcessHandle,
+                                      PROCESS_VM_READ,
+                                      NULL,
+                                      PreviousMode,
+                                      (PVOID*)(&Process),
+                                      NULL);
+   if (!NT_SUCCESS(Status))
+   {
+      return(Status);
+   }
+
+   CurrentProcess = PsGetCurrentProcess();
+
    if(PreviousMode != KernelMode)
    {
      _SEH_TRY
      {
-       ProbeForWrite(Buffer,
-                     NumberOfBytesToRead,
-                     1);
        if(NumberOfBytesRead != NULL)
        {
-         ProbeForWrite(NumberOfBytesRead,
-                       sizeof(ULONG),
-                       sizeof(ULONG));
+         ProbeForWriteUlong(NumberOfBytesRead);
        }
      }
      _SEH_HANDLE
@@ -499,22 +618,6 @@ NtReadVirtualMemory(IN HANDLE ProcessHandle,
      }
    }
 
-   DPRINT("NtReadVirtualMemory(ProcessHandle %x, BaseAddress %x, "
-          "Buffer %x, NumberOfBytesToRead %d)\n",ProcessHandle,BaseAddress,
-          Buffer,NumberOfBytesToRead);
-
-   Status = ObReferenceObjectByHandle(ProcessHandle,
-                                      PROCESS_VM_READ,
-                                      NULL,
-                                      PreviousMode,
-                                      (PVOID*)(&Process),
-                                      NULL);
-   if (!NT_SUCCESS(Status))
-   {
-      return(Status);
-   }
-
-   CurrentProcess = PsGetCurrentProcess();
 
    if (Process == CurrentProcess)
    {
@@ -558,7 +661,6 @@ NtReadVirtualMemory(IN HANDLE ProcessHandle,
 
           Status = STATUS_SUCCESS;
           _SEH_TRY {
-              ProbeForRead(BaseAddress, NumberOfBytesToRead, 1);
               Status = STATUS_PARTIAL_COPY;
               RtlCopyMemory(SystemAddress, BaseAddress, NumberOfBytesToRead);
               Status = STATUS_SUCCESS;
@@ -666,18 +768,52 @@ NtWriteVirtualMemory(IN HANDLE ProcessHandle,
                      IN ULONG NumberOfBytesToWrite,
                      OUT PULONG NumberOfBytesWritten  OPTIONAL)
 {
-   NTSTATUS Status;
    PMDL Mdl;
    PVOID SystemAddress;
    PEPROCESS Process;
-   ULONG OldProtection = 0;
-   PVOID ProtectBaseAddress;
-   ULONG ProtectNumberOfBytes;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS CopyStatus, Status = STATUS_SUCCESS;
 
    DPRINT("NtWriteVirtualMemory(ProcessHandle %x, BaseAddress %x, "
           "Buffer %x, NumberOfBytesToWrite %d)\n",ProcessHandle,BaseAddress,
           Buffer,NumberOfBytesToWrite);
 
+   if ((ULONG_PTR)BaseAddress + NumberOfBytesToWrite - 1 < (ULONG_PTR)BaseAddress ||
+       (ULONG_PTR)BaseAddress + NumberOfBytesToWrite - 1 >= MmUserProbeAddress)
+     {
+       /* Don't allow to write into kernel space */
+       return STATUS_ACCESS_VIOLATION;
+     }
+
+   PreviousMode = ExGetPreviousMode();
+   
+   if (PreviousMode != KernelMode)
+     {
+       if ((ULONG_PTR)Buffer + NumberOfBytesToWrite - 1 < (ULONG_PTR)Buffer ||
+           (ULONG_PTR)Buffer + NumberOfBytesToWrite - 1 >= MmUserProbeAddress)
+         {
+           /* Don't allow to read from kernel space */
+           return STATUS_ACCESS_VIOLATION;
+         }
+       if (NumberOfBytesWritten != NULL)
+         {
+           _SEH_TRY
+             {
+               ProbeForWriteUlong(NumberOfBytesWritten);
+             }
+           _SEH_HANDLE
+             {
+               Status = _SEH_GetExceptionCode();
+             }
+           _SEH_END;
+       
+           if (!NT_SUCCESS(Status))
+             {
+               return Status;
+             }
+          }
+     }
+
    Status = ObReferenceObjectByHandle(ProcessHandle,
                                       PROCESS_VM_WRITE,
                                       NULL,
@@ -689,30 +825,27 @@ NtWriteVirtualMemory(IN HANDLE ProcessHandle,
       return(Status);
    }
 
-   /* We have to make sure the target memory is writable.
-    *
-    * I am not sure if it is correct to do this in any case, but it has to be
-    * done at least in some cases because you can use WriteProcessMemory to
-    * write into the .text section of a module where memcpy() would crash.
-    *  -blight (2005/01/09)
-    */
-   ProtectBaseAddress = BaseAddress;
-   ProtectNumberOfBytes = NumberOfBytesToWrite;
+   CopyStatus = STATUS_SUCCESS;
 
    /* Write memory */
    if (Process == PsGetCurrentProcess())
    {
-      Status = MiProtectVirtualMemory(Process,
-                                      &ProtectBaseAddress,
-                                      &ProtectNumberOfBytes,
-                                      PAGE_READWRITE,
-                                      &OldProtection);
-      if (!NT_SUCCESS(Status))
-      {
-         ObDereferenceObject(Process);
-         return Status;
-      }
-      memcpy(BaseAddress, Buffer, NumberOfBytesToWrite);
+      if (PreviousMode != KernelMode)
+        {
+         _SEH_TRY
+            {
+              memcpy(BaseAddress, Buffer, NumberOfBytesToWrite);
+            }
+          _SEH_HANDLE
+            {
+              CopyStatus = _SEH_GetExceptionCode();
+            }
+          _SEH_END;
+        }
+      else
+        {
+          memcpy(BaseAddress, Buffer, NumberOfBytesToWrite);
+        }
    }
    else
    {
@@ -721,36 +854,48 @@ NtWriteVirtualMemory(IN HANDLE ProcessHandle,
                         Buffer,
                         NumberOfBytesToWrite);
       if(Mdl == NULL)
-      {
-         ObDereferenceObject(Process);
-         return(STATUS_NO_MEMORY);
-      }
-
-      /* Make the target area writable. */
-      Status = MiProtectVirtualMemory(Process,
-                                      &ProtectBaseAddress,
-                                      &ProtectNumberOfBytes,
-                                      PAGE_READWRITE,
-                                      &OldProtection);
-      if (!NT_SUCCESS(Status))
-      {
-         ObDereferenceObject(Process);
-         ExFreePool(Mdl);
-         return Status;
-      }
-
-      /* Map the MDL. */
-      MmProbeAndLockPages(Mdl,
-                          UserMode,
-                          IoReadAccess);
+        {
+          ObDereferenceObject(Process);
+          return(STATUS_NO_MEMORY);
+        }
+      _SEH_TRY
+        {
+          /* Map the MDL. */
+          MmProbeAndLockPages(Mdl,
+                              UserMode,
+                              IoReadAccess);
+       }
+      _SEH_HANDLE
+        {
+         CopyStatus = _SEH_GetExceptionCode();
+        }
+      _SEH_END;
 
-      /* Copy memory from the mapped MDL into the target buffer. */
-      KeAttachProcess(&Process->Pcb);
+      if (NT_SUCCESS(CopyStatus))
+        {
+          /* Copy memory from the mapped MDL into the target buffer. */
+          KeAttachProcess(&Process->Pcb);
 
-      SystemAddress = MmGetSystemAddressForMdl(Mdl);
-      memcpy(BaseAddress, SystemAddress, NumberOfBytesToWrite);
+          SystemAddress = MmGetSystemAddressForMdl(Mdl);
+          if (PreviousMode != KernelMode)
+            {
+              _SEH_TRY
+                {
+                  memcpy(BaseAddress, SystemAddress, NumberOfBytesToWrite);
+                }
+              _SEH_HANDLE
+                {
+                  CopyStatus = _SEH_GetExceptionCode();
+                }
+              _SEH_END;
+            }
+          else
+            {
+              memcpy(BaseAddress, SystemAddress, NumberOfBytesToWrite);
+            }
 
-      KeDetachProcess();
+          KeDetachProcess();
+       }
 
       /* Free the MDL. */
       if (Mdl->MappedSystemVa != NULL)
@@ -760,25 +905,29 @@ NtWriteVirtualMemory(IN HANDLE ProcessHandle,
       MmUnlockPages(Mdl);
       ExFreePool(Mdl);
    }
-
-   /* Reset the protection of the target memory. */
-   Status = MiProtectVirtualMemory(Process,
-                                   &ProtectBaseAddress,
-                                   &ProtectNumberOfBytes,
-                                   OldProtection,
-                                   &OldProtection);
-   if (!NT_SUCCESS(Status))
-   {
-      DPRINT1("Failed to reset protection of the target memory! (Status 0x%x)\n", Status);
-      /* FIXME: Should we bugcheck here? */
-   }
-
    ObDereferenceObject(Process);
 
-   if (NumberOfBytesWritten != NULL)
-      MmCopyToCaller(NumberOfBytesWritten, &NumberOfBytesToWrite, sizeof(ULONG));
+   if (NT_SUCCESS(CopyStatus) && NumberOfBytesWritten != NULL)
+     {
+       if (PreviousMode != KernelMode)
+         {
+           _SEH_TRY
+             {
+               *NumberOfBytesWritten = NumberOfBytesToWrite;
+             }
+           _SEH_HANDLE
+             {
+               Status = _SEH_GetExceptionCode();
+             }
+           _SEH_END;
+         }
+       else
+         {
+           *NumberOfBytesWritten = NumberOfBytesToWrite;
+         }
+     }
 
-   return(STATUS_SUCCESS);
+   return(NT_SUCCESS(CopyStatus) ? Status : CopyStatus);
 }
 
 /*
@@ -850,7 +999,7 @@ ProbeForRead (IN CONST VOID *Address,
       ExRaiseStatus (STATUS_DATATYPE_MISALIGNMENT);
    }
    else if ((ULONG_PTR)Address + Length - 1 < (ULONG_PTR)Address ||
-            (ULONG_PTR)Address + Length - 1 > (ULONG_PTR)MmUserProbeAddress)
+            (ULONG_PTR)Address + Length - 1 >= (ULONG_PTR)MmUserProbeAddress)
    {
       ExRaiseStatus (STATUS_ACCESS_VIOLATION);
    }
@@ -878,20 +1027,22 @@ ProbeForWrite (IN CONST VOID *Address,
       ExRaiseStatus (STATUS_DATATYPE_MISALIGNMENT);
    }
 
-   Last = (PCHAR)((ULONG_PTR)Address + Length - 1);
+   Last = (CHAR*)((ULONG_PTR)Address + Length - 1);
    if ((ULONG_PTR)Last < (ULONG_PTR)Address ||
-       (ULONG_PTR)Last > (ULONG_PTR)MmUserProbeAddress)
+       (ULONG_PTR)Last >= (ULONG_PTR)MmUserProbeAddress)
    {
       ExRaiseStatus (STATUS_ACCESS_VIOLATION);
    }
 
    /* Check for accessible pages */
    Current = (CHAR*)Address;
-   do
+   *Current = *Current;
+   Current = (PCHAR)((ULONG_PTR)PAGE_ROUND_DOWN(Current) + PAGE_SIZE);
+   while (Current <= Last)
    {
      *Current = *Current;
      Current = (CHAR*)((ULONG_PTR)Current + PAGE_SIZE);
-   } while (Current <= Last);
+   } 
 }
 
 /* EOF */