[KERNEL32] Misc. fixes to DefineDosDeviceW
authorPierre Schweitzer <pierre@reactos.org>
Sun, 5 May 2019 16:31:43 +0000 (18:31 +0200)
committerPierre Schweitzer <pierre@reactos.org>
Sun, 5 May 2019 16:33:15 +0000 (18:33 +0200)
- Add support for LUIDDeviceMapsEnabled;
- Broadcast proper message in case of device removal;
- Use less memory for strings management;
- Make code a bit cleaner.

dll/win32/kernel32/client/dosdev.c

index 67fd608..3c569ec 100644 (file)
@@ -244,38 +244,52 @@ DefineDosDeviceW(
     PCSR_CAPTURE_BUFFER CaptureBuffer;
     UNICODE_STRING NtTargetPathU;
     UNICODE_STRING DeviceNameU;
-    UNICODE_STRING DeviceUpcaseNameU;
     HANDLE hUser32;
     DEV_BROADCAST_VOLUME dbcv;
-    BOOL Result = TRUE;
     DWORD dwRecipients;
     typedef long (WINAPI *BSM_type)(DWORD, LPDWORD, UINT, WPARAM, LPARAM);
     BSM_type BSM_ptr;
+    BOOLEAN LUIDDeviceMapsEnabled;
+    WCHAR Letter;
+    WPARAM wParam;
 
-    if ( (dwFlags & 0xFFFFFFF0) ||
+    /* Get status about local device mapping */
+    LUIDDeviceMapsEnabled = BaseStaticServerData->LUIDDeviceMapsEnabled;
+
+    /* Validate input & flags */
+    if ((dwFlags & 0xFFFFFFE0) ||
         ((dwFlags & DDD_EXACT_MATCH_ON_REMOVE) &&
-        ! (dwFlags & DDD_REMOVE_DEFINITION)) )
+        !(dwFlags & DDD_REMOVE_DEFINITION)) ||
+        (lpTargetPath == NULL && !(dwFlags & (DDD_LUID_BROADCAST_DRIVE | DDD_REMOVE_DEFINITION))) ||
+        ((dwFlags & DDD_LUID_BROADCAST_DRIVE) &&
+         (lpDeviceName == NULL || lpTargetPath != NULL || dwFlags & (DDD_NO_BROADCAST_SYSTEM | DDD_EXACT_MATCH_ON_REMOVE | DDD_RAW_TARGET_PATH) || !LUIDDeviceMapsEnabled)))
     {
         SetLastError(ERROR_INVALID_PARAMETER);
         return FALSE;
     }
 
+    /* Initialize device unicode string to ease its use */
+    RtlInitUnicodeString(&DeviceNameU, lpDeviceName);
+
+    /* The buffer for CSR call will contain it */
+    BufferSize = DeviceNameU.MaximumLength;
     ArgumentCount = 1;
-    BufferSize = 0;
-    if (!lpTargetPath)
+
+    /* If we don't have target path, use empty string */
+    if (lpTargetPath == NULL)
     {
-        RtlInitUnicodeString(&NtTargetPathU,
-                             NULL);
+        RtlInitUnicodeString(&NtTargetPathU, NULL);
     }
     else
     {
+        /* Else, use it raw if asked to */
         if (dwFlags & DDD_RAW_TARGET_PATH)
         {
-            RtlInitUnicodeString(&NtTargetPathU,
-                                 lpTargetPath);
+            RtlInitUnicodeString(&NtTargetPathU, lpTargetPath);
         }
         else
         {
+            /* Otherwise, use it converted */
             if (!RtlDosPathNameToNtPathName_U(lpTargetPath,
                                               &NtTargetPathU,
                                               NULL,
@@ -286,104 +300,119 @@ DefineDosDeviceW(
                 return FALSE;
             }
         }
+
+        /* This target path will be the second arg */
         ArgumentCount = 2;
-        BufferSize += NtTargetPathU.Length;
+        BufferSize += NtTargetPathU.MaximumLength;
     }
 
-    RtlInitUnicodeString(&DeviceNameU,
-                         lpDeviceName);
-    RtlUpcaseUnicodeString(&DeviceUpcaseNameU,
-                           &DeviceNameU,
-                           TRUE);
-    BufferSize += DeviceUpcaseNameU.Length;
-
+    /* Allocate the capture buffer for our strings */
     CaptureBuffer = CsrAllocateCaptureBuffer(ArgumentCount,
                                              BufferSize);
-    if (!CaptureBuffer)
+    if (CaptureBuffer == NULL)
     {
+        if (!(dwFlags & DDD_RAW_TARGET_PATH))
+        {
+            RtlFreeUnicodeString(&NtTargetPathU);
+        }
+
         SetLastError(ERROR_NOT_ENOUGH_MEMORY);
-        Result = FALSE;
+        return FALSE;
     }
-    else
-    {
-        DefineDosDeviceRequest->Flags = dwFlags;
 
-        CsrCaptureMessageBuffer(CaptureBuffer,
-                                DeviceUpcaseNameU.Buffer,
-                                DeviceUpcaseNameU.Length,
-                                (PVOID*)&DefineDosDeviceRequest->DeviceName.Buffer);
+    /* Set the flags */
+    DefineDosDeviceRequest->Flags = dwFlags;
 
-        DefineDosDeviceRequest->DeviceName.Length =
-            DeviceUpcaseNameU.Length;
-        DefineDosDeviceRequest->DeviceName.MaximumLength =
-            DeviceUpcaseNameU.Length;
+    /* Allocate a buffer for the device name */
+    DefineDosDeviceRequest->DeviceName.MaximumLength = CsrAllocateMessagePointer(CaptureBuffer,
+                                                                                 DeviceNameU.MaximumLength,
+                                                                                 (PVOID*)&DefineDosDeviceRequest->DeviceName.Buffer);
+    /* And copy it while upcasing it */
+    RtlUpcaseUnicodeString(&DefineDosDeviceRequest->DeviceName, &DeviceNameU, FALSE);
 
-        if (NtTargetPathU.Buffer)
-        {
-            CsrCaptureMessageBuffer(CaptureBuffer,
-                                    NtTargetPathU.Buffer,
-                                    NtTargetPathU.Length,
-                                    (PVOID*)&DefineDosDeviceRequest->TargetPath.Buffer);
-        }
-        DefineDosDeviceRequest->TargetPath.Length =
-            NtTargetPathU.Length;
-        DefineDosDeviceRequest->TargetPath.MaximumLength =
-            NtTargetPathU.Length;
-
-        CsrClientCallServer((PCSR_API_MESSAGE)&ApiMessage,
-                            CaptureBuffer,
-                            CSR_CREATE_API_NUMBER(BASESRV_SERVERDLL_INDEX, BasepDefineDosDevice),
-                            sizeof(*DefineDosDeviceRequest));
-        CsrFreeCaptureBuffer(CaptureBuffer);
-
-        if (!NT_SUCCESS(ApiMessage.Status))
+    /* If we have a target path, copy it too, and free it if allocated */
+    if (NtTargetPathU.Length != 0)
+    {
+        DefineDosDeviceRequest->TargetPath.MaximumLength = CsrAllocateMessagePointer(CaptureBuffer,
+                                                                                     NtTargetPathU.MaximumLength,
+                                                                                     (PVOID*)&DefineDosDeviceRequest->TargetPath.Buffer);
+        RtlCopyUnicodeString(&DefineDosDeviceRequest->TargetPath, &NtTargetPathU);
+
+        if (!(dwFlags & DDD_RAW_TARGET_PATH))
         {
-            WARN("CsrClientCallServer() failed (Status %lx)\n", ApiMessage.Status);
-            BaseSetLastNTError(ApiMessage.Status);
-            Result = FALSE;
+            RtlFreeUnicodeString(&NtTargetPathU);
         }
-        else
+    }
+    /* Otherwise, null initialize the string */
+    else
+    {
+        RtlInitUnicodeString(&DefineDosDeviceRequest->TargetPath, NULL);
+    }
+
+    /* Finally,  call the server */
+    CsrClientCallServer((PCSR_API_MESSAGE)&ApiMessage,
+                        CaptureBuffer,
+                        CSR_CREATE_API_NUMBER(BASESRV_SERVERDLL_INDEX, BasepDefineDosDevice),
+                        sizeof(*DefineDosDeviceRequest));
+    CsrFreeCaptureBuffer(CaptureBuffer);
+
+    /* Return failure if any */
+    if (!NT_SUCCESS(ApiMessage.Status))
+    {
+        WARN("CsrClientCallServer() failed (Status %lx)\n", ApiMessage.Status);
+        BaseSetLastNTError(ApiMessage.Status);
+        return FALSE;
+    }
+
+    /* Here is the success path, we will always return true */
+
+    /* Should broadcast the event? Only do if not denied and if drive letter */
+    if (!(dwFlags & DDD_NO_BROADCAST_SYSTEM) &&
+        DeviceNameU.Length == 2 * sizeof(WCHAR) &&
+        DeviceNameU.Buffer[1] == L':')
+    {
+        /* Make sure letter is valid and there are no local device mappings */
+        Letter = RtlUpcaseUnicodeChar(DeviceNameU.Buffer[0]) - L'A';
+        if (Letter < 26 && !LUIDDeviceMapsEnabled)
         {
-            if (! (dwFlags & DDD_NO_BROADCAST_SYSTEM) &&
-                DeviceUpcaseNameU.Length == 2 * sizeof(WCHAR) &&
-                DeviceUpcaseNameU.Buffer[1] == L':' &&
-                ( (DeviceUpcaseNameU.Buffer[0] - L'A') < 26 ))
+            /* Rely on user32 for broadcasting */
+            hUser32 = LoadLibraryW(L"user32.dll");
+            if (hUser32 != 0)
             {
-                hUser32 = LoadLibraryA("user32.dll");
-                if (hUser32)
+                /* Get the function pointer */
+                BSM_ptr = (BSM_type)GetProcAddress(hUser32, "BroadcastSystemMessageW");
+                if (BSM_ptr)
                 {
-                    BSM_ptr = (BSM_type)
-                        GetProcAddress(hUser32, "BroadcastSystemMessageW");
-                    if (BSM_ptr)
-                    {
-                        dwRecipients = BSM_APPLICATIONS;
-                        dbcv.dbcv_size = sizeof(DEV_BROADCAST_VOLUME);
-                        dbcv.dbcv_devicetype = DBT_DEVTYP_VOLUME;
-                        dbcv.dbcv_reserved = 0;
-                        dbcv.dbcv_unitmask |=
-                            (1 << (DeviceUpcaseNameU.Buffer[0] - L'A'));
-                        dbcv.dbcv_flags = DBTF_NET;
-                        (void) BSM_ptr(BSF_SENDNOTIFYMESSAGE | BSF_FLUSHDISK,
-                                       &dwRecipients,
-                                       WM_DEVICECHANGE,
-                                       (WPARAM)DBT_DEVICEARRIVAL,
-                                       (LPARAM)&dbcv);
-                    }
-                    FreeLibrary(hUser32);
+                    /* Set our target */
+                    dwRecipients = BSM_APPLICATIONS;
+
+                    /* And initialize our structure */
+                    dbcv.dbcv_size = sizeof(DEV_BROADCAST_VOLUME);
+                    dbcv.dbcv_devicetype = DBT_DEVTYP_VOLUME;
+                    dbcv.dbcv_reserved = 0;
+
+                    /* Set the volume which had the event */
+                    dbcv.dbcv_unitmask = 1 << Letter;
+                    dbcv.dbcv_flags = DBTF_NET;
+
+                    /* And properly set the event (removal or arrival?) */
+                    wParam = (dwFlags & DDD_REMOVE_DEFINITION) ? DBT_DEVICEREMOVECOMPLETE : DBT_DEVICEARRIVAL;
+
+                    /* And broadcast! */
+                    BSM_ptr(BSF_SENDNOTIFYMESSAGE | BSF_FLUSHDISK,
+                            &dwRecipients,
+                            WM_DEVICECHANGE,
+                            wParam,
+                            (LPARAM)&dbcv);
                 }
+
+                /* We're done! */
+                FreeLibrary(hUser32);
             }
         }
     }
 
-    if (NtTargetPathU.Buffer &&
-        NtTargetPathU.Buffer != lpTargetPath)
-    {
-        RtlFreeHeap(RtlGetProcessHeap(),
-                    0,
-                    NtTargetPathU.Buffer);
-    }
-    RtlFreeUnicodeString(&DeviceUpcaseNameU);
-    return Result;
+    return TRUE;
 }