[NTOS:MM]
authorThomas Faber <thomas.faber@reactos.org>
Wed, 22 Oct 2014 13:04:57 +0000 (13:04 +0000)
committerThomas Faber <thomas.faber@reactos.org>
Wed, 22 Oct 2014 13:04:57 +0000 (13:04 +0000)
Make special pool usable:
- Invalidate PTEs on free to catch use-after-free situations (and not confuse Mm)
- Fix pattern check not to look for more than 8 bits in a byte
- Enable POOL_FLAG_SPECIAL_POOL if special pool has been initialized
- Implement MmExpandSpecialPool
- Issue the correct SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION bugcheck when problems are detected
- Magic values--
To enable special pool for a single tag, set a value for MmSpecialPoolTag in ntoskrnl/mm/ARM3/pool.c.
To enable it for more than one tag, set MmSpecialPoolTag and modify MmUseSpecialPool in ntoskrnl/mm/ARM3/special.c (e.g. to return TRUE independent of Tag).
CORE-8680 #resolve

svn path=/trunk/; revision=64886

reactos/ntoskrnl/include/internal/mm.h
reactos/ntoskrnl/mm/ARM3/special.c

index 219c0b4..e6f8e24 100644 (file)
@@ -719,7 +719,7 @@ MmPageFault(
 
 VOID
 NTAPI
-MiInitializeSpecialPool();
+MiInitializeSpecialPool(VOID);
 
 BOOLEAN
 NTAPI
index 49e6055..35116d7 100644 (file)
@@ -20,6 +20,7 @@
 #define MODULE_INVOLVED_IN_ARM3
 #include "../ARM3/miarm.h"
 
+extern ULONG ExpPoolFlags;
 extern PMMPTE MmSystemPteBase;
 
 PMMPTE
@@ -73,22 +74,20 @@ MmUseSpecialPool(SIZE_T NumberOfBytes, ULONG Tag)
     if (NumberOfBytes > (PAGE_SIZE - sizeof(POOL_HEADER)))
         return FALSE;
 
-    // FIXME
-    //return TRUE;
-    return FALSE;
+    return Tag == MmSpecialPoolTag;
 }
 
 BOOLEAN
 NTAPI
 MmIsSpecialPoolAddress(PVOID P)
 {
-    return((P >= MmSpecialPoolStart) &&
-           (P <= MmSpecialPoolEnd));
+    return ((P >= MmSpecialPoolStart) &&
+            (P <= MmSpecialPoolEnd));
 }
 
 VOID
 NTAPI
-MiInitializeSpecialPool()
+MiInitializeSpecialPool(VOID)
 {
     ULONG SpecialPoolPtes, i;
     PMMPTE PointerPte;
@@ -98,7 +97,7 @@ MiInitializeSpecialPool()
         (MmSpecialPoolTag == -1)) return;
 
     /* Calculate number of system PTEs for the special pool */
-    if ( MmNumberOfSystemPtes >= 0x3000 )
+    if (MmNumberOfSystemPtes >= 0x3000)
         SpecialPoolPtes = MmNumberOfSystemPtes / 3;
     else
         SpecialPoolPtes = MmNumberOfSystemPtes / 6;
@@ -119,7 +118,7 @@ MiInitializeSpecialPool()
 
         /* Reserving didn't work, so try to reduce the requested size */
         ASSERT(SpecialPoolPtes >= PTE_PER_PAGE);
-        SpecialPoolPtes -= 1024;
+        SpecialPoolPtes -= PTE_PER_PAGE;
     } while (SpecialPoolPtes);
 
     /* Fail if we couldn't reserve them at all */
@@ -132,7 +131,7 @@ MiInitializeSpecialPool()
     MiSpecialPoolFirstPte = PointerPte;
     MmSpecialPoolStart = MiPteToAddress(PointerPte);
 
-    for (i = 0; i<512; i++)
+    for (i = 0; i < PTE_PER_PAGE / 2; i++)
     {
         /* Point it to the next entry */
         PointerPte->u.List.NextEntry = &PointerPte[2] - MmSystemPteBase;
@@ -143,7 +142,7 @@ MiInitializeSpecialPool()
 
     /* Save extra values */
     MiSpecialPoolExtra = PointerPte;
-    MiSpecialPoolExtraCount = SpecialPoolPtes - 1024;
+    MiSpecialPoolExtraCount = SpecialPoolPtes - PTE_PER_PAGE;
 
     /* Mark the previous PTE as the last one */
     MiSpecialPoolLastPte = PointerPte - 2;
@@ -160,10 +159,52 @@ MiInitializeSpecialPool()
         MiSpecialPagesNonPagedMaximum = MmResidentAvailablePages >> 3;
 
     DPRINT1("Special pool start %p - end %p\n", MmSpecialPoolStart, MmSpecialPoolEnd);
+    ExpPoolFlags |= POOL_FLAG_SPECIAL_POOL;
 
     //MiTestSpecialPool();
 }
 
+NTSTATUS
+NTAPI
+MmExpandSpecialPool(VOID)
+{
+    ULONG i;
+    PMMPTE PointerPte;
+
+    ASSERT(KeGetCurrentIrql() == DISPATCH_LEVEL);
+
+    if (MiSpecialPoolExtraCount == 0)
+        return STATUS_INSUFFICIENT_RESOURCES;
+
+    PointerPte = MiSpecialPoolExtra;
+    ASSERT(MiSpecialPoolFirstPte == MiSpecialPoolLastPte);
+    ASSERT(MiSpecialPoolFirstPte->u.List.NextEntry == MM_EMPTY_PTE_LIST);
+    MiSpecialPoolFirstPte->u.List.NextEntry = PointerPte - MmSystemPteBase;
+
+    ASSERT(MiSpecialPoolExtraCount >= PTE_PER_PAGE);
+    for (i = 0; i < PTE_PER_PAGE / 2; i++)
+    {
+        /* Point it to the next entry */
+        PointerPte->u.List.NextEntry = &PointerPte[2] - MmSystemPteBase;
+
+        /* Move to the next pair */
+        PointerPte += 2;
+    }
+
+    /* Save remaining extra values */
+    MiSpecialPoolExtra = PointerPte;
+    MiSpecialPoolExtraCount -= PTE_PER_PAGE;
+
+    /* Mark the previous PTE as the last one */
+    MiSpecialPoolLastPte = PointerPte - 2;
+    MiSpecialPoolLastPte->u.List.NextEntry = MM_EMPTY_PTE_LIST;
+
+    /* Save new end address of the special pool */
+    MmSpecialPoolEnd = MiPteToAddress(MiSpecialPoolLastPte + 1);
+
+    return STATUS_SUCCESS;
+}
+
 PVOID
 NTAPI
 MmAllocateSpecialPool(SIZE_T NumberOfBytes, ULONG Tag, POOL_TYPE PoolType, ULONG SpecialType)
@@ -176,7 +217,7 @@ MmAllocateSpecialPool(SIZE_T NumberOfBytes, ULONG Tag, POOL_TYPE PoolType, ULONG
     PVOID Entry;
     PPOOL_HEADER Header;
 
-    DPRINT1("MmAllocateSpecialPool(%x %x %x %x)\n", NumberOfBytes, Tag, PoolType, SpecialType);
+    DPRINT("MmAllocateSpecialPool(%x %x %x %x)\n", NumberOfBytes, Tag, PoolType, SpecialType);
 
     /* Check if the pool is initialized and quit if it's not */
     if (!MiSpecialPoolFirstPte) return NULL;
@@ -191,7 +232,11 @@ MmAllocateSpecialPool(SIZE_T NumberOfBytes, ULONG Tag, POOL_TYPE PoolType, ULONG
         ((PoolType != PagedPool) && (Irql > DISPATCH_LEVEL)))
     {
         /* Bad caller */
-        KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION, Irql, PoolType, NumberOfBytes, 0x30);
+        KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                     Irql,
+                     PoolType,
+                     NumberOfBytes,
+                     0x30);
     }
 
     /* TODO: Take into account various limitations */
@@ -210,14 +255,19 @@ MmAllocateSpecialPool(SIZE_T NumberOfBytes, ULONG Tag, POOL_TYPE PoolType, ULONG
         return NULL;
     }
 
-    /* Reject allocation if special pool PTE list is exhausted */
+    /* Check if special pool PTE list is exhausted */
     if (MiSpecialPoolFirstPte->u.List.NextEntry == MM_EMPTY_PTE_LIST)
     {
-        /* Release the PFN database lock */
-        KeReleaseQueuedSpinLock(LockQueuePfnLock, Irql);
-        DPRINT1("Special pool: No PTEs left!\n");
-        /* TODO: Expand the special pool */
-        return NULL;
+        /* Try to expand it */
+        if (!NT_SUCCESS(MmExpandSpecialPool()))
+        {
+            /* No reserves left, reject this allocation */
+            static int once;
+            KeReleaseQueuedSpinLock(LockQueuePfnLock, Irql);
+            if (!once++) DPRINT1("Special pool: No PTEs left!\n");
+            return NULL;
+        }
+        ASSERT(MiSpecialPoolFirstPte->u.List.NextEntry != MM_EMPTY_PTE_LIST);
     }
 
     /* Save allocation time */
@@ -286,8 +336,8 @@ MmAllocateSpecialPool(SIZE_T NumberOfBytes, ULONG Tag, POOL_TYPE PoolType, ULONG
        That time will be used to check memory consistency within the allocated
        page. */
     Header->PoolTag = Tag;
-    Header->BlockSize = (USHORT)TickCount.LowPart;
-    DPRINT1("%p\n", Entry);
+    Header->BlockSize = (UCHAR)TickCount.LowPart;
+    DPRINT("%p\n", Entry);
     return Entry;
 }
 
@@ -300,6 +350,7 @@ MiSpecialPoolCheckPattern(PUCHAR P, PPOOL_HEADER Header)
 
     /* Get amount of bytes user requested to be allocated by clearing out the paged mask */
     BytesRequested = (Header->Ulong1 & ~SPECIAL_POOL_PAGED) & 0xFFFF;
+    ASSERT(BytesRequested <= PAGE_SIZE - sizeof(POOL_HEADER));
 
     /* Get a pointer to the end of user's area */
     Ptr = P + BytesRequested;
@@ -320,7 +371,11 @@ MiSpecialPoolCheckPattern(PUCHAR P, PPOOL_HEADER Header)
         /* Bugcheck if bytes don't match */
         if (Ptr[Index] != Header->BlockSize)
         {
-            KeBugCheckEx(BAD_POOL_HEADER, (ULONG_PTR)P, (ULONG_PTR)&Ptr[Index], Header->BlockSize, 0x24);
+            KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                         (ULONG_PTR)P,
+                         (ULONG_PTR)&Ptr[Index],
+                         Header->BlockSize,
+                         0x24);
         }
     }
 }
@@ -341,7 +396,7 @@ MmFreeSpecialPool(PVOID P)
     LARGE_INTEGER TickCount;
     PMMPFN Pfn;
 
-    DPRINT1("MmFreeSpecialPool(%p)\n", P);
+    DPRINT("MmFreeSpecialPool(%p)\n", P);
 
     /* Get the PTE */
     PointerPte = MiAddressToPte(P);
@@ -353,7 +408,11 @@ MmFreeSpecialPool(PVOID P)
         if (PointerPte->u.Soft.Protection == MM_NOACCESS ||
             !PointerPte->u.Soft.Protection)
         {
-            KeBugCheckEx(BAD_POOL_HEADER, (ULONG_PTR)P, (ULONG_PTR)PointerPte, 0, 0x20);
+            KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                         (ULONG_PTR)P,
+                         (ULONG_PTR)PointerPte,
+                         0,
+                         0x20);
         }
     }
 
@@ -375,28 +434,35 @@ MmFreeSpecialPool(PVOID P)
     if ((Header->Ulong1 & SPECIAL_POOL_PAGED) == 0)
     {
         /* Non-paged allocation, ensure that IRQ is not higher that DISPATCH */
-        ASSERT((PointerPte + 1)->u.Soft.PageFileHigh == SPECIAL_POOL_NONPAGED_PTE);
+        PoolType = NonPagedPool;
+        ASSERT(PointerPte[1].u.Soft.PageFileHigh == SPECIAL_POOL_NONPAGED_PTE);
         if (Irql > DISPATCH_LEVEL)
         {
-            KeBugCheckEx(BAD_POOL_HEADER, Irql, (ULONG_PTR)P, 0, 0x31);
+            KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                         Irql,
+                         PoolType,
+                         (ULONG_PTR)P,
+                         0x31);
         }
-
-        PoolType = NonPagedPool;
     }
     else
     {
         /* Paged allocation, ensure */
-        ASSERT((PointerPte + 1)->u.Soft.PageFileHigh == SPECIAL_POOL_PAGED_PTE);
-        if (Irql > DISPATCH_LEVEL)
+        PoolType = PagedPool;
+        ASSERT(PointerPte[1].u.Soft.PageFileHigh == SPECIAL_POOL_PAGED_PTE);
+        if (Irql > APC_LEVEL)
         {
-            KeBugCheckEx(BAD_POOL_HEADER, Irql, (ULONG_PTR)P, 1, 0x31);
+            KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                         Irql,
+                         PoolType,
+                         (ULONG_PTR)P,
+                         0x31);
         }
-
-        PoolType = PagedPool;
     }
 
     /* Get amount of bytes user requested to be allocated by clearing out the paged mask */
     BytesRequested = (Header->Ulong1 & ~SPECIAL_POOL_PAGED) & 0xFFFF;
+    ASSERT(BytesRequested <= PAGE_SIZE - sizeof(POOL_HEADER));
 
     /* Check memory before the allocated user buffer in case of overruns detection */
     if (Overruns)
@@ -407,21 +473,33 @@ MmFreeSpecialPool(PVOID P)
         /* If they mismatch, it's unrecoverable */
         if (BytesRequested > BytesReal)
         {
-            KeBugCheckEx(BAD_POOL_HEADER, (ULONG_PTR)P, BytesRequested, BytesReal, 0x21);
+            KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                         (ULONG_PTR)P,
+                         BytesRequested,
+                         BytesReal,
+                         0x21);
         }
 
         if (BytesRequested + sizeof(POOL_HEADER) < BytesReal)
         {
-            KeBugCheckEx(BAD_POOL_HEADER, (ULONG_PTR)P, BytesRequested, BytesReal, 0x22);
+            KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                         (ULONG_PTR)P,
+                         BytesRequested,
+                         BytesReal,
+                         0x22);
         }
 
         /* Actually check the memory pattern */
         for (b = (PUCHAR)(Header + 1); b < (PUCHAR)P; b++)
         {
-            if (Header->BlockSize != b[0])
+            if (*b != Header->BlockSize)
             {
                 /* Bytes mismatch */
-                KeBugCheckEx(BAD_POOL_HEADER, (ULONG_PTR)P, (ULONG_PTR)b, Header->BlockSize, 0x23);
+                KeBugCheckEx(SPECIAL_POOL_DETECTED_MEMORY_CORRUPTION,
+                             (ULONG_PTR)P,
+                             (ULONG_PTR)b,
+                             Header->BlockSize,
+                             0x23);
             }
         }
     }
@@ -448,7 +526,6 @@ MmFreeSpecialPool(PVOID P)
         Pfn = MI_PFN_ELEMENT(PointerPte->u.Hard.PageFrameNumber);
 
         /* Lock PFN database */
-        ASSERT(KeGetCurrentIrql() <= DISPATCH_LEVEL);
         Irql = KeAcquireQueuedSpinLock(LockQueuePfnLock);
 
         /* Delete this PFN */
@@ -457,6 +534,8 @@ MmFreeSpecialPool(PVOID P)
         /* Decrement share count of this PFN */
         MiDecrementShareCount(Pfn, PointerPte->u.Hard.PageFrameNumber);
 
+        MI_ERASE_PTE(PointerPte);
+
         /* Flush the TLB */
         //FIXME: Use KeFlushSingleTb() instead
         KeFlushEntireTb(TRUE, TRUE);
@@ -467,12 +546,11 @@ MmFreeSpecialPool(PVOID P)
         MiDeleteSystemPageableVm(PointerPte, 1, 0, NULL);
 
         /* Lock PFN database */
-        ASSERT(KeGetCurrentIrql() <= DISPATCH_LEVEL);
         Irql = KeAcquireQueuedSpinLock(LockQueuePfnLock);
     }
 
     /* Mark next PTE as invalid */
-    PointerPte[1].u.Long = 0; //|= 8000;
+    MI_ERASE_PTE(PointerPte + 1);
 
     /* Make sure that the last entry is really the last one */
     ASSERT(MiSpecialPoolLastPte->u.List.NextEntry == MM_EMPTY_PTE_LIST);