[NTDLL] Use LDR_LOCK_LOADER_LOCK_FLAG_RAISE_ON_ERRORS flag instead of TRUE
authorTimo Kreuzer <timo.kreuzer@reactos.org>
Wed, 23 Oct 2013 19:31:41 +0000 (19:31 +0000)
committerTimo Kreuzer <timo.kreuzer@reactos.org>
Wed, 23 Oct 2013 19:31:41 +0000 (19:31 +0000)
[RTL] Fix RtlAddAce (the version checks were inverted.)
[NTOSKRNL] Make SystemExtendServiceTableInformation case of NtSetSystemInformation behave more like Windows
[NTOSKRNL] Fix PsGetCurrentThreadWin32ThreadAndEnterCriticalRegion

svn path=/trunk/; revision=60737

reactos/dll/ntdll/ldr/ldrapi.c
reactos/lib/rtl/acl.c
reactos/lib/rtl/memstream.c
reactos/ntoskrnl/ex/sysinfo.c
reactos/ntoskrnl/ps/thread.c

index 8c1db23..5107649 100644 (file)
@@ -1610,7 +1610,7 @@ LdrUnloadAlternateResourceModule(IN PVOID BaseAddress)
     ULONG_PTR Cookie;
 
     /* Acquire the loader lock */
-    LdrLockLoaderLock(TRUE, NULL, &Cookie);
+    LdrLockLoaderLock(LDR_LOCK_LOADER_LOCK_FLAG_RAISE_ON_ERRORS, NULL, &Cookie);
 
     /* Check if there's any alternate resources loaded */
     if (AlternateResourceModuleCount)
index 819bf4e..92347ef 100644 (file)
@@ -579,22 +579,23 @@ RtlAddAce(IN PACL Acl,
     /* Bail out if there's no space */
     if (!RtlFirstFreeAce(Acl, &FreeAce)) return STATUS_INVALID_PARAMETER;
 
-    /* Always use the smaller revision */
-    if (Acl->AclRevision <= AclRevision) AclRevision = Acl->AclRevision;
-
     /* Loop over all the ACEs, keeping track of new ACEs as we go along */
     for (Ace = AceList, NewAceCount = 0;
          Ace < (PACE)((ULONG_PTR)AceList + AceListLength);
          NewAceCount++)
     {
-        /* Make sure that the revision of this ACE is valid in this list */
-        if (Ace->Header.AceType <= ACCESS_MAX_MS_V3_ACE_TYPE)
-        {
-            if (AclRevision < ACL_REVISION3) return STATUS_INVALID_PARAMETER;
-        }
-        else if (Ace->Header.AceType <= ACCESS_MAX_MS_V4_ACE_TYPE)
+        /* Make sure that the revision of this ACE is valid in this list.
+           The initial check looks strange, but it is what Windows does. */
+        if (Ace->Header.AceType <= ACCESS_MAX_MS_ACE_TYPE)
         {
-            if (AclRevision < ACL_REVISION4) return STATUS_INVALID_PARAMETER;
+            if (Ace->Header.AceType > ACCESS_MAX_MS_V3_ACE_TYPE)
+            {
+                if (AclRevision < ACL_REVISION4) return STATUS_INVALID_PARAMETER;
+            }
+            else if (Ace->Header.AceType > ACCESS_MAX_MS_V2_ACE_TYPE)
+            {
+                if (AclRevision < ACL_REVISION3) return STATUS_INVALID_PARAMETER;
+            }
         }
 
         /* Move to the next ACE */
@@ -627,9 +628,9 @@ RtlAddAce(IN PACL Acl,
                 Ace,
                 (ULONG_PTR)FreeAce - (ULONG_PTR)Ace);
 
-    /* Fill out the header and return */
-    Acl->AceCount = Acl->AceCount + NewAceCount;
-    Acl->AclRevision = (UCHAR)AclRevision;
+    /* Update the header and return */
+    Acl->AceCount += NewAceCount;
+    Acl->AclRevision = (UCHAR)min(Acl->AclRevision, AclRevision);
     return STATUS_SUCCESS;
 }
 
@@ -846,21 +847,21 @@ RtlValidAcl(IN PACL Acl)
             (Acl->AclRevision > MAX_ACL_REVISION))
         {
             DPRINT1("Invalid ACL revision\n");
-            _SEH2_YIELD(return FALSE);
+            return FALSE;
         }
 
         /* Next, validate that the ACL is USHORT-aligned */
         if (ROUND_DOWN(Acl->AclSize, sizeof(USHORT)) != Acl->AclSize)
         {
             DPRINT1("Invalid ACL size\n");
-            _SEH2_YIELD(return FALSE);
+            return FALSE;
         }
 
         /* And that it's big enough */
         if (Acl->AclSize < sizeof(ACL))
         {
             DPRINT1("Invalid ACL size\n");
-            _SEH2_YIELD(return FALSE);
+            return FALSE;
         }
 
         /* Loop each ACE */
@@ -871,21 +872,21 @@ RtlValidAcl(IN PACL Acl)
             if (((ULONG_PTR)Ace + sizeof(ACE_HEADER)) >= ((ULONG_PTR)Acl + Acl->AclSize))
             {
                 DPRINT1("Invalid ACE size\n");
-                _SEH2_YIELD(return FALSE);
+                return FALSE;
             }
 
             /* Validate the length of this ACE */
             if (ROUND_DOWN(Ace->AceSize, sizeof(USHORT)) != Ace->AceSize)
             {
                 DPRINT1("Invalid ACE size: %lx\n", Ace->AceSize);
-                _SEH2_YIELD(return FALSE);
+                return FALSE;
             }
 
             /* Validate we have space for the entire ACE */
             if (((ULONG_PTR)Ace + Ace->AceSize) > ((ULONG_PTR)Acl + Acl->AclSize))
             {
                 DPRINT1("Invalid ACE size %lx %lx\n", Ace->AceSize, Acl->AclSize);
-                _SEH2_YIELD(return FALSE);
+                return FALSE;
             }
 
             /* Check what kind of ACE this is */
@@ -895,14 +896,14 @@ RtlValidAcl(IN PACL Acl)
                 if (ROUND_DOWN(Ace->AceSize, sizeof(ULONG)) != Ace->AceSize)
                 {
                     DPRINT1("Invalid ACE size\n");
-                    _SEH2_YIELD(return FALSE);
+                    return FALSE;
                 }
 
                 /* The ACE size should at least have enough for the header */
                 if (Ace->AceSize < sizeof(ACE_HEADER))
                 {
                     DPRINT1("Invalid ACE size: %lx %lx\n", Ace->AceSize, sizeof(ACE_HEADER));
-                    _SEH2_YIELD(return FALSE);
+                    return FALSE;
                 }
 
                 /* Check if the SID revision is valid */
@@ -910,21 +911,21 @@ RtlValidAcl(IN PACL Acl)
                 if (Sid->Revision != SID_REVISION)
                 {
                     DPRINT1("Invalid SID\n");
-                    _SEH2_YIELD(return FALSE);
+                    return FALSE;
                 }
 
                 /* Check if the SID is out of bounds */
                 if (Sid->SubAuthorityCount > SID_MAX_SUB_AUTHORITIES)
                 {
                     DPRINT1("Invalid SID\n");
-                    _SEH2_YIELD(return FALSE);
+                    return FALSE;
                 }
 
                 /* The ACE size should at least have enough for the header and SID */
                 if (Ace->AceSize < (sizeof(ACE_HEADER) + RtlLengthSid(Sid)))
                 {
                     DPRINT1("Invalid ACE size\n");
-                    _SEH2_YIELD(return FALSE);
+                    return FALSE;
                 }
             }
             else if (Ace->AceType == ACCESS_ALLOWED_COMPOUND_ACE_TYPE)
@@ -942,7 +943,7 @@ RtlValidAcl(IN PACL Acl)
                 if (Ace->AceSize < sizeof(ACE_HEADER))
                 {
                     DPRINT1("Unknown ACE\n");
-                    _SEH2_YIELD(return FALSE);
+                    return FALSE;
                 }
             }
 
@@ -953,7 +954,7 @@ RtlValidAcl(IN PACL Acl)
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
         /* Something was invalid, fail */
-        _SEH2_YIELD(return FALSE);
+        return FALSE;
     }
     _SEH2_END;
 
index 776fbc6..b94460d 100644 (file)
@@ -303,7 +303,7 @@ RtlCopyMemoryStreamTo(
     TotalSize = Length.QuadPart;
     while (TotalSize)
     {
-        Left = min(TotalSize, sizeof(Buffer));
+        Left = (ULONG)min(TotalSize, sizeof(Buffer));
 
         /* Read */
         Result = IStream_Read(This, Buffer, Left, &Amount);
index fb6b77d..605e3b1 100644 (file)
@@ -1612,28 +1612,51 @@ SSI_DEF(SystemExtendServiceTableInformation)
     /* Check who is calling */
     if (PreviousMode != KernelMode)
     {
+        static const UNICODE_STRING Win32kName = 
+            RTL_CONSTANT_STRING(L"\\SystemRoot\\System32\\win32k.sys");
+
         /* Make sure we can load drivers */
         if (!SeSinglePrivilegeCheck(SeLoadDriverPrivilege, UserMode))
         {
             /* FIXME: We can't, fail */
-            //return STATUS_PRIVILEGE_NOT_HELD;
+            return STATUS_PRIVILEGE_NOT_HELD;
         }
-    }
+        
+        _SEH2_TRY
+        {
+            /* Probe and copy the unicode string */
+            ProbeForRead(Buffer, sizeof(ImageName), 1);
+            ImageName = *(PUNICODE_STRING)Buffer;
 
-    /* Probe and capture the driver name */
-    ProbeAndCaptureUnicodeString(&ImageName, PreviousMode, Buffer);
+            /* Probe the string buffer */
+            ProbeForRead(ImageName.Buffer, ImageName.Length, sizeof(WCHAR));
+
+            /* Check if we have the correct name (nothing else is allowed!) */
+            if (!RtlEqualUnicodeString(&ImageName, &Win32kName, FALSE))
+            {
+                _SEH2_YIELD(return STATUS_PRIVILEGE_NOT_HELD);
+            }
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+        
+        /* Recursively call the function, so that we are from kernel mode */
+        return ZwSetSystemInformation(SystemExtendServiceTableInformation,
+                                      (PVOID)&Win32kName,
+                                      sizeof(Win32kName));
+    }
 
     /* Load the image */
-    Status = MmLoadSystemImage(&ImageName,
+    Status = MmLoadSystemImage((PUNICODE_STRING)Buffer,
                                NULL,
                                NULL,
                                0,
                                (PVOID)&ModuleObject,
                                &ImageBase);
 
-    /* Release String */
-    ReleaseCapturedUnicodeString(&ImageName, PreviousMode);
-
     if (!NT_SUCCESS(Status)) return Status;
 
     /* Get the headers */
@@ -1658,7 +1681,7 @@ SSI_DEF(SystemExtendServiceTableInformation)
     /* Call it */
     Status = (DriverInit)(&Win32k, NULL);
     ASSERT(KeGetCurrentIrql() == PASSIVE_LEVEL);
-
+__debugbreak();__debugbreak();
     /* Unload if we failed */
     if (!NT_SUCCESS(Status)) MmUnloadSystemImage(ModuleObject);
     return Status;
index 07d7108..3bc9469 100644 (file)
@@ -812,10 +812,22 @@ PsGetCurrentThreadWin32Thread(VOID)
  */
 PVOID
 NTAPI
-PsGetCurrentThreadWin32ThreadAndEnterCriticalRegion(VOID)
+PsGetCurrentThreadWin32ThreadAndEnterCriticalRegion(
+    _Out_ HANDLE* OutProcessId)
 {
+    PETHREAD CurrentThread;
+
+    /* Get the current thread */
+    CurrentThread = PsGetCurrentThread();
+
+    /* Return the process id */
+    *OutProcessId = CurrentThread->Cid.UniqueProcess;
+
+    /* Enter critical region */
     KeEnterCriticalRegion();
-    return PsGetCurrentThread()->Tcb.Win32Thread;
+
+    /* Return the win32 thread */
+    return CurrentThread->Tcb.Win32Thread;
 }
 
 /*