[NTOS:MM] Quick fix: use SIZE_T instead of ULONG, because ULONG is 32-bit and on...
[reactos.git] / ntoskrnl / ex / keyedevt.c
index cd0ae10..6bb2769 100644 (file)
@@ -25,16 +25,6 @@ typedef struct _EX_KEYED_EVENT
     } HashTable[NUM_KEY_HASH_BUCKETS];
 } EX_KEYED_EVENT, *PEX_KEYED_EVENT;
 
-NTSTATUS
-NTAPI
-ZwCreateKeyedEvent(
-    _Out_ PHANDLE OutHandle,
-    _In_ ACCESS_MASK AccessMask,
-    _In_ POBJECT_ATTRIBUTES ObjectAttributes,
-    _In_ ULONG Flags);
-
-#define KeGetCurrentProcess() ((PKPROCESS)PsGetCurrentProcess())
-
 /* GLOBALS *******************************************************************/
 
 PEX_KEYED_EVENT ExpCritSecOutOfMemoryEvent;
@@ -43,14 +33,15 @@ POBJECT_TYPE ExKeyedEventObjectType;
 static
 GENERIC_MAPPING ExpKeyedEventMapping =
 {
-    STANDARD_RIGHTS_READ | EVENT_QUERY_STATE,
-    STANDARD_RIGHTS_WRITE | EVENT_MODIFY_STATE,
+    STANDARD_RIGHTS_READ | KEYEDEVENT_WAIT,
+    STANDARD_RIGHTS_WRITE | KEYEDEVENT_WAKE,
     STANDARD_RIGHTS_EXECUTE,
-    EVENT_ALL_ACCESS
+    KEYEDEVENT_ALL_ACCESS
 };
 
 /* FUNCTIONS *****************************************************************/
 
+_IRQL_requires_max_(APC_LEVEL)
 BOOLEAN
 INIT_FUNCTION
 NTAPI
@@ -67,7 +58,7 @@ ExpInitializeKeyedEventImplementation(VOID)
     ObjectTypeInitializer.Length = sizeof(ObjectTypeInitializer);
     ObjectTypeInitializer.GenericMapping = ExpKeyedEventMapping;
     ObjectTypeInitializer.PoolType = PagedPool;
-    ObjectTypeInitializer.ValidAccessMask = EVENT_ALL_ACCESS;
+    ObjectTypeInitializer.ValidAccessMask = KEYEDEVENT_ALL_ACCESS;
     ObjectTypeInitializer.UseDefaultObject = TRUE;
 
     /* Create the keyed event object type */
@@ -80,14 +71,14 @@ ExpInitializeKeyedEventImplementation(VOID)
     /* Create the out of memory event for critical sections */
     InitializeObjectAttributes(&ObjectAttributes, &Name, OBJ_PERMANENT, NULL, NULL);
     Status = ZwCreateKeyedEvent(&EventHandle,
-                                EVENT_ALL_ACCESS,
+                                KEYEDEVENT_ALL_ACCESS,
                                 &ObjectAttributes,
                                 0);
     if (NT_SUCCESS(Status))
     {
         /* Take a reference so we can get rid of the handle */
         Status = ObReferenceObjectByHandle(EventHandle,
-                                           EVENT_ALL_ACCESS,
+                                           KEYEDEVENT_ALL_ACCESS,
                                            ExKeyedEventObjectType,
                                            KernelMode,
                                            (PVOID*)&ExpCritSecOutOfMemoryEvent,
@@ -116,6 +107,7 @@ ExpInitializeKeyedEvent(
     }
 }
 
+_IRQL_requires_max_(APC_LEVEL)
 NTSTATUS
 NTAPI
 ExpReleaseOrWaitForKeyedEvent(
@@ -126,13 +118,14 @@ ExpReleaseOrWaitForKeyedEvent(
     _In_ BOOLEAN Release)
 {
     PETHREAD Thread, CurrentThread;
-    PKPROCESS CurrentProcess;
+    PEPROCESS CurrentProcess;
     PLIST_ENTRY ListEntry, WaitListHead1, WaitListHead2;
     NTSTATUS Status;
     ULONG_PTR HashIndex;
+    PVOID PreviousKeyedWaitValue;
 
     /* Get the current process */
-    CurrentProcess = KeGetCurrentProcess();
+    CurrentProcess = PsGetCurrentProcess();
 
     /* Calculate the hash index */
     HashIndex = (ULONG_PTR)KeyedWaitValue >> 5;
@@ -140,6 +133,7 @@ ExpReleaseOrWaitForKeyedEvent(
     HashIndex %= NUM_KEY_HASH_BUCKETS;
 
     /* Lock the lists */
+    KeEnterCriticalRegion();
     ExAcquirePushLockExclusive(&KeyedEvent->HashTable[HashIndex].Lock);
 
     /* Get the lists for search and wait, depending on whether
@@ -159,20 +153,33 @@ ExpReleaseOrWaitForKeyedEvent(
     ListEntry = WaitListHead1->Flink;
     while (ListEntry != WaitListHead1)
     {
+        /* Get the waiting thread. Note that this thread cannot be terminated
+           as long as we hold the list lock, since it either needs to wait to
+           be signaled by this thread or, when the wait is aborted due to thread
+           termination, then it first needs to acquire the list lock. */
         Thread = CONTAINING_RECORD(ListEntry, ETHREAD, KeyedWaitChain);
+        ListEntry = ListEntry->Flink;
 
         /* Check if this thread is a correct waiter */
-        if ((Thread->Tcb.Process == CurrentProcess) &&
+        if ((Thread->Tcb.Process == &CurrentProcess->Pcb) &&
             (Thread->KeyedWaitValue == KeyedWaitValue))
         {
             /* Remove the thread from the list */
             RemoveEntryList(&Thread->KeyedWaitChain);
 
+            /* Initialize the list entry to show that it was removed */
+            InitializeListHead(&Thread->KeyedWaitChain);
+
             /* Wake the thread */
-            KeReleaseSemaphore(&Thread->KeyedWaitSemaphore, 0, 1, FALSE);
+            KeReleaseSemaphore(&Thread->KeyedWaitSemaphore,
+                               IO_NO_INCREMENT,
+                               1,
+                               FALSE);
+            Thread = NULL;
 
-            /* Unlock the lists */
+            /* Unlock the list. After this it is not safe to access Thread */
             ExReleasePushLockExclusive(&KeyedEvent->HashTable[HashIndex].Lock);
+            KeLeaveCriticalRegion();
 
             return STATUS_SUCCESS;
         }
@@ -181,7 +188,8 @@ ExpReleaseOrWaitForKeyedEvent(
     /* Get the current thread */
     CurrentThread = PsGetCurrentThread();
 
-    /* Set the wait key */
+    /* Set the wait key and remember the old value */
+    PreviousKeyedWaitValue = CurrentThread->KeyedWaitValue;
     CurrentThread->KeyedWaitValue = KeyedWaitValue;
 
     /* Initialize the wait semaphore */
@@ -190,8 +198,9 @@ ExpReleaseOrWaitForKeyedEvent(
     /* Insert the current thread into the secondary wait list */
     InsertTailList(WaitListHead2, &CurrentThread->KeyedWaitChain);
 
-    /* Unlock the lists */
+    /* Unlock the list */
     ExReleasePushLockExclusive(&KeyedEvent->HashTable[HashIndex].Lock);
+    KeLeaveCriticalRegion();
 
     /* Wait for the keyed wait semaphore */
     Status = KeWaitForSingleObject(&CurrentThread->KeyedWaitSemaphore,
@@ -200,9 +209,33 @@ ExpReleaseOrWaitForKeyedEvent(
                                    Alertable,
                                    Timeout);
 
-    return STATUS_SUCCESS;
+    /* Check if the wait was aborted or timed out */
+    if (Status != STATUS_SUCCESS)
+    {
+        /* Lock the lists to make sure no one else messes with the entry */
+        KeEnterCriticalRegion();
+        ExAcquirePushLockExclusive(&KeyedEvent->HashTable[HashIndex].Lock);
+
+        /* Check if the wait list entry is still in the list */
+        if (!IsListEmpty(&CurrentThread->KeyedWaitChain))
+        {
+            /* Remove the thread from the list */
+            RemoveEntryList(&CurrentThread->KeyedWaitChain);
+            InitializeListHead(&CurrentThread->KeyedWaitChain);
+        }
+
+        /* Unlock the list */
+        ExReleasePushLockExclusive(&KeyedEvent->HashTable[HashIndex].Lock);
+        KeLeaveCriticalRegion();
+    }
+
+    /* Restore the previous KeyedWaitValue, since this is a union member */
+    CurrentThread->KeyedWaitValue = PreviousKeyedWaitValue;
+
+    return Status;
 }
 
+_IRQL_requires_max_(APC_LEVEL)
 NTSTATUS
 NTAPI
 ExpWaitForKeyedEvent(
@@ -219,13 +252,14 @@ ExpWaitForKeyedEvent(
                                          FALSE);
 }
 
+_IRQL_requires_max_(APC_LEVEL)
 NTSTATUS
 NTAPI
 ExpReleaseKeyedEvent(
     _Inout_ PEX_KEYED_EVENT KeyedEvent,
     _In_ PVOID KeyedWaitValue,
-       _In_ BOOLEAN Alertable,
-       _In_ PLARGE_INTEGER Timeout)
+    _In_ BOOLEAN Alertable,
+    _In_ PLARGE_INTEGER Timeout)
 {
     /* Call the generic internal function */
     return ExpReleaseOrWaitForKeyedEvent(KeyedEvent,
@@ -235,6 +269,7 @@ ExpReleaseKeyedEvent(
                                          TRUE);
 }
 
+_IRQL_requires_max_(PASSIVE_LEVEL)
 NTSTATUS
 NTAPI
 NtCreateKeyedEvent(
@@ -269,7 +304,7 @@ NtCreateKeyedEvent(
     /* Check for success */
     if (!NT_SUCCESS(Status)) return Status;
 
-    /* Initalize the keyed event */
+    /* Initialize the keyed event */
     ExpInitializeKeyedEvent(KeyedEvent);
 
     /* Insert it */
@@ -311,6 +346,7 @@ NtCreateKeyedEvent(
     return Status;
 }
 
+_IRQL_requires_max_(PASSIVE_LEVEL)
 NTSTATUS
 NTAPI
 NtOpenKeyedEvent(
@@ -347,6 +383,9 @@ NtOpenKeyedEvent(
         {
             /* Get the exception code */
             Status = _SEH2_GetExceptionCode();
+
+            /* Cleanup */
+            ObCloseHandle(KeyedEventHandle, PreviousMode);
         }
         _SEH2_END;
     }
@@ -359,24 +398,48 @@ NtOpenKeyedEvent(
     return Status;
 }
 
+_IRQL_requires_max_(PASSIVE_LEVEL)
 NTSTATUS
 NTAPI
 NtWaitForKeyedEvent(
-    _In_ HANDLE Handle,
+    _In_opt_ HANDLE Handle,
     _In_ PVOID Key,
     _In_ BOOLEAN Alertable,
-    _In_ PLARGE_INTEGER Timeout)
+    _In_opt_ PLARGE_INTEGER Timeout)
 {
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PEX_KEYED_EVENT KeyedEvent;
     NTSTATUS Status;
+    LARGE_INTEGER TimeoutCopy;
+
+    /* Key must always be two-byte aligned */
+    if ((ULONG_PTR)Key & 1)
+    {
+        return STATUS_INVALID_PARAMETER_1;
+    }
+
+    /* Check if the caller passed a timeout value and this is from user mode */
+    if ((Timeout != NULL) && (PreviousMode != KernelMode))
+    {
+        _SEH2_TRY
+        {
+            ProbeForRead(Timeout, sizeof(*Timeout), 1);
+            TimeoutCopy = *Timeout;
+            Timeout = &TimeoutCopy;
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
 
     /* Check if the caller provided a handle */
     if (Handle != NULL)
     {
         /* Get the keyed event object */
         Status = ObReferenceObjectByHandle(Handle,
-                                           EVENT_MODIFY_STATE,
+                                           KEYEDEVENT_WAIT,
                                            ExKeyedEventObjectType,
                                            PreviousMode,
                                            (PVOID*)&KeyedEvent,
@@ -401,24 +464,48 @@ NtWaitForKeyedEvent(
     return Status;
 }
 
+_IRQL_requires_max_(PASSIVE_LEVEL)
 NTSTATUS
 NTAPI
 NtReleaseKeyedEvent(
-    _In_ HANDLE Handle,
+    _In_opt_ HANDLE Handle,
     _In_ PVOID Key,
     _In_ BOOLEAN Alertable,
-    _In_ PLARGE_INTEGER Timeout)
+    _In_opt_ PLARGE_INTEGER Timeout)
 {
     KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
     PEX_KEYED_EVENT KeyedEvent;
     NTSTATUS Status;
+    LARGE_INTEGER TimeoutCopy;
+
+    /* Key must always be two-byte aligned */
+    if ((ULONG_PTR)Key & 1)
+    {
+        return STATUS_INVALID_PARAMETER_1;
+    }
+
+    /* Check if the caller passed a timeout value and this is from user mode */
+    if ((Timeout != NULL) && (PreviousMode != KernelMode))
+    {
+        _SEH2_TRY
+        {
+            ProbeForRead(Timeout, sizeof(*Timeout), 1);
+            TimeoutCopy = *Timeout;
+            Timeout = &TimeoutCopy;
+        }
+        _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
+        {
+            _SEH2_YIELD(return _SEH2_GetExceptionCode());
+        }
+        _SEH2_END;
+    }
 
     /* Check if the caller provided a handle */
     if (Handle != NULL)
     {
         /* Get the keyed event object */
         Status = ObReferenceObjectByHandle(Handle,
-                                           EVENT_MODIFY_STATE,
+                                           KEYEDEVENT_WAKE,
                                            ExKeyedEventObjectType,
                                            PreviousMode,
                                            (PVOID*)&KeyedEvent,