[USERENV]
authorEric Kohl <eric.kohl@reactos.org>
Sat, 2 Nov 2013 14:45:26 +0000 (14:45 +0000)
committerEric Kohl <eric.kohl@reactos.org>
Sat, 2 Nov 2013 14:45:26 +0000 (14:45 +0000)
Get the user name and domain name from the logon token and set the USERNAME and USERDOMAIN environment variables accordingly.

svn path=/trunk/; revision=60834

reactos/dll/win32/userenv/environment.c
reactos/dll/win32/userenv/internal.h
reactos/dll/win32/userenv/misc.c
reactos/dll/win32/userenv/profile.c

index 918bce4..baf2a80 100644 (file)
@@ -175,8 +175,8 @@ GetCurrentUserKey(HANDLE hToken)
     HKEY hKey;
     LONG Error;
 
-    if (!GetUserSidFromToken(hToken,
-                             &SidString))
+    if (!GetUserSidStringFromToken(hToken,
+                                   &SidString))
     {
         DPRINT1("GetUserSidFromToken() failed\n");
         return NULL;
@@ -201,6 +201,89 @@ GetCurrentUserKey(HANDLE hToken)
 }
 
 
+static
+BOOL
+GetUserAndDomainName(IN HANDLE hToken,
+                     OUT LPWSTR *UserName,
+                     OUT LPWSTR *DomainName)
+{
+    PSID Sid = NULL;
+    LPWSTR lpUserName = NULL;
+    LPWSTR lpDomainName = NULL;
+    DWORD cbUserName = 0;
+    DWORD cbDomainName = 0;
+    SID_NAME_USE SidNameUse;
+    BOOL bRet = TRUE;
+
+    if (!GetUserSidFromToken(hToken,
+                             &Sid))
+    {
+        DPRINT1("GetUserSidFromToken() failed\n");
+        return FALSE;
+    }
+
+    if (!LookupAccountSidW(NULL,
+                           Sid,
+                           NULL,
+                           &cbUserName,
+                           NULL,
+                           &cbDomainName,
+                           &SidNameUse))
+    {
+        if (GetLastError() != ERROR_INSUFFICIENT_BUFFER)
+        {
+            bRet = FALSE;
+            goto done;
+        }
+    }
+
+    lpUserName = LocalAlloc(LPTR,
+                            cbUserName * sizeof(WCHAR));
+    if (lpUserName == NULL)
+    {
+        bRet = FALSE;
+        goto done;
+    }
+
+    lpDomainName = LocalAlloc(LPTR,
+                              cbDomainName * sizeof(WCHAR));
+    if (lpDomainName == NULL)
+    {
+        bRet = FALSE;
+        goto done;
+    }
+
+    if (!LookupAccountSidW(NULL,
+                           Sid,
+                           lpUserName,
+                           &cbUserName,
+                           lpDomainName,
+                           &cbDomainName,
+                           &SidNameUse))
+    {
+        bRet = FALSE;
+        goto done;
+    }
+
+    *UserName = lpUserName;
+    *DomainName = lpDomainName;
+
+done:
+    if (bRet == FALSE)
+    {
+        if (lpUserName != NULL)
+            LocalFree(lpUserName);
+
+        if (lpDomainName != NULL)
+            LocalFree(lpDomainName);
+    }
+
+    LocalFree(Sid);
+
+    return bRet;
+}
+
+
 static
 BOOL
 SetUserEnvironment(LPVOID *lpEnvironment,
@@ -326,6 +409,8 @@ CreateEnvironmentBlock(LPVOID *lpEnvironment,
     DWORD dwType;
     HKEY hKey;
     HKEY hKeyUser;
+    LPWSTR lpUserName = NULL;
+    LPWSTR lpDomainName = NULL;
     NTSTATUS Status;
     LONG lError;
 
@@ -431,15 +516,20 @@ CreateEnvironmentBlock(LPVOID *lpEnvironment,
                                    FALSE);
     }
 
-    /* FIXME: Set 'USERDOMAIN' variable */
-
-    Length = MAX_PATH;
-    if (GetUserNameW(Buffer,
-                     &Length))
+    if (GetUserAndDomainName(hToken,
+                             &lpUserName,
+                             &lpDomainName))
     {
+        /* Set 'USERDOMAIN' variable */
+        SetUserEnvironmentVariable(lpEnvironment,
+                                   L"USERDOMAIN",
+                                   lpDomainName,
+                                   FALSE);
+
+        /* Set 'USERNAME' variable */
         SetUserEnvironmentVariable(lpEnvironment,
                                    L"USERNAME",
-                                   Buffer,
+                                   lpUserName,
                                    FALSE);
     }
 
@@ -455,6 +545,12 @@ CreateEnvironmentBlock(LPVOID *lpEnvironment,
 
     RegCloseKey(hKeyUser);
 
+    if (lpUserName != NULL)
+        LocalFree(lpUserName);
+
+    if (lpDomainName != NULL)
+        LocalFree(lpDomainName);
+
     return TRUE;
 }
 
index 098de53..d22ab20 100644 (file)
@@ -75,7 +75,11 @@ AppendBackslash(LPWSTR String);
 
 BOOL
 GetUserSidFromToken(HANDLE hToken,
-                    PUNICODE_STRING SidString);
+                    PSID *Sid);
+
+BOOL
+GetUserSidStringFromToken(HANDLE hToken,
+                          PUNICODE_STRING SidString);
 
 PSECURITY_DESCRIPTOR
 CreateDefaultSecurityDescriptor(VOID);
index e7bf1dd..1075b33 100644 (file)
@@ -53,56 +53,124 @@ AppendBackslash(LPWSTR String)
 
 BOOL
 GetUserSidFromToken(HANDLE hToken,
-                    PUNICODE_STRING SidString)
+                    PSID *Sid)
 {
-    PSID_AND_ATTRIBUTES SidBuffer, nsb;
+    PTOKEN_USER UserBuffer, nsb;
+    PSID pSid = NULL;
     ULONG Length;
     NTSTATUS Status;
 
     Length = 256;
-    SidBuffer = LocalAlloc(LMEM_FIXED,
-                           Length);
-    if (SidBuffer == NULL)
+    UserBuffer = LocalAlloc(LPTR, Length);
+    if (UserBuffer == NULL)
+    {
+        return FALSE;
+    }
+
+    Status = NtQueryInformationToken(hToken,
+                                     TokenUser,
+                                     (PVOID)UserBuffer,
+                                     Length,
+                                     &Length);
+    if (Status == STATUS_BUFFER_TOO_SMALL)
+    {
+        nsb = LocalReAlloc(UserBuffer, Length, LMEM_MOVEABLE);
+        if (nsb == NULL)
+        {
+            LocalFree(UserBuffer);
+            return FALSE;
+        }
+
+        UserBuffer = nsb;
+        Status = NtQueryInformationToken(hToken,
+                                         TokenUser,
+                                         (PVOID)UserBuffer,
+                                         Length,
+                                         &Length);
+    }
+
+    if (!NT_SUCCESS (Status))
+    {
+        LocalFree(UserBuffer);
+        return FALSE;
+    }
+
+    Length = RtlLengthSid(UserBuffer->User.Sid);
+
+    pSid = LocalAlloc(LPTR, Length);
+    if (pSid == NULL)
+    {
+        LocalFree(UserBuffer);
+        return FALSE;
+    }
+
+    Status = RtlCopySid(Length, pSid, UserBuffer->User.Sid);
+
+    LocalFree(UserBuffer);
+
+    if (!NT_SUCCESS (Status))
+    {
+        LocalFree(pSid);
+        return FALSE;
+    }
+
+    *Sid = pSid;
+
+    return TRUE;
+}
+
+
+BOOL
+GetUserSidStringFromToken(HANDLE hToken,
+                          PUNICODE_STRING SidString)
+{
+    PTOKEN_USER UserBuffer, nsb;
+    ULONG Length;
+    NTSTATUS Status;
+
+    Length = 256;
+    UserBuffer = LocalAlloc(LPTR, Length);
+    if (UserBuffer == NULL)
         return FALSE;
 
     Status = NtQueryInformationToken(hToken,
                                      TokenUser,
-                                     (PVOID)SidBuffer,
+                                     (PVOID)UserBuffer,
                                      Length,
                                      &Length);
     if (Status == STATUS_BUFFER_TOO_SMALL)
     {
-        nsb = LocalReAlloc(SidBuffer,
+        nsb = LocalReAlloc(UserBuffer,
                            Length,
                            LMEM_MOVEABLE);
         if (nsb == NULL)
         {
-            LocalFree((HLOCAL)SidBuffer);
+            LocalFree(UserBuffer);
             return FALSE;
         }
 
-        SidBuffer = nsb;
+        UserBuffer = nsb;
         Status = NtQueryInformationToken(hToken,
                                          TokenUser,
-                                         (PVOID)SidBuffer,
+                                         (PVOID)UserBuffer,
                                          Length,
                                          &Length);
     }
 
     if (!NT_SUCCESS (Status))
     {
-        LocalFree((HLOCAL)SidBuffer);
+        LocalFree(UserBuffer);
         SetLastError(RtlNtStatusToDosError(Status));
         return FALSE;
     }
 
-    DPRINT("SidLength: %lu\n", RtlLengthSid (SidBuffer[0].Sid));
+    DPRINT("SidLength: %lu\n", RtlLengthSid (UserBuffer->User.Sid));
 
     Status = RtlConvertSidToUnicodeString(SidString,
-                                          SidBuffer[0].Sid,
+                                          UserBuffer->User.Sid,
                                           TRUE);
 
-    LocalFree((HLOCAL)SidBuffer);
+    LocalFree(UserBuffer);
 
     if (!NT_SUCCESS(Status))
     {
@@ -117,6 +185,7 @@ GetUserSidFromToken(HANDLE hToken,
     return TRUE;
 }
 
+
 PSECURITY_DESCRIPTOR
 CreateDefaultSecurityDescriptor(VOID)
 {
index 2884e3a..888115b 100644 (file)
@@ -814,8 +814,8 @@ GetUserProfileDirectoryW(HANDLE hToken,
     HKEY hKey;
     LONG Error;
 
-    if (!GetUserSidFromToken(hToken,
-                             &SidString))
+    if (!GetUserSidStringFromToken(hToken,
+                                   &SidString))
     {
         DPRINT1("GetUserSidFromToken() failed\n");
         return FALSE;
@@ -898,8 +898,8 @@ CheckForLoadedProfile(HANDLE hToken)
 
     DPRINT("CheckForLoadedProfile() called\n");
 
-    if (!GetUserSidFromToken(hToken,
-                             &SidString))
+    if (!GetUserSidStringFromToken(hToken,
+                                   &SidString))
     {
         DPRINT1("GetUserSidFromToken() failed\n");
         return FALSE;
@@ -1166,7 +1166,7 @@ LoadUserProfileW(IN HANDLE hToken,
     }
 
     /* Get user SID string */
-    ret = GetUserSidFromToken(hToken, &SidString);
+    ret = GetUserSidStringFromToken(hToken, &SidString);
     if (!ret)
     {
         DPRINT1("GetUserSidFromToken() failed\n");
@@ -1243,8 +1243,8 @@ UnloadUserProfile(HANDLE hToken,
 
     RegCloseKey(hProfile);
 
-    if (!GetUserSidFromToken(hToken,
-                             &SidString))
+    if (!GetUserSidStringFromToken(hToken,
+                                   &SidString))
     {
         DPRINT1("GetUserSidFromToken() failed\n");
         return FALSE;