Use ULONG_PTR instead of ULONG when doing pointer arithmetics.
[reactos.git] / reactos / ntoskrnl / mm / marea.c
index 69ab18e..e2d8867 100644 (file)
@@ -1,6 +1,6 @@
 /*
  *  ReactOS kernel
- *  Copyright (C) 1998, 1999, 2000, 2001 ReactOS Team
+ *  Copyright (C) 1998-2002 ReactOS Team
  *
  *  This program is free software; you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
@@ -27,6 +27,7 @@
 
 /* INCLUDES *****************************************************************/
 
+#include <roscfg.h>
 #include <ddk/ntddk.h>
 #include <internal/mm.h>
 #include <internal/ps.h>
 
 #define TAG_MAREA   TAG('M', 'A', 'R', 'E')
 
+#ifdef DBG
+PVOID MiMemoryAreaBugCheckAddress = (PVOID) NULL;
+#endif /* DBG */
+
+/* Define to track memory area references */
+//#define TRACK_MEMORY_AREA_REFERENCES
+
 /* FUNCTIONS *****************************************************************/
 
 VOID MmDumpMemoryAreas(PLIST_ENTRY ListHead)
@@ -61,8 +69,94 @@ VOID MmDumpMemoryAreas(PLIST_ENTRY ListHead)
    DbgPrint("Finished MmDumpMemoryAreas()\n");
 }
 
-MEMORY_AREA* MmOpenMemoryAreaByAddress(PMADDRESS_SPACE AddressSpace,
-                                      PVOID Address)
+#ifdef DBG
+
+VOID
+MiValidateMemoryAreaPTEs(IN PMEMORY_AREA  MemoryArea)
+{
+       ULONG PteProtect;
+       ULONG i;
+
+       if (!MiInitialized)
+               return;
+
+       for (i = 0; i <= (MemoryArea->Length / PAGESIZE); i++)
+               {
+                       if (MmIsPagePresent(MemoryArea->Process, MemoryArea->BaseAddress + (i * PAGESIZE)))
+                               {
+                                       PteProtect = MmGetPageProtect(MemoryArea->Process, MemoryArea->BaseAddress + (i * PAGESIZE));
+                                       if (PteProtect != MemoryArea->Attributes)
+                                               {
+                                                       if (MmIsCopyOnWriteMemoryArea(MemoryArea))
+                                                               {
+                                                                       if ((PteProtect != PAGE_READONLY) && (PteProtect != PAGE_EXECUTE_READ))
+                                                                               {
+                                                                                       DPRINT1("COW memory area attributes 0x%.08x\n", MemoryArea->Attributes);
+                                                                                       DbgMmDumpProtection(MemoryArea->Attributes);
+                                                                                       DPRINT1("PTE attributes 0x%.08x\n", PteProtect);
+                                                                                       DbgMmDumpProtection(PteProtect);
+                                                                                       assertmsg(FALSE, ("PTE attributes and memory area protection are different. Area 0x%.08x\n",
+                                                                                               MemoryArea->BaseAddress));
+                                                                               }
+                                                               }
+                                                               else
+                                                               {
+                                                                       DPRINT1("Memory area attributes 0x%.08x\n", MemoryArea->Attributes);
+                                                                       DbgMmDumpProtection(MemoryArea->Attributes);
+                                                                       DPRINT1("PTE attributes 0x%.08x\n", PteProtect);
+                                                                       DbgMmDumpProtection(PteProtect);
+                                                                       assertmsg(FALSE, ("PTE attributes and memory area protection are different. Area 0x%.08x\n",
+                                                                               MemoryArea->BaseAddress));
+                                                               }
+                                               }
+                               }
+               }
+}
+
+
+VOID
+MiValidateMemoryArea(IN PMEMORY_AREA  MemoryArea)
+{
+  assertmsg(MemoryArea != NULL,
+   ("No memory area can exist at 0x%.08x\n", MemoryArea));
+
+  assertmsg(MemoryArea->Magic == TAG_MAREA,
+   ("Bad magic (0x%.08x) for memory area (0x%.08x). It should be 0x%.08x\n",
+     MemoryArea->Magic, MemoryArea, TAG_MAREA));
+
+       /* FIXME: Can cause page faults and deadlock on the address space lock */
+  //MiValidateMemoryAreaPTEs(MemoryArea);
+}
+
+#endif /* DBG */
+
+VOID
+MmApplyMemoryAreaProtection(IN PMEMORY_AREA  MemoryArea)
+{
+       ULONG i;
+       
+       if (!MiInitialized)
+               return;
+
+       for (i = 0; i <= (MemoryArea->Length / PAGESIZE); i++)
+               {
+                       if (MmIsPagePresent(MemoryArea->Process, MemoryArea->BaseAddress + (i * PAGESIZE)))
+                               {
+                                       MmSetPageProtect(MemoryArea->Process,
+                                               MemoryArea->BaseAddress + (i * PAGESIZE),
+                                               MemoryArea->Attributes);
+                               }
+               }
+}
+
+
+/*
+ * NOTE: If the memory area is found, then it is referenced. The caller must
+ *       call MmCloseMemoryArea() after use.
+ */
+PMEMORY_AREA
+MmOpenMemoryAreaByAddress(IN PMADDRESS_SPACE  AddressSpace,
+  IN PVOID  Address)
 {
    PLIST_ENTRY current_entry;
    MEMORY_AREA* current;
@@ -70,9 +164,9 @@ MEMORY_AREA* MmOpenMemoryAreaByAddress(PMADDRESS_SPACE AddressSpace,
 
    DPRINT("MmOpenMemoryAreaByAddress(AddressSpace %x, Address %x)\n",
           AddressSpace, Address);
-   
+       
 //   MmDumpMemoryAreas(&AddressSpace->MAreaListHead);
-   
+
    previous_entry = &AddressSpace->MAreaListHead;
    current_entry = AddressSpace->MAreaListHead.Flink;
    while (current_entry != &AddressSpace->MAreaListHead)
@@ -102,6 +196,7 @@ MEMORY_AREA* MmOpenMemoryAreaByAddress(PMADDRESS_SPACE AddressSpace,
            (current->BaseAddress + current->Length) > Address)
          {
             DPRINT("%s() = %x\n",__FUNCTION__,current);
+       MmReferenceMemoryArea(current);
             return(current);
          }
        if (current->BaseAddress > Address)
@@ -116,6 +211,10 @@ MEMORY_AREA* MmOpenMemoryAreaByAddress(PMADDRESS_SPACE AddressSpace,
    return(NULL);
 }
 
+/*
+ * NOTE: If the memory area is found, then it is referenced. The caller must
+ *       call MmCloseMemoryArea() after use.
+ */
 MEMORY_AREA* MmOpenMemoryAreaByRegion(PMADDRESS_SPACE AddressSpace, 
                                      PVOID Address,
                                      ULONG Length)
@@ -140,6 +239,7 @@ MEMORY_AREA* MmOpenMemoryAreaByRegion(PMADDRESS_SPACE AddressSpace,
          {
             DPRINT("Finished MmOpenMemoryAreaByRegion() = %x\n",
                    current);
+       MmReferenceMemoryArea(current);
             return(current);
          }
        Extent = (ULONG)current->BaseAddress + current->Length;
@@ -148,6 +248,7 @@ MEMORY_AREA* MmOpenMemoryAreaByRegion(PMADDRESS_SPACE AddressSpace,
          {
             DPRINT("Finished MmOpenMemoryAreaByRegion() = %x\n",
                    current);
+       MmReferenceMemoryArea(current);
             return(current);
          }
        if (current->BaseAddress <= Address &&
@@ -155,6 +256,7 @@ MEMORY_AREA* MmOpenMemoryAreaByRegion(PMADDRESS_SPACE AddressSpace,
          {
             DPRINT("Finished MmOpenMemoryAreaByRegion() = %x\n",
                    current);
+       MmReferenceMemoryArea(current);
             return(current);
          }
        if (current->BaseAddress >= (Address+Length))
@@ -168,6 +270,14 @@ MEMORY_AREA* MmOpenMemoryAreaByRegion(PMADDRESS_SPACE AddressSpace,
    return(NULL);
 }
 
+
+VOID
+MmCloseMemoryArea(IN PMEMORY_AREA  MemoryArea)
+{
+  MmDereferenceMemoryArea(MemoryArea);
+}
+
+
 static VOID MmInsertMemoryArea(PMADDRESS_SPACE AddressSpace,
                               MEMORY_AREA* marea)
 {
@@ -176,7 +286,7 @@ static VOID MmInsertMemoryArea(PMADDRESS_SPACE AddressSpace,
    PLIST_ENTRY inserted_entry = &marea->Entry;
    MEMORY_AREA* current;
    MEMORY_AREA* next;   
-   
+
    DPRINT("MmInsertMemoryArea(marea %x)\n", marea);
    DPRINT("marea->BaseAddress %x\n", marea->BaseAddress);
    DPRINT("marea->Length %x\n", marea->Length);
@@ -224,7 +334,7 @@ static VOID MmInsertMemoryArea(PMADDRESS_SPACE AddressSpace,
          }
        if (current->BaseAddress < marea->BaseAddress &&
            next->BaseAddress > marea->BaseAddress)
-         {          
+         {
             DPRINT("Inserting before %x\n", current_entry);
             inserted_entry->Flink = current_entry->Flink;
             inserted_entry->Blink = current_entry;
@@ -290,50 +400,75 @@ NTSTATUS MmInitMemoryAreas(VOID)
    return(STATUS_SUCCESS);
 }
 
-NTSTATUS 
-MmFreeMemoryArea(PMADDRESS_SPACE AddressSpace,
-                PVOID BaseAddress,
-                ULONG Length,
-                VOID (*FreePage)(PVOID Context, MEMORY_AREA* MemoryArea, PVOID Address, 
-                                 ULONG PhysAddr, SWAPENTRY SwapEntry, BOOLEAN Dirty),
-                PVOID FreePageContext)
+/* NOTE: The address space lock must be held when called */
+NTSTATUS
+MmFreeMemoryArea(IN PMADDRESS_SPACE  AddressSpace,
+       IN PVOID  BaseAddress,
+       IN ULONG  Length,
+       IN PFREE_MEMORY_AREA_PAGE_CALLBACK  FreePage,
+       IN PVOID FreePageContext)
 {
-   MEMORY_AREA* MemoryArea;
-   ULONG i;
-   
-   DPRINT("MmFreeMemoryArea(AddressSpace %x, BaseAddress %x, Length %x,"
-          "FreePageContext %d)\n",AddressSpace,BaseAddress,Length,FreePageContext);
-
-   MemoryArea = MmOpenMemoryAreaByAddress(AddressSpace,
-                                         BaseAddress);
-   if (MemoryArea == NULL)
-     {       
-       KeBugCheck(0);
-       return(STATUS_UNSUCCESSFUL);
-     }
+       MEMORY_AREA* MemoryArea;
+       ULONG i;
+       
+       DPRINT("MmFreeMemoryArea(AddressSpace %x, BaseAddress %x, Length %x, "
+         "FreePageContext %d)\n",AddressSpace,BaseAddress,Length,FreePageContext);
+       
+  MemoryArea = MmOpenMemoryAreaByAddress(AddressSpace, BaseAddress);
+  if (MemoryArea == NULL)
+    {
+      assertmsg(FALSE, ("Freeing non-existant memory area at 0x%.08x\n", BaseAddress));
+      return(STATUS_UNSUCCESSFUL);
+    }
+
+  MmCloseMemoryArea(MemoryArea);
+  InterlockedDecrement(&MemoryArea->ReferenceCount);
+#if 0
+       assertmsg(MemoryArea->ReferenceCount == 0,
+    ("Memory area at address 0x%.08x has %d outstanding references\n",
+      BaseAddress, MemoryArea->ReferenceCount));
+#endif
    for (i=0; i<(PAGE_ROUND_UP(MemoryArea->Length)/PAGESIZE); i++)
      {
-       ULONG PhysAddr = 0;
-       BOOL Dirty;
+       ULONG_PTR PhysicalPage = 0;
+       BOOLEAN Dirty = FALSE;
        SWAPENTRY SwapEntry = 0;
+       PVOID VirtualPage = NULL;
 
-       if (MmIsPageSwapEntry(AddressSpace->Process,
-                            MemoryArea->BaseAddress + (i * PAGESIZE)))
+       VirtualPage = MemoryArea->BaseAddress + (i * PAGESIZE);
+
+#ifdef DBG
+       if ((MiMemoryAreaBugCheckAddress != NULL)
+         && ((MiMemoryAreaBugCheckAddress >= VirtualPage)
+         && MiMemoryAreaBugCheckAddress < VirtualPage + PAGESIZE))
+        {
+          assertmsg(FALSE, ("VirtualPage 0x%.08x  MiMemoryAreaBugCheckAddress 0x%.08x \n",
+            VirtualPage));
+        }
+#endif
+
+       if (FreePage != NULL)
+        {
+          FreePage(TRUE, FreePageContext, MemoryArea,
+                  VirtualPage, 0, 0, FALSE);
+        }
+
+       if (MmIsPageSwapEntry(AddressSpace->Process, VirtualPage))
         {
           MmDeletePageFileMapping(AddressSpace->Process,
-                                  MemoryArea->BaseAddress + (i * PAGESIZE),
+                                  VirtualPage,
                                   &SwapEntry);
         }
        else
         {
           MmDeleteVirtualMapping(AddressSpace->Process, 
-                                 MemoryArea->BaseAddress + (i*PAGESIZE),
-                                 FALSE, &Dirty, &PhysAddr);
+                                 VirtualPage,
+                                 FALSE, &Dirty, &PhysicalPage);
         }
        if (FreePage != NULL)
         {
-          FreePage(FreePageContext, MemoryArea,
-                   MemoryArea->BaseAddress + (i * PAGESIZE), PhysAddr, SwapEntry, Dirty);
+          FreePage(FALSE, FreePageContext, MemoryArea,
+                   VirtualPage, PhysicalPage, SwapEntry, Dirty);
         }
      }
    
@@ -356,7 +491,7 @@ PMEMORY_AREA MmSplitMemoryArea(PEPROCESS Process,
    PMEMORY_AREA Result;
    PMEMORY_AREA Split;
    
-   Result = ExAllocatePoolWithTag(NonPagedPool, sizeof(MEMORY_AREA), 
+   Result = ExAllocatePoolWithTag(NonPagedPool, sizeof(MEMORY_AREA),
                                  TAG_MAREA);
    RtlZeroMemory(Result,sizeof(MEMORY_AREA));
    Result->Type = NewType;
@@ -364,8 +499,9 @@ PMEMORY_AREA MmSplitMemoryArea(PEPROCESS Process,
    Result->Length = Length;
    Result->Attributes = NewAttributes;
    Result->LockCount = 0;
+   Result->ReferenceCount = 1;
    Result->Process = Process;
-   
+
    if (BaseAddress == OriginalMemoryArea->BaseAddress)
      {
        OriginalMemoryArea->BaseAddress = BaseAddress + Length;
@@ -381,7 +517,7 @@ PMEMORY_AREA MmSplitMemoryArea(PEPROCESS Process,
 
        return(Result);
      }
-      
+
    Split = ExAllocatePoolWithTag(NonPagedPool, sizeof(MEMORY_AREA),
                                 TAG_MAREA);
    RtlCopyMemory(Split,OriginalMemoryArea,sizeof(MEMORY_AREA));
@@ -394,14 +530,15 @@ PMEMORY_AREA MmSplitMemoryArea(PEPROCESS Process,
    return(Split);
 }
 
-NTSTATUS MmCreateMemoryArea(PEPROCESS Process,
-                           PMADDRESS_SPACE AddressSpace,
-                           ULONG Type,
-                           PVOID* BaseAddress,
-                           ULONG Length,
-                           ULONG Attributes,
-                           MEMORY_AREA** Result,
-                           BOOL FixedAddress)
+NTSTATUS
+MmCreateMemoryArea(IN PEPROCESS  Process,
+       IN PMADDRESS_SPACE  AddressSpace,
+       IN ULONG  Type,
+       IN OUT PVOID*  BaseAddress,
+       IN ULONG  Length,
+       IN ULONG  Attributes,
+       OUT PMEMORY_AREA*  Result,
+       IN BOOLEAN  FixedAddress)
 /*
  * FUNCTION: Create a memory area
  * ARGUMENTS:
@@ -411,10 +548,13 @@ NTSTATUS MmCreateMemoryArea(PEPROCESS Process,
  *     Length = Length to allocate
  *     Attributes = Protection attributes for the memory area
  *     Result = Receives a pointer to the memory area on exit
+ *     FixedAddress = Wether the memory area must be based at BaseAddress or not
  * RETURNS: Status
  * NOTES: Lock the address space before calling this function
  */
 {
+        PMEMORY_AREA  MemoryArea;
+       
    DPRINT("MmCreateMemoryArea(Type %d, BaseAddress %x,"
           "*BaseAddress %x, Length %x, Attributes %x, Result %x)\n",
           Type,BaseAddress,*BaseAddress,Length,Attributes,Result);
@@ -433,27 +573,90 @@ NTSTATUS MmCreateMemoryArea(PEPROCESS Process,
    else
      {
        (*BaseAddress) = (PVOID)PAGE_ROUND_DOWN((*BaseAddress));
-       if (MmOpenMemoryAreaByRegion(AddressSpace,
-                                    *BaseAddress,
-                                    Length)!=NULL)
+       MemoryArea = MmOpenMemoryAreaByRegion(AddressSpace, *BaseAddress, Length);
+       if (MemoryArea)
          {
+                        MmCloseMemoryArea(MemoryArea);
             DPRINT("Memory area already occupied\n");
             return(STATUS_CONFLICTING_ADDRESSES);
          }
      }
-   
+
+       DPRINT("MmCreateMemoryArea(*BaseAddress %x)\n", *BaseAddress);
+
    *Result = ExAllocatePoolWithTag(NonPagedPool, sizeof(MEMORY_AREA),
                                   TAG_MAREA);
    RtlZeroMemory(*Result,sizeof(MEMORY_AREA));
+   SET_MAGIC(*Result, TAG_MAREA)
    (*Result)->Type = Type;
    (*Result)->BaseAddress = *BaseAddress;
    (*Result)->Length = Length;
    (*Result)->Attributes = Attributes;
    (*Result)->LockCount = 0;
+   (*Result)->ReferenceCount = 1;
    (*Result)->Process = Process;
-   
+
+   MmApplyMemoryAreaProtection(*Result);
+        
    MmInsertMemoryArea(AddressSpace, *Result);
    
    DPRINT("MmCreateMemoryArea() succeeded\n");
    return(STATUS_SUCCESS);
 }
+
+#ifdef DBG
+
+VOID
+MiReferenceMemoryArea(IN PMEMORY_AREA  MemoryArea,
+  IN LPSTR  FileName,
+       IN ULONG  LineNumber)
+{
+  VALIDATE_MEMORY_AREA(MemoryArea);
+
+  InterlockedIncrement(&MemoryArea->ReferenceCount);
+
+#ifdef TRACK_MEMORY_AREA_REFERENCES
+       DbgPrint("(0x%.08x)(%s:%d) Referencing memory area 0x%.08x (New ref.count %d)\n",
+               KeGetCurrentThread(), FileName, LineNumber,
+               MemoryArea->BaseAddress,
+               MemoryArea->ReferenceCount);
+#endif /* TRACK_MEMORY_AREA_REFERENCES */
+}
+
+
+VOID
+MiDereferenceMemoryArea(IN PMEMORY_AREA  MemoryArea,
+  IN LPSTR  FileName,
+       IN ULONG  LineNumber)
+{
+  VALIDATE_MEMORY_AREA(MemoryArea);
+
+  InterlockedDecrement(&MemoryArea->ReferenceCount);
+
+#ifdef TRACK_MEMORY_AREA_REFERENCES
+       DbgPrint("(0x%.08x)(%s:%d) Dereferencing memory area 0x%.08x (New ref.count %d)\n",
+               KeGetCurrentThread(), FileName, LineNumber,
+               MemoryArea->BaseAddress,
+               MemoryArea->ReferenceCount);
+#endif /* TRACK_MEMORY_AREA_REFERENCES */
+
+  assertmsg(MemoryArea->ReferenceCount > 0,
+    ("No outstanding references on memory area (0x%.08x)\n", MemoryArea));
+}
+
+#else /* !DBG */
+
+VOID
+MiReferenceMemoryArea(IN PMEMORY_AREA  MemoryArea)
+{
+  InterlockedIncrement(&MemoryArea->ReferenceCount);
+}
+
+
+VOID
+MiDereferenceMemoryArea(IN PMEMORY_AREA  MemoryArea)
+{
+  InterlockedDecrement(&MemoryArea->ReferenceCount);
+}
+
+#endif /* !DBG */