[NTOS/ARM3]
authorAleksey Bragin <aleksey@reactos.org>
Sun, 29 Jul 2012 22:18:23 +0000 (22:18 +0000)
committerAleksey Bragin <aleksey@reactos.org>
Sun, 29 Jul 2012 22:18:23 +0000 (22:18 +0000)
- Implement MiProtectVirtualMemory for VAD based allocator (it's very similar to the already implemented MiSetProtectionOnSection function). However it still doesn't fully work because support in other functions is missing (failed assertions in MiFlushTbAndCapture).
See issue #7216 for more details.

svn path=/trunk/; revision=56986

reactos/ntoskrnl/mm/ARM3/virtual.c

index 18295d8..1be2106 100644 (file)
@@ -26,6 +26,15 @@ MiProtectVirtualMemory(IN PEPROCESS Process,
                        IN ULONG NewAccessProtection,
                        OUT PULONG OldAccessProtection  OPTIONAL);
 
+VOID
+NTAPI
+MiFlushTbAndCapture(IN PMMVAD FoundVad,
+                    IN PMMPTE PointerPte,
+                    IN ULONG ProtectionMask,
+                    IN PMMPFN Pfn1,
+                    IN BOOLEAN CaptureDirtyBit);
+
+
 /* PRIVATE FUNCTIONS **********************************************************/
 
 ULONG
@@ -1719,7 +1728,17 @@ MiProtectVirtualMemory(IN PEPROCESS Process,
                        OUT PULONG OldAccessProtection OPTIONAL)
 {
     PMEMORY_AREA MemoryArea;
+    PMMVAD Vad;
+    PMMSUPPORT AddressSpace;
+    ULONG_PTR StartingAddress, EndingAddress;
+    PMMPTE PointerPde, PointerPte, LastPte;
+    MMPTE PteContents;
+    PUSHORT UsedPageTableEntries;
+    PMMPFN Pfn1;
+    ULONG ProtectionMask;
+    NTSTATUS Status = STATUS_SUCCESS;
 
+    /* Check for ROS specific memory area */
     MemoryArea = MmLocateMemoryAreaByAddress(&Process->Vm, *BaseAddress);
     if ((MemoryArea) && (MemoryArea->Type == MEMORY_AREA_SECTION_VIEW))
     {
@@ -1730,8 +1749,175 @@ MiProtectVirtualMemory(IN PEPROCESS Process,
                                          OldAccessProtection);
     }
 
-    UNIMPLEMENTED;
-    return STATUS_CONFLICTING_ADDRESSES;
+    /* Calcualte base address for the VAD */
+    StartingAddress = (ULONG_PTR)PAGE_ALIGN((*BaseAddress));
+    EndingAddress = (((ULONG_PTR)*BaseAddress + *NumberOfBytesToProtect - 1) | (PAGE_SIZE - 1));
+
+    /* Calculate the protection mask and make sure it's valid */
+    ProtectionMask = MiMakeProtectionMask(NewAccessProtection);
+    if (ProtectionMask == MM_INVALID_PROTECTION)
+    {
+        DPRINT1("Invalid protection mask\n");
+        return STATUS_INVALID_PAGE_PROTECTION;
+    }
+
+    /* Lock the address space and make sure the process isn't already dead */
+    AddressSpace = MmGetCurrentAddressSpace();
+    MmLockAddressSpace(AddressSpace);
+    if (Process->VmDeleted)
+    {
+        DPRINT1("Process is dying\n");
+        Status = STATUS_PROCESS_IS_TERMINATING;
+        goto FailPath;
+    }
+
+    /* Get the VAD for this address range, and make sure it exists */
+    Vad = (PMMVAD)MiCheckForConflictingNode(StartingAddress >> PAGE_SHIFT,
+                                            EndingAddress >> PAGE_SHIFT,
+                                            &Process->VadRoot);
+    if (!Vad)
+    {
+        DPRINT("Could not find a VAD for this allocation\n");
+        Status = STATUS_CONFLICTING_ADDRESSES;
+        goto FailPath;
+    }
+
+    /* Make sure the address is within this VAD's boundaries */
+    if ((((ULONG_PTR)StartingAddress >> PAGE_SHIFT) < Vad->StartingVpn) ||
+        (((ULONG_PTR)EndingAddress >> PAGE_SHIFT) > Vad->EndingVpn))
+    {
+        Status = STATUS_CONFLICTING_ADDRESSES;
+        goto FailPath;
+    }
+
+    /* These kinds of VADs are not supported atm  */
+    if ((Vad->u.VadFlags.VadType == VadAwe) ||
+        (Vad->u.VadFlags.VadType == VadDevicePhysicalMemory) ||
+        (Vad->u.VadFlags.VadType == VadLargePages))
+    {
+        DPRINT1("Illegal VAD for attempting to set protection\n");
+        Status = STATUS_CONFLICTING_ADDRESSES;
+        goto FailPath;
+    }
+
+    /* Check for a VAD whose protection can't be changed */
+    if (Vad->u.VadFlags.NoChange == 1)
+    {
+        DPRINT1("Trying to change protection of a NoChange VAD\n");
+        Status = STATUS_INVALID_PAGE_PROTECTION;
+        goto FailPath;
+    }
+
+    if (Vad->u.VadFlags.PrivateMemory == 0)
+    {
+        /* This is a section, handled by the ROS specific code above */
+        UNIMPLEMENTED;
+    }
+    else
+    {
+        /* Private memory, check protection flags */
+        if ((NewAccessProtection & PAGE_WRITECOPY) ||
+            (NewAccessProtection & PAGE_EXECUTE_WRITECOPY))
+        {
+            Status = STATUS_INVALID_PARAMETER_4;
+            goto FailPath;
+        }
+
+        //MiLockProcessWorkingSet(Thread, Process);
+
+        /* TODO: Check if all pages in this range are committed */
+
+        /* Compute starting and ending PTE and PDE addresses */
+        PointerPde = MiAddressToPde(StartingAddress);
+        PointerPte = MiAddressToPte(StartingAddress);
+        LastPte = MiAddressToPte(EndingAddress);
+
+        /* Make this PDE valid */
+        MiMakePdeExistAndMakeValid(PointerPde, Process, MM_NOIRQL);
+
+        /* Save protection of the first page */
+        if (PointerPte->u.Long != 0)
+        {
+            /* Capture the page protection and make the PDE valid */
+            *OldAccessProtection = MiGetPageProtection(PointerPte);
+            MiMakePdeExistAndMakeValid(PointerPde, Process, MM_NOIRQL);
+        }
+        else
+        {
+            /* Grab the old protection from the VAD itself */
+            *OldAccessProtection = MmProtectToValue[Vad->u.VadFlags.Protection];
+        }
+
+        /* Loop all the PTEs now */
+        while (PointerPte <= LastPte)
+        {
+            /* Check if we've crossed a PDE boundary and make the new PDE valid too */
+            if ((((ULONG_PTR)PointerPte) & (SYSTEM_PD_SIZE - 1)) == 0)
+            {
+                PointerPde = MiAddressToPte(PointerPte);
+                MiMakePdeExistAndMakeValid(PointerPde, Process, MM_NOIRQL);
+            }
+
+            /* Capture the PTE and see what we're dealing with */
+            PteContents = *PointerPte;
+            if (PteContents.u.Long == 0)
+            {
+                /* This used to be a zero PTE and it no longer is, so we must add a
+                   reference to the pagetable. */
+                UsedPageTableEntries = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(MiPteToAddress(PointerPte))];
+                (*UsedPageTableEntries)++;
+                ASSERT((*UsedPageTableEntries) <= PTE_COUNT);
+            }
+            else if (PteContents.u.Hard.Valid == 1)
+            {
+                /* Get the PFN entry */
+                Pfn1 = MiGetPfnEntry(PFN_FROM_PTE(&PteContents));
+
+                /* We don't support this yet */
+                ASSERT(Pfn1->u3.e1.PrototypePte == 0);
+
+                /* Check if the page should not be accessible at all */
+                if ((NewAccessProtection & PAGE_NOACCESS) ||
+                    (NewAccessProtection & PAGE_GUARD))
+                {
+                    /* TODO */
+                    UNIMPLEMENTED;
+                }
+
+                /* Write the protection mask and write it with a TLB flush */
+                Pfn1->OriginalPte.u.Soft.Protection = ProtectionMask;
+                MiFlushTbAndCapture(Vad,
+                                    PointerPte,
+                                    ProtectionMask,
+                                    Pfn1,
+                                    TRUE);
+            }
+            else
+            {
+                /* We don't support these cases yet */
+                ASSERT(PteContents.u.Soft.Prototype == 0);
+                ASSERT(PteContents.u.Soft.Transition == 0);
+
+                /* The PTE is already demand-zero, just update the protection mask */
+                PointerPte->u.Soft.Protection = ProtectionMask;
+            }
+
+            PointerPte++;
+        }
+
+        /* Unlock the working set and update quota charges if needed, then return */
+        //MiUnlockProcessWorkingSet(Thread, Process);
+    }
+
+FailPath:
+    /* Unlock the address space */
+    MmUnlockAddressSpace(AddressSpace);
+
+    /* Return parameters */
+    *NumberOfBytesToProtect = (SIZE_T)((PUCHAR)EndingAddress - (PUCHAR)StartingAddress + 1);
+    *BaseAddress = (PVOID)StartingAddress;
+
+    return Status;
 }
 
 VOID