[WIN32K:NTUSER] Detect when the NtUserCreateWindowStation() caller has provided an...
[reactos.git] / win32ss / user / ntuser / winsta.c
index eb3ca30..f373b1c 100644 (file)
@@ -512,6 +512,7 @@ NtUserCreateWindowStation(
     HWINSTA hWinSta;
     OBJECT_ATTRIBUTES LocalObjectAttributes;
     UNICODE_STRING WindowStationName;
+    KPROCESSOR_MODE AccessMode = UserMode;
 
     TRACE("NtUserCreateWindowStation called\n");
 
@@ -520,9 +521,21 @@ NtUserCreateWindowStation(
     {
         ProbeForRead(ObjectAttributes, sizeof(OBJECT_ATTRIBUTES), sizeof(ULONG));
         LocalObjectAttributes = *ObjectAttributes;
-        Status = IntSafeCopyUnicodeStringTerminateNULL(&WindowStationName,
-                                                       LocalObjectAttributes.ObjectName);
-        LocalObjectAttributes.ObjectName = &WindowStationName;
+        if (LocalObjectAttributes.ObjectName ||
+            LocalObjectAttributes.RootDirectory
+            /* &&
+            LocalObjectAttributes.ObjectName->Buffer &&
+            LocalObjectAttributes.ObjectName->Length > 0 */)
+        {
+            Status = IntSafeCopyUnicodeStringTerminateNULL(&WindowStationName,
+                                                           LocalObjectAttributes.ObjectName);
+            LocalObjectAttributes.ObjectName = &WindowStationName;
+        }
+        else
+        {
+            LocalObjectAttributes.ObjectName = NULL;
+            Status = STATUS_SUCCESS;
+        }
     }
     _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
     {
@@ -537,11 +550,64 @@ NtUserCreateWindowStation(
         return NULL;
     }
 
+    /*
+     * If the caller did not provide a window station name, build a new one
+     * based on the logon session identifier for the calling process.
+     */
+    if (!LocalObjectAttributes.ObjectName)
+    {
+        LUID CallerLuid;
+        WCHAR ServiceWinStaName[MAX_PATH];
+
+        /* Retrieve the LUID of the current process */
+        Status = GetProcessLuid(NULL, NULL, &CallerLuid);
+        if (!NT_SUCCESS(Status))
+        {
+            ERR("Failed to retrieve the caller LUID, Status 0x%08lx\n", Status);
+            SetLastNtError(Status);
+            return NULL;
+        }
+
+        /* Build a valid window station name from the LUID */
+        Status = RtlStringCbPrintfW(ServiceWinStaName,
+                                    sizeof(ServiceWinStaName),
+                                    L"%wZ\\Service-0x%x-%x$",
+                                    &gustrWindowStationsDir,
+                                    CallerLuid.HighPart,
+                                    CallerLuid.LowPart);
+        if (!NT_SUCCESS(Status))
+        {
+            ERR("Impossible to build a valid window station name, Status 0x%08lx\n", Status);
+            SetLastNtError(Status);
+            return NULL;
+        }
+
+        WindowStationName.Length = wcslen(ServiceWinStaName) * sizeof(WCHAR);
+        WindowStationName.MaximumLength =
+            WindowStationName.Length + sizeof(UNICODE_NULL);
+        WindowStationName.Buffer =
+            ExAllocatePoolWithTag(PagedPool,
+                                  WindowStationName.MaximumLength,
+                                  TAG_STRING);
+        if (!WindowStationName.Buffer)
+        {
+            Status = STATUS_NO_MEMORY;
+            ERR("Impossible to build a valid window station name, Status 0x%08lx\n", Status);
+            SetLastNtError(Status);
+            return NULL;
+        }
+        RtlStringCbCopyW(WindowStationName.Buffer,
+                         WindowStationName.MaximumLength,
+                         ServiceWinStaName);
+        LocalObjectAttributes.ObjectName = &WindowStationName;
+        AccessMode = KernelMode;
+    }
+
     // TODO: Capture and use the SecurityQualityOfService!
 
     Status = IntCreateWindowStation(&hWinSta,
                                     &LocalObjectAttributes,
-                                    UserMode,
+                                    AccessMode,
                                     dwDesiredAccess,
                                     Unknown2,
                                     Unknown3,