Added process ids.
[reactos.git] / reactos / ntoskrnl / ps / thread.c
index 31737ff..285c7f2 100644 (file)
 #include <ddk/ntddk.h>
 #include <internal/ke.h>
 #include <internal/ob.h>
+#include <string.h>
 #include <internal/string.h>
 #include <internal/hal.h>
 #include <internal/ps.h>
+#include <internal/ob.h>
 
 #define NDEBUG
 #include <internal/debug.h>
@@ -45,10 +47,12 @@ static KSPIN_LOCK ThreadListLock = {0,};
  */
 static LIST_ENTRY PriorityListHead[NR_THREAD_PRIORITY_LEVELS]={{NULL,NULL},};
 static BOOLEAN DoneInitYet = FALSE;
+ULONG PiNrThreads = 0;
+ULONG PiNrRunnableThreads = 0;
 
 static PETHREAD CurrentThread = NULL;
 
-static ULONG NextThreadUniqueId = 0;
+static ULONG NextUniqueThreadId = 0;
 
 /* FUNCTIONS ***************************************************************/
 
@@ -59,14 +63,42 @@ PKTHREAD KeGetCurrentThread(VOID)
 
 PETHREAD PsGetCurrentThread(VOID)
 {
-   return((PETHREAD)KeGetCurrentThread());
+   return(CurrentThread);
+}
+
+VOID PiTerminateProcessThreads(PEPROCESS Process, NTSTATUS ExitStatus)
+{
+   KIRQL oldlvl;
+   PLIST_ENTRY current_entry;
+   PETHREAD current;
+   ULONG i;
+
+   KeAcquireSpinLock(&ThreadListLock, &oldlvl);
+
+   for (i=0; i<NR_THREAD_PRIORITY_LEVELS; i++)
+   {
+        current_entry = PriorityListHead[i].Flink;
+        while (current_entry != &PriorityListHead[i])
+        {
+             current = CONTAINING_RECORD(current_entry,ETHREAD,Tcb.Entry);
+             if (current->ThreadsProcess == Process &&
+                 current != PsGetCurrentThread())
+             {
+                  PsTerminateOtherThread(current, ExitStatus);
+             }
+             current_entry = current_entry->Flink;
+        }
+   }
+
+   KeReleaseSpinLock(&ThreadListLock, oldlvl);
 }
 
 static VOID PsInsertIntoThreadList(KPRIORITY Priority, PETHREAD Thread)
 {
    KIRQL oldlvl;
    
-   DPRINT("PsInsertIntoThreadList(Priority %d, Thread %x)\n",Priority,Thread);
+   DPRINT("PsInsertIntoThreadList(Priority %x, Thread %x)\n",Priority,
+           Thread);
    
    KeAcquireSpinLock(&ThreadListLock,&oldlvl);
    InsertTailList(&PriorityListHead[THREAD_PRIORITY_MAX+Priority],
@@ -74,6 +106,16 @@ static VOID PsInsertIntoThreadList(KPRIORITY Priority, PETHREAD Thread)
    KeReleaseSpinLock(&ThreadListLock,oldlvl);
 }
 
+VOID PsBeginThread(PKSTART_ROUTINE StartRoutine, PVOID StartContext)
+{
+   NTSTATUS Ret;
+   
+   KeReleaseSpinLock(&ThreadListLock,PASSIVE_LEVEL);
+   Ret = StartRoutine(StartContext);
+   PsTerminateSystemThread(Ret);
+   KeBugCheck(0);
+}
+
 static PETHREAD PsScanThreadList(KPRIORITY Priority)
 {
    PLIST_ENTRY current_entry;
@@ -81,12 +123,19 @@ static PETHREAD PsScanThreadList(KPRIORITY Priority)
    PETHREAD oldest = NULL;
    ULONG oldest_time = 0;
    
-   DPRINT("PsScanThreadList(Priority %d)\n",Priority);
+//   DPRINT("PsScanThreadList(Priority %d)\n",Priority);
    
    current_entry = PriorityListHead[THREAD_PRIORITY_MAX+Priority].Flink;
    while (current_entry != &PriorityListHead[THREAD_PRIORITY_MAX+Priority])
      {
        current = CONTAINING_RECORD(current_entry,ETHREAD,Tcb.Entry);
+
+       if (current->Tcb.ThreadState == THREAD_STATE_TERMINATED &&
+           current != CurrentThread)
+         {
+            PsReleaseThread(current);
+         }
+
        if (current->Tcb.ThreadState == THREAD_STATE_RUNNABLE)
          {
             if (oldest == NULL || oldest_time > current->Tcb.LastTick)
@@ -97,7 +146,7 @@ static PETHREAD PsScanThreadList(KPRIORITY Priority)
          }
        current_entry = current_entry->Flink;
      }
-   DPRINT("PsScanThreadList() = %x\n",oldest);
+//   DPRINT("PsScanThreadList() = %x\n",oldest);
    return(oldest);
 }
 
@@ -129,30 +178,31 @@ VOID PsDispatchThread(VOID)
        Candidate = PsScanThreadList(CurrentPriority);
        if (Candidate == CurrentThread)
          {
-            DPRINT("Scheduling current thread\n");
-            KeQueryTickCount(&TickCount);
-            CurrentThread->Tcb.LastTick = TickCount.LowPart;
+             DPRINT("Scheduling current thread\n");
+             KeQueryTickCount(&TickCount);
+             CurrentThread->Tcb.LastTick = TickCount.LowPart;
             CurrentThread->Tcb.ThreadState = THREAD_STATE_RUNNING;
             KeReleaseSpinLock(&ThreadListLock,irql);
             return;
          }
        if (Candidate != NULL)
          {     
-            DPRINT("Scheduling %x\n",Candidate);
+             DPRINT("Scheduling %x\n",Candidate);
             
             Candidate->Tcb.ThreadState = THREAD_STATE_RUNNING;
             
             KeQueryTickCount(&TickCount);
-            CurrentThread->Tcb.LastTick = TickCount.LowPart;
+             CurrentThread->Tcb.LastTick = TickCount.LowPart;
             
             CurrentThread = Candidate;
             
-            KeReleaseSpinLock(&ThreadListLock,irql);
-            KeLowerIrql(PASSIVE_LEVEL);
             HalTaskSwitch(&CurrentThread->Tcb);
+            KeReleaseSpinLock(&ThreadListLock,irql);
             return;
          }
      }
+   DbgPrint("CRITICAL: No threads are runnable\n");
+   KeBugCheck(0);
 }
 
 NTSTATUS PsInitializeThread(HANDLE ProcessHandle, 
@@ -161,15 +211,15 @@ NTSTATUS PsInitializeThread(HANDLE ProcessHandle,
                            ACCESS_MASK DesiredAccess,
                            POBJECT_ATTRIBUTES ThreadAttributes)
 {
-   ULONG ThreadId;
-   ULONG ProcessId;
    PETHREAD Thread;
    NTSTATUS Status;
    
-   Thread = ObGenericCreateObject(ThreadHandle,
-                                 DesiredAccess,
-                                 ThreadAttributes,
-                                 PsThreadType);
+   PiNrThreads++;
+   
+   Thread = ObCreateObject(ThreadHandle,
+                          DesiredAccess,
+                          ThreadAttributes,
+                          PsThreadType);
    DPRINT("Thread = %x\n",Thread);
    Thread->Tcb.LastTick = 0;
    Thread->Tcb.ThreadState=THREAD_STATE_SUSPENDED;
@@ -177,7 +227,7 @@ NTSTATUS PsInitializeThread(HANDLE ProcessHandle,
    Thread->Tcb.CurrentPriority=THREAD_PRIORITY_NORMAL;
    Thread->Tcb.ApcList=ExAllocatePool(NonPagedPool,sizeof(LIST_ENTRY));
    Thread->Tcb.SuspendCount = 1;
-   if (ProcessHandle!=NULL)
+   if (ProcessHandle != NULL)
      {
        Status = ObReferenceObjectByHandle(ProcessHandle,
                                           PROCESS_CREATE_THREAD,
@@ -187,33 +237,46 @@ NTSTATUS PsInitializeThread(HANDLE ProcessHandle,
                                           NULL);
        if (Status != STATUS_SUCCESS)
          {
+            DPRINT("Failed at %s:%d\n",__FILE__,__LINE__);
             return(Status);
          }
      }
    else
      {
-       Thread->ThreadsProcess=SystemProcess;
+       Thread->ThreadsProcess = SystemProcess;
+       ObReferenceObjectByPointer(Thread->ThreadsProcess,
+                                  PROCESS_CREATE_THREAD,
+                                  PsProcessType,
+                                  UserMode);
      }
+   ObReferenceObjectByPointer(Thread->ThreadsProcess,
+                             PROCESS_CREATE_THREAD,
+                             PsProcessType,
+                             UserMode);
    InitializeListHead(Thread->Tcb.ApcList);
    InitializeListHead(&(Thread->IrpList));
-   Thread->Cid.UniqueThread=NextThreadUniqueId++;
-//   thread->Cid.ThreadId=InterlockedIncrement(&NextThreadUniqueId);
+   Thread->Cid.UniqueThread = (HANDLE)InterlockedIncrement(
+                                                          &NextUniqueThreadId);
+   Thread->Cid.UniqueProcess = (HANDLE)Thread->ThreadsProcess->UniqueProcessId;
+   DbgPrint("Thread->Cid.UniqueThread %d\nThread->Cid.UniqueProcess %d\n",
+            Thread->Cid.UniqueThread, Thread->Cid.UniqueThread);
+   ObReferenceObjectByPointer(Thread,
+                             THREAD_ALL_ACCESS,
+                             PsThreadType,
+                             UserMode);
    PsInsertIntoThreadList(Thread->Tcb.CurrentPriority,Thread);
    
    *ThreadPtr = Thread;
    
+   ObDereferenceObject(Thread->ThreadsProcess);
    return(STATUS_SUCCESS);
 }
 
 VOID PsResumeThread(PETHREAD Thread)
 {
    DPRINT("PsResumeThread(Thread %x)\n",Thread);
-   
    Thread->Tcb.SuspendCount--;
-   DPRINT("Thread->Tcb.SuspendCount %d\n",Thread->Tcb.SuspendCount);
-   DPRINT("Thread->Tcb.ThreadState %d THREAD_STATE_RUNNING %d\n",
-           Thread->Tcb.ThreadState,THREAD_STATE_RUNNING);
-   if (Thread->Tcb.SuspendCount <= 0 && 
+   if (Thread->Tcb.SuspendCount <= 0 &&
        Thread->Tcb.ThreadState != THREAD_STATE_RUNNING)
      {
         DPRINT("Setting thread to runnable\n");
@@ -236,7 +299,12 @@ VOID PsSuspendThread(PETHREAD Thread)
      }
 }
 
-void PsInitThreadManagment(void)
+VOID PiDeleteThread(PVOID ObjectBody)
+{
+   DbgPrint("PiDeleteThread(ObjectBody %x)\n",ObjectBody);
+}
+
+VOID PsInitThreadManagment(VOID)
 /*
  * FUNCTION: Initialize thread managment
  */
@@ -266,7 +334,7 @@ void PsInitThreadManagment(void)
    PsThreadType->Dump = NULL;
    PsThreadType->Open = NULL;
    PsThreadType->Close = NULL;
-   PsThreadType->Delete = NULL;
+   PsThreadType->Delete = PiDeleteThread;
    PsThreadType->Parse = NULL;
    PsThreadType->Security = NULL;
    PsThreadType->QueryName = NULL;
@@ -316,14 +384,21 @@ NTSTATUS ZwCreateThread(PHANDLE ThreadHandle,
    PETHREAD Thread;
    NTSTATUS Status;
    
+   DPRINT("ZwCreateThread(ThreadHandle %x, PCONTEXT %x)\n",
+         ThreadHandle,ThreadContext);
+   
    Status = PsInitializeThread(ProcessHandle,&Thread,ThreadHandle,
                               DesiredAccess,ObjectAttributes);
-   if (Status != STATUS_SUCCESS)
+   if (!NT_SUCCESS(Status))
      {
        return(Status);
      }
    
-   HalInitTaskWithContext(Thread,ThreadContext);
+   Status = HalInitTaskWithContext(Thread,ThreadContext);
+   if (!NT_SUCCESS(Status))
+     {
+       return(Status);
+     }
    Thread->StartAddress=NULL;
 
    if (Client!=NULL)
@@ -372,13 +447,17 @@ NTSTATUS PsCreateSystemThread(PHANDLE ThreadHandle,
    
    Status = PsInitializeThread(ProcessHandle,&Thread,ThreadHandle,
                               DesiredAccess,ObjectAttributes);
-   if (Status != STATUS_SUCCESS)
+   if (!NT_SUCCESS(Status))
      {
        return(Status);
      }
    
    Thread->StartAddress=StartRoutine;
-   HalInitTask(Thread,StartRoutine,StartContext);
+   Status = HalInitTask(Thread,StartRoutine,StartContext);
+   if (!NT_SUCCESS(Status))
+     {
+       return(Status);
+     }
 
    if (ClientId!=NULL)
      {
@@ -500,6 +579,7 @@ NTSTATUS STDCALL ZwResumeThread(IN HANDLE ThreadHandle,
        Thread->Tcb.ThreadState = THREAD_STATE_RUNNABLE;
      }
    
+   ObDereferenceObject(Thread);
    return(STATUS_SUCCESS);
 }
 
@@ -561,6 +641,7 @@ NTSTATUS STDCALL ZwSuspendThread(IN HANDLE ThreadHandle,
          }
      }
    
+   ObDereferenceObject(Thread);
    return(STATUS_SUCCESS);
 }