convert DefaultSetInfoBufferCheck and DefaultQueryInfoBufferCheck to inlined functions
[reactos.git] / reactos / ntoskrnl / ps / query.c
index 5b4f855..9641831 100644 (file)
@@ -143,13 +143,13 @@ NtQueryInformationProcess(IN  HANDLE ProcessHandle,
 
    PreviousMode = ExGetPreviousMode();
 
-   DefaultQueryInfoBufferCheck(ProcessInformationClass,
-                               PsProcessInfoClass,
-                               ProcessInformation,
-                               ProcessInformationLength,
-                               ReturnLength,
-                               PreviousMode,
-                               &Status);
+   Status = DefaultQueryInfoBufferCheck(ProcessInformationClass,
+                                        PsProcessInfoClass,
+                                        sizeof(PsProcessInfoClass) / sizeof(PsProcessInfoClass[0]),
+                                        ProcessInformation,
+                                        ProcessInformationLength,
+                                        ReturnLength,
+                                        PreviousMode);
    if(!NT_SUCCESS(Status))
    {
      DPRINT1("NtQueryInformationProcess() failed, Status: 0x%x\n", Status);
@@ -190,9 +190,9 @@ NtQueryInformationProcess(IN  HANDLE ProcessHandle,
          ProcessBasicInformationP->PebBaseAddress = Process->Peb;
          ProcessBasicInformationP->AffinityMask = Process->Pcb.Affinity;
          ProcessBasicInformationP->UniqueProcessId =
-           Process->UniqueProcessId;
+           (ULONG)Process->UniqueProcessId;
          ProcessBasicInformationP->InheritedFromUniqueProcessId =
-           Process->InheritedFromUniqueProcessId;
+           (ULONG)Process->InheritedFromUniqueProcessId;
          ProcessBasicInformationP->BasePriority =
            Process->Pcb.BasePriority;
 
@@ -654,12 +654,12 @@ NtSetInformationProcess(IN HANDLE ProcessHandle,
 
    PreviousMode = ExGetPreviousMode();
 
-   DefaultSetInfoBufferCheck(ProcessInformationClass,
-                             PsProcessInfoClass,
-                             ProcessInformation,
-                             ProcessInformationLength,
-                             PreviousMode,
-                             &Status);
+   Status = DefaultSetInfoBufferCheck(ProcessInformationClass,
+                                      PsProcessInfoClass,
+                                      sizeof(PsProcessInfoClass) / sizeof(PsProcessInfoClass[0]),
+                                      ProcessInformation,
+                                      ProcessInformationLength,
+                                      PreviousMode);
    if(!NT_SUCCESS(Status))
    {
      DPRINT1("NtSetInformationProcess() %d %x  %x called\n", ProcessInformationClass, ProcessInformation, ProcessInformationLength);
@@ -1135,7 +1135,6 @@ NtSetInformationThread (IN HANDLE ThreadHandle,
                        IN ULONG ThreadInformationLength)
 {
   PETHREAD Thread;
-  NTSTATUS Status;
   union
   {
      KPRIORITY Priority;
@@ -1144,8 +1143,12 @@ NtSetInformationThread (IN HANDLE ThreadHandle,
      HANDLE Handle;
      PVOID Address;
   }u;
+  KPROCESSOR_MODE PreviousMode;
+  NTSTATUS Status = STATUS_SUCCESS;
 
   PAGED_CODE();
+  
+  PreviousMode = ExGetPreviousMode();
 
   if (ThreadInformationClass <= MaxThreadInfoClass &&
       !SetInformationData[ThreadInformationClass].Implemented)
@@ -1162,20 +1165,43 @@ NtSetInformationThread (IN HANDLE ThreadHandle,
       return STATUS_INFO_LENGTH_MISMATCH;
     }
 
+  if (PreviousMode != KernelMode)
+    {
+      _SEH_TRY
+        {
+          ProbeForRead(ThreadInformation,
+                       SetInformationData[ThreadInformationClass].Size,
+                       1);
+          RtlCopyMemory(&u.Priority,
+                        ThreadInformation,
+                        SetInformationData[ThreadInformationClass].Size);
+        }
+      _SEH_HANDLE
+        {
+          Status = _SEH_GetExceptionCode();
+        }
+      _SEH_END;
+      
+      if (!NT_SUCCESS(Status))
+        {
+          return Status;
+        }
+    }
+  else
+    {
+      RtlCopyMemory(&u.Priority,
+                    ThreadInformation,
+                    SetInformationData[ThreadInformationClass].Size);
+    }
+
+  /* FIXME: This is REALLY wrong. Some types don't need THREAD_SET_INFORMATION */
+  /* FIXME: We should also check for certain things before doing the reference */
   Status = ObReferenceObjectByHandle (ThreadHandle,
                                      THREAD_SET_INFORMATION,
                                      PsThreadType,
                                      ExGetPreviousMode (),
                                      (PVOID*)&Thread,
                                      NULL);
-   if (!NT_SUCCESS(Status))
-     {
-       return Status;
-     }
-
-   Status = MmCopyFromCaller(&u.Priority,
-                            ThreadInformation,
-                            SetInformationData[ThreadInformationClass].Size);
    if (NT_SUCCESS(Status))
      {
        switch (ThreadInformationClass)
@@ -1194,7 +1220,19 @@ NtSetInformationThread (IN HANDLE ThreadHandle,
             break;
 
            case ThreadAffinityMask:
-            Status = KeSetAffinityThread(&Thread->Tcb, u.Affinity);
+               
+               /* Check if this is valid */
+               DPRINT1("%lx, %lx\n", Thread->ThreadsProcess->Pcb.Affinity, u.Affinity);
+               if ((Thread->ThreadsProcess->Pcb.Affinity & u.Affinity) !=
+                   u.Affinity)
+               {
+                   DPRINT1("Wrong affinity given\n");
+                   Status = STATUS_INVALID_PARAMETER;
+               }
+               else
+               {
+                       Status = KeSetAffinityThread(&Thread->Tcb, u.Affinity);
+               }
             break;
 
            case ThreadImpersonationToken:
@@ -1209,9 +1247,9 @@ NtSetInformationThread (IN HANDLE ThreadHandle,
             /* Shoult never occure if the data table is correct */
             KEBUGCHECK(0);
         }
+       ObDereferenceObject (Thread);
      }
-  ObDereferenceObject (Thread);
-
+  
   return Status;
 }
 
@@ -1226,7 +1264,6 @@ NtQueryInformationThread (IN      HANDLE          ThreadHandle,
                          OUT   PULONG          ReturnLength  OPTIONAL)
 {
    PETHREAD Thread;
-   NTSTATUS Status;
    union
    {
       THREAD_BASIC_INFORMATION TBI;
@@ -1235,8 +1272,12 @@ NtQueryInformationThread (IN     HANDLE          ThreadHandle,
       LARGE_INTEGER Count;
       BOOLEAN Last;
    }u;
+   KPROCESSOR_MODE PreviousMode;
+   NTSTATUS Status = STATUS_SUCCESS;
 
    PAGED_CODE();
+   
+   PreviousMode = ExGetPreviousMode();
 
    if (ThreadInformationClass <= MaxThreadInfoClass &&
        !QueryInformationData[ThreadInformationClass].Implemented)
@@ -1253,6 +1294,30 @@ NtQueryInformationThread (IN     HANDLE          ThreadHandle,
        return STATUS_INFO_LENGTH_MISMATCH;
      }
 
+   if (PreviousMode != KernelMode)
+     {
+       _SEH_TRY
+         {
+           ProbeForWrite(ThreadInformation,
+                         QueryInformationData[ThreadInformationClass].Size,
+                         1);
+           if (ReturnLength != NULL)
+             {
+               ProbeForWriteUlong(ReturnLength);
+             }
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
+       
+       if (!NT_SUCCESS(Status))
+         {
+           return Status;
+         }
+     }
+
    Status = ObReferenceObjectByHandle(ThreadHandle,
                                      THREAD_QUERY_INFORMATION,
                                      PsThreadType,
@@ -1272,11 +1337,11 @@ NtQueryInformationThread (IN    HANDLE          ThreadHandle,
           * 0. So do the conversion here:
           * -Gunnar     */
          u.TBI.ExitStatus = (Thread->ExitStatus == 0) ? STATUS_PENDING : Thread->ExitStatus;
-        u.TBI.TebBaseAddress = Thread->Tcb.Teb;
+        u.TBI.TebBaseAddress = (PVOID)Thread->Tcb.Teb;
         u.TBI.ClientId = Thread->Cid;
         u.TBI.AffinityMask = Thread->Tcb.Affinity;
         u.TBI.Priority = Thread->Tcb.Priority;
-        u.TBI.BasePriority = Thread->Tcb.BasePriority;
+        u.TBI.BasePriority = KeQueryBasePriorityThread(&Thread->Tcb);
         break;
 
        case ThreadTimes:
@@ -1311,23 +1376,41 @@ NtQueryInformationThread (IN    HANDLE          ThreadHandle,
         /* Shoult never occure if the data table is correct */
         KEBUGCHECK(0);
      }
-   if (QueryInformationData[ThreadInformationClass].Size)
+
+   if (PreviousMode != KernelMode)
      {
-       Status = MmCopyToCaller(ThreadInformation,
-                               &u.TBI,
-                              QueryInformationData[ThreadInformationClass].Size);
+       _SEH_TRY
+         {
+           if (QueryInformationData[ThreadInformationClass].Size)
+             {
+               RtlCopyMemory(ThreadInformation,
+                             &u.TBI,
+                             QueryInformationData[ThreadInformationClass].Size);
+             }
+           if (ReturnLength != NULL)
+             {
+               *ReturnLength = QueryInformationData[ThreadInformationClass].Size;
+             }
+         }
+       _SEH_HANDLE
+         {
+           Status = _SEH_GetExceptionCode();
+         }
+       _SEH_END;
      }
-   if (ReturnLength)
+   else
      {
-       NTSTATUS Status2;
-       static ULONG Null = 0;
-       Status2 = MmCopyToCaller(ReturnLength,
-                               NT_SUCCESS(Status) ? &QueryInformationData[ThreadInformationClass].Size : &Null,
-                               sizeof(ULONG));
-       if (NT_SUCCESS(Status))
+       if (QueryInformationData[ThreadInformationClass].Size)
          {
-          Status = Status2;
-        }
+           RtlCopyMemory(ThreadInformation,
+                         &u.TBI,
+                         QueryInformationData[ThreadInformationClass].Size);
+         }
+
+       if (ReturnLength != NULL)
+         {
+           *ReturnLength = QueryInformationData[ThreadInformationClass].Size;
+         }
      }
 
    ObDereferenceObject(Thread);