convert DefaultSetInfoBufferCheck and DefaultQueryInfoBufferCheck to inlined functions
[reactos.git] / reactos / ntoskrnl / se / token.c
index 843da88..009fa98 100644 (file)
@@ -68,18 +68,23 @@ static const INFORMATION_CLASS_INFO SeTokenInformationClass[] = {
 
 /* FUNCTIONS *****************************************************************/
 
-VOID SepFreeProxyData(PVOID ProxyData)
+VOID
+NTAPI
+SepFreeProxyData(PVOID ProxyData)
 {
    UNIMPLEMENTED;
 }
 
-NTSTATUS SepCopyProxyData(PVOID* Dest, PVOID Src)
+NTSTATUS
+NTAPI
+SepCopyProxyData(PVOID* Dest, PVOID Src)
 {
    UNIMPLEMENTED;
    return(STATUS_NOT_IMPLEMENTED);
 }
 
 NTSTATUS
+NTAPI
 SeExchangePrimaryToken(PEPROCESS Process,
                        PACCESS_TOKEN NewTokenP,
                        PACCESS_TOKEN* OldTokenP)
@@ -109,6 +114,7 @@ SeExchangePrimaryToken(PEPROCESS Process,
 }
 
 VOID
+NTAPI
 SeDeassignPrimaryToken(PEPROCESS Process)
 {
     PTOKEN OldToken;
@@ -138,6 +144,7 @@ RtlLengthSidAndAttributes(ULONG Count,
 
 
 NTSTATUS
+NTAPI
 SepFindPrimaryGroupAndDefaultOwner(PTOKEN Token,
                                   PSID PrimaryGroup,
                                   PSID DefaultOwner)
@@ -557,7 +564,9 @@ SepDeleteToken(PVOID ObjectBody)
 }
 
 
-VOID INIT_FUNCTION
+VOID
+INIT_FUNCTION
+NTAPI
 SepInitializeTokenImplementation(VOID)
 {
     UNICODE_STRING Name;
@@ -565,15 +574,15 @@ SepInitializeTokenImplementation(VOID)
     
     ExInitializeResource(&SepTokenLock);
     
-    DPRINT1("Creating Token Object Type\n");
+    DPRINT("Creating Token Object Type\n");
   
     /*  Initialize the Token type  */
     RtlZeroMemory(&ObjectTypeInitializer, sizeof(ObjectTypeInitializer));
     RtlInitUnicodeString(&Name, L"Token");
     ObjectTypeInitializer.Length = sizeof(ObjectTypeInitializer);
-    ObjectTypeInitializer.DefaultNonPagedPoolCharge = sizeof(TOKEN);
+    ObjectTypeInitializer.DefaultPagedPoolCharge = sizeof(TOKEN);
     ObjectTypeInitializer.GenericMapping = SepTokenMapping;
-    ObjectTypeInitializer.PoolType = NonPagedPool;
+    ObjectTypeInitializer.PoolType = PagedPool;
     ObjectTypeInitializer.ValidAccessMask = TOKEN_ALL_ACCESS;
     ObjectTypeInitializer.UseDefaultObject = TRUE;
     ObjectTypeInitializer.DeleteProcedure = SepDeleteToken;
@@ -606,13 +615,13 @@ NtQueryInformationToken(IN HANDLE TokenHandle,
   PreviousMode = ExGetPreviousMode();
 
   /* Check buffers and class validity */
-  DefaultQueryInfoBufferCheck(TokenInformationClass,
-                              SeTokenInformationClass,
-                              TokenInformation,
-                              TokenInformationLength,
-                              ReturnLength,
-                              PreviousMode,
-                              &Status);
+  Status = DefaultQueryInfoBufferCheck(TokenInformationClass,
+                                       SeTokenInformationClass,
+                                       sizeof(SeTokenInformationClass) / sizeof(SeTokenInformationClass[0]),
+                                       TokenInformation,
+                                       TokenInformationLength,
+                                       ReturnLength,
+                                       PreviousMode);
 
   if(!NT_SUCCESS(Status))
   {
@@ -1047,9 +1056,49 @@ NtQueryInformationToken(IN HANDLE TokenHandle,
        break;
 
       case TokenRestrictedSids:
-       DPRINT1("NtQueryInformationToken(TokenRestrictedSids) not implemented\n");
-       Status = STATUS_NOT_IMPLEMENTED;
+      {
+        PTOKEN_GROUPS tg = (PTOKEN_GROUPS)TokenInformation;
+
+        DPRINT("NtQueryInformationToken(TokenRestrictedSids)\n");
+        RequiredLength = sizeof(tg->GroupCount) +
+                         RtlLengthSidAndAttributes(Token->RestrictedSidCount, Token->RestrictedSids);
+
+        _SEH_TRY
+        {
+          if(TokenInformationLength >= RequiredLength)
+          {
+            ULONG SidLen = RequiredLength - sizeof(tg->GroupCount) -
+                           (Token->RestrictedSidCount * sizeof(SID_AND_ATTRIBUTES));
+            PSID_AND_ATTRIBUTES Sid = (PSID_AND_ATTRIBUTES)((ULONG_PTR)TokenInformation + sizeof(tg->GroupCount) +
+                                                            (Token->RestrictedSidCount * sizeof(SID_AND_ATTRIBUTES)));
+
+            tg->GroupCount = Token->RestrictedSidCount;
+            Status = RtlCopySidAndAttributesArray(Token->RestrictedSidCount,
+                                                  Token->RestrictedSids,
+                                                  SidLen,
+                                                  &tg->Groups[0],
+                                                  (PSID)Sid,
+                                                  &Unused.Ptr,
+                                                  &Unused.Ulong);
+          }
+          else
+          {
+            Status = STATUS_BUFFER_TOO_SMALL;
+          }
+
+          if(ReturnLength != NULL)
+          {
+            *ReturnLength = RequiredLength;
+          }
+        }
+        _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+        _SEH_END;
+
        break;
+      }
 
       case TokenSandBoxInert:
        DPRINT1("NtQueryInformationToken(TokenSandboxInert) not implemented\n");
@@ -1149,12 +1198,12 @@ NtSetInformationToken(IN HANDLE TokenHandle,
 
   PreviousMode = ExGetPreviousMode();
 
-  DefaultSetInfoBufferCheck(TokenInformationClass,
-                            SeTokenInformationClass,
-                            TokenInformation,
-                            TokenInformationLength,
-                            PreviousMode,
-                            &Status);
+  Status = DefaultSetInfoBufferCheck(TokenInformationClass,
+                                     SeTokenInformationClass,
+                                     sizeof(SeTokenInformationClass) / sizeof(SeTokenInformationClass[0]),
+                                     TokenInformation,
+                                     TokenInformationLength,
+                                     PreviousMode);
 
   if(!NT_SUCCESS(Status))
   {
@@ -1400,9 +1449,7 @@ NtDuplicateToken(IN HANDLE ExistingTokenHandle,
   {
     _SEH_TRY
     {
-      ProbeForWrite(NewTokenHandle,
-                    sizeof(HANDLE),
-                    sizeof(ULONG));
+      ProbeForWriteHandle(NewTokenHandle);
     }
     _SEH_HANDLE
     {
@@ -1780,17 +1827,19 @@ SepCreateSystemProcessToken(VOID)
   NTSTATUS Status;
   ULONG uSize;
   ULONG i;
+  ULONG uLocalSystemLength;
+  ULONG uWorldLength;
+  ULONG uAuthUserLength;
+  ULONG uAdminsLength;
+  PTOKEN AccessToken;
+  PVOID SidArea;
 
   PAGED_CODE();
 
-  ULONG uLocalSystemLength = RtlLengthSid(SeLocalSystemSid);
-  ULONG uWorldLength       = RtlLengthSid(SeWorldSid);
-  ULONG uAuthUserLength    = RtlLengthSid(SeAuthenticatedUserSid);
-  ULONG uAdminsLength      = RtlLengthSid(SeAliasAdminsSid);
-
-  PTOKEN AccessToken;
-
-  PVOID SidArea;
+  uLocalSystemLength = RtlLengthSid(SeLocalSystemSid);
+  uWorldLength       = RtlLengthSid(SeWorldSid);
+  uAuthUserLength    = RtlLengthSid(SeAuthenticatedUserSid);
+  uAdminsLength      = RtlLengthSid(SeAliasAdminsSid);
 
  /*
   * Initialize the token
@@ -1808,6 +1857,12 @@ SepCreateSystemProcessToken(VOID)
     {
       return NULL;
     }
+  Status = ObInsertObject(AccessToken,
+                          NULL,
+                          TOKEN_ALL_ACCESS,
+                          0,
+                          NULL,
+                          NULL);
 
   Status = ExpAllocateLocallyUniqueId(&AccessToken->TokenId);
   if (!NT_SUCCESS(Status))
@@ -1847,7 +1902,7 @@ SepCreateSystemProcessToken(VOID)
   uSize += uAdminsLength;
 
   AccessToken->UserAndGroups =
-    (PSID_AND_ATTRIBUTES)ExAllocatePoolWithTag(NonPagedPool,
+    (PSID_AND_ATTRIBUTES)ExAllocatePoolWithTag(PagedPool,
                                               uSize,
                                               TAG('T', 'O', 'K', 'u'));
   SidArea = &AccessToken->UserAndGroups[AccessToken->UserAndGroupCount];
@@ -1879,7 +1934,7 @@ SepCreateSystemProcessToken(VOID)
 
   uSize = AccessToken->PrivilegeCount * sizeof(LUID_AND_ATTRIBUTES);
   AccessToken->Privileges =
-       (PLUID_AND_ATTRIBUTES)ExAllocatePoolWithTag(NonPagedPool,
+       (PLUID_AND_ATTRIBUTES)ExAllocatePoolWithTag(PagedPool,
                                                    uSize,
                                                    TAG('T', 'O', 'K', 'p'));
 
@@ -1958,7 +2013,7 @@ SepCreateSystemProcessToken(VOID)
   uSize += sizeof(ACE) + uAdminsLength;
   uSize = (uSize & (~3)) + 8;
   AccessToken->DefaultDacl =
-    (PACL) ExAllocatePoolWithTag(NonPagedPool,
+    (PACL) ExAllocatePoolWithTag(PagedPool,
                                 uSize,
                                 TAG('T', 'O', 'K', 'd'));
   Status = RtlCreateAcl(AccessToken->DefaultDacl, uSize, ACL_REVISION);
@@ -2004,6 +2059,8 @@ NtCreateToken(OUT PHANDLE TokenHandle,
   PVOID EndMem;
   ULONG uLength;
   ULONG i;
+  ULONG nTokenPrivileges = 0;
+  LARGE_INTEGER LocalExpirationTime = {};
   KPROCESSOR_MODE PreviousMode;
   NTSTATUS Status = STATUS_SUCCESS;
 
@@ -2015,15 +2072,11 @@ NtCreateToken(OUT PHANDLE TokenHandle,
   {
     _SEH_TRY
     {
-      ProbeForWrite(TokenHandle,
-                    sizeof(HANDLE),
-                    sizeof(ULONG));
+      ProbeForWriteHandle(TokenHandle);
       ProbeForRead(AuthenticationId,
                    sizeof(LUID),
                    sizeof(ULONG));
-      ProbeForRead(ExpirationTime,
-                   sizeof(LARGE_INTEGER),
-                   sizeof(ULONG));
+      LocalExpirationTime = ProbeForReadLargeInteger(ExpirationTime);
       ProbeForRead(TokenUser,
                    sizeof(TOKEN_USER),
                    sizeof(ULONG));
@@ -2045,6 +2098,7 @@ NtCreateToken(OUT PHANDLE TokenHandle,
       ProbeForRead(TokenSource,
                    sizeof(TOKEN_SOURCE),
                    sizeof(ULONG));
+      nTokenPrivileges = TokenPrivileges->PrivilegeCount;
     }
     _SEH_HANDLE
     {
@@ -2057,6 +2111,11 @@ NtCreateToken(OUT PHANDLE TokenHandle,
       return Status;
     }
   }
+  else
+  {
+    nTokenPrivileges = TokenPrivileges->PrivilegeCount;
+    LocalExpirationTime = *ExpirationTime;
+  }
 
   Status = ZwAllocateLocallyUniqueId(&TokenId);
   if (!NT_SUCCESS(Status))
@@ -2155,14 +2214,26 @@ NtCreateToken(OUT PHANDLE TokenHandle,
                                                    uLength,
                                                    TAG('T', 'O', 'K', 'p'));
 
-      for (i = 0; i < TokenPrivileges->PrivilegeCount; i++)
-       {
-         Status = MmCopyFromCaller(&AccessToken->Privileges[i],
-                                   &TokenPrivileges->Privileges[i],
-                                   sizeof(LUID_AND_ATTRIBUTES));
-         if (!NT_SUCCESS(Status))
-           break;
-       }
+      if (PreviousMode != KernelMode)
+        {
+          _SEH_TRY
+            {
+              RtlCopyMemory(AccessToken->Privileges,
+                            TokenPrivileges->Privileges,
+                            nTokenPrivileges * sizeof(LUID_AND_ATTRIBUTES));
+            }
+          _SEH_HANDLE
+            {
+              Status = _SEH_GetExceptionCode();
+            }
+          _SEH_END;
+        }
+      else
+        {
+          RtlCopyMemory(AccessToken->Privileges,
+                        TokenPrivileges->Privileges,
+                        nTokenPrivileges * sizeof(LUID_AND_ATTRIBUTES));
+        }
     }
 
   if (NT_SUCCESS(Status))
@@ -2317,9 +2388,7 @@ NtOpenThreadTokenEx(IN HANDLE ThreadHandle,
   {
     _SEH_TRY
     {
-      ProbeForWrite(TokenHandle,
-                    sizeof(HANDLE),
-                    sizeof(ULONG));
+      ProbeForWriteHandle(TokenHandle);
     }
     _SEH_HANDLE
     {