X-Git-Url: https://git.reactos.org/?p=reactos.git;a=blobdiff_plain;f=reactos%2Fntoskrnl%2Fio%2Fiomgr%2Fdriver.c;h=499e42f8d183b5b988e46d5bfeff2b9a36970042;hp=da913b1dd4b26a9081458990e6bdce5fd2c0e434;hb=552f1fd24d098e0e5fcd157d6fc62aa4508ec5a8;hpb=6ea50da3a222662dd214231352b6e2f42a880729 diff --git a/reactos/ntoskrnl/io/iomgr/driver.c b/reactos/ntoskrnl/io/iomgr/driver.c index da913b1dd4b..499e42f8d18 100644 --- a/reactos/ntoskrnl/io/iomgr/driver.c +++ b/reactos/ntoskrnl/io/iomgr/driver.c @@ -16,6 +16,8 @@ /* GLOBALS ********************************************************************/ +ERESOURCE IopDriverLoadResource; + LIST_ENTRY DriverReinitListHead; KSPIN_LOCK DriverReinitListLock; PLIST_ENTRY DriverReinitTailEntry; @@ -39,15 +41,16 @@ PLIST_ENTRY IopGroupTable; /* PRIVATE FUNCTIONS **********************************************************/ -NTSTATUS NTAPI +NTSTATUS +NTAPI IopInvalidDeviceRequest( - PDEVICE_OBJECT DeviceObject, - PIRP Irp) + PDEVICE_OBJECT DeviceObject, + PIRP Irp) { - Irp->IoStatus.Status = STATUS_INVALID_DEVICE_REQUEST; - Irp->IoStatus.Information = 0; - IoCompleteRequest(Irp, IO_NO_INCREMENT); - return STATUS_INVALID_DEVICE_REQUEST; + Irp->IoStatus.Status = STATUS_INVALID_DEVICE_REQUEST; + Irp->IoStatus.Information = 0; + IoCompleteRequest(Irp, IO_NO_INCREMENT); + return STATUS_INVALID_DEVICE_REQUEST; } VOID @@ -64,8 +67,7 @@ IopDeleteDriver(IN PVOID ObjectBody) ASSERT(!DriverObject->DeviceObject); /* Get the extension and loop them */ - DriverExtension = IoGetDrvObjExtension(DriverObject)-> - ClientDriverExtension; + DriverExtension = IoGetDrvObjExtension(DriverObject)->ClientDriverExtension; while (DriverExtension) { /* Get the next one */ @@ -90,71 +92,69 @@ IopDeleteDriver(IN PVOID ObjectBody) ExFreePool(DriverObject->DriverName.Buffer); } -#if 0 /* See a bit of hack in IopCreateDriver */ /* Check if it has a service key name */ if (DriverObject->DriverExtension->ServiceKeyName.Buffer) { /* Free it */ ExFreePool(DriverObject->DriverExtension->ServiceKeyName.Buffer); } -#endif } -NTSTATUS FASTCALL +NTSTATUS +FASTCALL IopGetDriverObject( - PDRIVER_OBJECT *DriverObject, - PUNICODE_STRING ServiceName, - BOOLEAN FileSystem) + PDRIVER_OBJECT *DriverObject, + PUNICODE_STRING ServiceName, + BOOLEAN FileSystem) { - PDRIVER_OBJECT Object; - WCHAR NameBuffer[MAX_PATH]; - UNICODE_STRING DriverName; - NTSTATUS Status; - - DPRINT("IopGetDriverObject(%p '%wZ' %x)\n", - DriverObject, ServiceName, FileSystem); - - *DriverObject = NULL; - - /* Create ModuleName string */ - if (ServiceName == NULL || ServiceName->Buffer == NULL) - /* We don't know which DriverObject we have to open */ - return STATUS_INVALID_PARAMETER_2; - - DriverName.Buffer = NameBuffer; - DriverName.Length = 0; - DriverName.MaximumLength = sizeof(NameBuffer); - - if (FileSystem == TRUE) - RtlAppendUnicodeToString(&DriverName, FILESYSTEM_ROOT_NAME); - else - RtlAppendUnicodeToString(&DriverName, DRIVER_ROOT_NAME); - RtlAppendUnicodeStringToString(&DriverName, ServiceName); - - DPRINT("Driver name: '%wZ'\n", &DriverName); - - /* Open driver object */ - Status = ObReferenceObjectByName( - &DriverName, - OBJ_OPENIF | OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE, /* Attributes */ - NULL, /* PassedAccessState */ - 0, /* DesiredAccess */ - IoDriverObjectType, - KernelMode, - NULL, /* ParseContext */ - (PVOID*)&Object); - - if (!NT_SUCCESS(Status)) - { - DPRINT("Failed to reference driver object, status=0x%08x\n", Status); - return Status; - } - - *DriverObject = Object; - - DPRINT("Driver Object: %p\n", Object); - - return STATUS_SUCCESS; + PDRIVER_OBJECT Object; + WCHAR NameBuffer[MAX_PATH]; + UNICODE_STRING DriverName; + NTSTATUS Status; + + DPRINT("IopGetDriverObject(%p '%wZ' %x)\n", + DriverObject, ServiceName, FileSystem); + + ASSERT(ExIsResourceAcquiredExclusiveLite(&IopDriverLoadResource)); + *DriverObject = NULL; + + /* Create ModuleName string */ + if (ServiceName == NULL || ServiceName->Buffer == NULL) + /* We don't know which DriverObject we have to open */ + return STATUS_INVALID_PARAMETER_2; + + DriverName.Buffer = NameBuffer; + DriverName.Length = 0; + DriverName.MaximumLength = sizeof(NameBuffer); + + if (FileSystem != FALSE) + RtlAppendUnicodeToString(&DriverName, FILESYSTEM_ROOT_NAME); + else + RtlAppendUnicodeToString(&DriverName, DRIVER_ROOT_NAME); + RtlAppendUnicodeStringToString(&DriverName, ServiceName); + + DPRINT("Driver name: '%wZ'\n", &DriverName); + + /* Open driver object */ + Status = ObReferenceObjectByName(&DriverName, + OBJ_OPENIF | OBJ_KERNEL_HANDLE | OBJ_CASE_INSENSITIVE, /* Attributes */ + NULL, /* PassedAccessState */ + 0, /* DesiredAccess */ + IoDriverObjectType, + KernelMode, + NULL, /* ParseContext */ + (PVOID*)&Object); + if (!NT_SUCCESS(Status)) + { + DPRINT("Failed to reference driver object, status=0x%08x\n", Status); + return Status; + } + + *DriverObject = Object; + + DPRINT("Driver Object: %p\n", Object); + + return STATUS_SUCCESS; } /* @@ -195,10 +195,8 @@ IopSuffixUnicodeString( * * Display 'Loading XXX...' message. */ - VOID FASTCALL -INIT_FUNCTION IopDisplayLoadingMessage(PUNICODE_STRING ServiceName) { CHAR TextBuffer[256]; @@ -236,55 +234,60 @@ IopDisplayLoadingMessage(PUNICODE_STRING ServiceName) * Remarks * The input image path isn't freed on error. */ - NTSTATUS FASTCALL IopNormalizeImagePath( - _Inout_ _When_(return>=0, _At_(ImagePath->Buffer, _Post_notnull_ __drv_allocatesMem(Mem))) - PUNICODE_STRING ImagePath, - _In_ PUNICODE_STRING ServiceName) + _Inout_ _When_(return>=0, _At_(ImagePath->Buffer, _Post_notnull_ __drv_allocatesMem(Mem))) + PUNICODE_STRING ImagePath, + _In_ PUNICODE_STRING ServiceName) { - UNICODE_STRING InputImagePath; - - DPRINT("Normalizing image path '%wZ' for service '%wZ'\n", ImagePath, ServiceName); - - RtlCopyMemory( - &InputImagePath, - ImagePath, - sizeof(UNICODE_STRING)); - - if (InputImagePath.Length == 0) - { - ImagePath->Length = 0; - ImagePath->MaximumLength = - (33 * sizeof(WCHAR)) + ServiceName->Length + sizeof(UNICODE_NULL); - ImagePath->Buffer = ExAllocatePool(NonPagedPool, ImagePath->MaximumLength); - if (ImagePath->Buffer == NULL) - return STATUS_NO_MEMORY; - - RtlAppendUnicodeToString(ImagePath, L"\\SystemRoot\\system32\\drivers\\"); - RtlAppendUnicodeStringToString(ImagePath, ServiceName); - RtlAppendUnicodeToString(ImagePath, L".sys"); - } else - if (InputImagePath.Buffer[0] != L'\\') - { - ImagePath->Length = 0; - ImagePath->MaximumLength = - 12 * sizeof(WCHAR) + InputImagePath.Length + sizeof(UNICODE_NULL); - ImagePath->Buffer = ExAllocatePool(NonPagedPool, ImagePath->MaximumLength); - if (ImagePath->Buffer == NULL) - return STATUS_NO_MEMORY; - - RtlAppendUnicodeToString(ImagePath, L"\\SystemRoot\\"); - RtlAppendUnicodeStringToString(ImagePath, &InputImagePath); - - /* Free caller's string */ - ExFreePoolWithTag(InputImagePath.Buffer, TAG_RTLREGISTRY); - } - - DPRINT("Normalized image path is '%wZ' for service '%wZ'\n", ImagePath, ServiceName); - - return STATUS_SUCCESS; + UNICODE_STRING SystemRootString = RTL_CONSTANT_STRING(L"\\SystemRoot\\"); + UNICODE_STRING DriversPathString = RTL_CONSTANT_STRING(L"\\SystemRoot\\System32\\drivers\\"); + UNICODE_STRING DotSysString = RTL_CONSTANT_STRING(L".sys"); + UNICODE_STRING InputImagePath; + + DPRINT("Normalizing image path '%wZ' for service '%wZ'\n", ImagePath, ServiceName); + + InputImagePath = *ImagePath; + if (InputImagePath.Length == 0) + { + ImagePath->Length = 0; + ImagePath->MaximumLength = DriversPathString.Length + + ServiceName->Length + + DotSysString.Length + + sizeof(UNICODE_NULL); + ImagePath->Buffer = ExAllocatePoolWithTag(NonPagedPool, + ImagePath->MaximumLength, + TAG_IO); + if (ImagePath->Buffer == NULL) + return STATUS_NO_MEMORY; + + RtlCopyUnicodeString(ImagePath, &DriversPathString); + RtlAppendUnicodeStringToString(ImagePath, ServiceName); + RtlAppendUnicodeStringToString(ImagePath, &DotSysString); + } + else if (InputImagePath.Buffer[0] != L'\\') + { + ImagePath->Length = 0; + ImagePath->MaximumLength = SystemRootString.Length + + InputImagePath.Length + + sizeof(UNICODE_NULL); + ImagePath->Buffer = ExAllocatePoolWithTag(NonPagedPool, + ImagePath->MaximumLength, + TAG_IO); + if (ImagePath->Buffer == NULL) + return STATUS_NO_MEMORY; + + RtlCopyUnicodeString(ImagePath, &SystemRootString); + RtlAppendUnicodeStringToString(ImagePath, &InputImagePath); + + /* Free caller's string */ + ExFreePoolWithTag(InputImagePath.Buffer, TAG_RTLREGISTRY); + } + + DPRINT("Normalized image path is '%wZ' for service '%wZ'\n", ImagePath, ServiceName); + + return STATUS_SUCCESS; } /* @@ -299,134 +302,129 @@ IopNormalizeImagePath( * Return Value * Status */ - -NTSTATUS FASTCALL +NTSTATUS +FASTCALL IopLoadServiceModule( - IN PUNICODE_STRING ServiceName, - OUT PLDR_DATA_TABLE_ENTRY *ModuleObject) + IN PUNICODE_STRING ServiceName, + OUT PLDR_DATA_TABLE_ENTRY *ModuleObject) { - RTL_QUERY_REGISTRY_TABLE QueryTable[3]; - ULONG ServiceStart; - UNICODE_STRING ServiceImagePath, CCSName; - NTSTATUS Status; - HANDLE CCSKey, ServiceKey; - PVOID BaseAddress; - - DPRINT("IopLoadServiceModule(%wZ, 0x%p)\n", ServiceName, ModuleObject); - - /* FIXME: This check may be removed once the bug is fixed */ - if (ServiceName->Buffer == NULL) - { - DPRINT1("If you see this, please report to Fireball or hpoussin!\n"); - return STATUS_UNSUCCESSFUL; - } - - if (ExpInTextModeSetup) - { - /* We have no registry, but luckily we know where all the drivers are */ - - /* ServiceStart < 4 is all that matters */ - ServiceStart = 0; - - /* IopNormalizeImagePath will do all of the work for us if we give it an empty string */ - ServiceImagePath.Length = ServiceImagePath.MaximumLength = 0; - ServiceImagePath.Buffer = NULL; - } - else - { - /* Open CurrentControlSet */ - RtlInitUnicodeString(&CCSName, - L"\\Registry\\Machine\\SYSTEM\\CurrentControlSet\\Services"); - Status = IopOpenRegistryKeyEx(&CCSKey, NULL, &CCSName, KEY_READ); - if (!NT_SUCCESS(Status)) - { - DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); - return Status; - } - - /* Open service key */ - Status = IopOpenRegistryKeyEx(&ServiceKey, CCSKey, ServiceName, KEY_READ); - if (!NT_SUCCESS(Status)) - { - DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); - ZwClose(CCSKey); - return Status; - } - - /* - * Get information about the service. - */ - - RtlZeroMemory(QueryTable, sizeof(QueryTable)); - - RtlInitUnicodeString(&ServiceImagePath, NULL); - - QueryTable[0].Name = L"Start"; - QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT; - QueryTable[0].EntryContext = &ServiceStart; - - QueryTable[1].Name = L"ImagePath"; - QueryTable[1].Flags = RTL_QUERY_REGISTRY_DIRECT; - QueryTable[1].EntryContext = &ServiceImagePath; - - Status = RtlQueryRegistryValues(RTL_REGISTRY_HANDLE, - (PWSTR)ServiceKey, QueryTable, NULL, NULL); - - ZwClose(ServiceKey); - ZwClose(CCSKey); - - if (!NT_SUCCESS(Status)) - { - DPRINT1("RtlQueryRegistryValues() failed (Status %x)\n", Status); - return Status; - } - } - - /* - * Normalize the image path for all later processing. - */ - - Status = IopNormalizeImagePath(&ServiceImagePath, ServiceName); - - if (!NT_SUCCESS(Status)) - { - DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status); - return Status; - } - - /* - * Case for disabled drivers - */ - - if (ServiceStart >= 4) - { - /* We can't load this */ - Status = STATUS_DRIVER_UNABLE_TO_LOAD; - } - else - { - DPRINT("Loading module from %wZ\n", &ServiceImagePath); - Status = MmLoadSystemImage(&ServiceImagePath, NULL, NULL, 0, (PVOID)ModuleObject, &BaseAddress); - if (NT_SUCCESS(Status)) - { - IopDisplayLoadingMessage(ServiceName); - } - } - - ExFreePool(ServiceImagePath.Buffer); - - /* - * Now check if the module was loaded successfully. - */ - - if (!NT_SUCCESS(Status)) - { - DPRINT("Module loading failed (Status %x)\n", Status); - } - - DPRINT("Module loading (Status %x)\n", Status); - - return Status; + RTL_QUERY_REGISTRY_TABLE QueryTable[3]; + ULONG ServiceStart; + UNICODE_STRING ServiceImagePath, CCSName; + NTSTATUS Status; + HANDLE CCSKey, ServiceKey; + PVOID BaseAddress; + + ASSERT(ExIsResourceAcquiredExclusiveLite(&IopDriverLoadResource)); + ASSERT(ServiceName->Length); + DPRINT("IopLoadServiceModule(%wZ, 0x%p)\n", ServiceName, ModuleObject); + + if (ExpInTextModeSetup) + { + /* We have no registry, but luckily we know where all the drivers are */ + + /* ServiceStart < 4 is all that matters */ + ServiceStart = 0; + + /* IopNormalizeImagePath will do all of the work for us if we give it an empty string */ + RtlInitEmptyUnicodeString(&ServiceImagePath, NULL, 0); + } + else + { + /* Open CurrentControlSet */ + RtlInitUnicodeString(&CCSName, + L"\\Registry\\Machine\\SYSTEM\\CurrentControlSet\\Services"); + Status = IopOpenRegistryKeyEx(&CCSKey, NULL, &CCSName, KEY_READ); + if (!NT_SUCCESS(Status)) + { + DPRINT1("IopOpenRegistryKeyEx() failed for '%wZ' with Status %08X\n", + &CCSName, Status); + return Status; + } + + /* Open service key */ + Status = IopOpenRegistryKeyEx(&ServiceKey, CCSKey, ServiceName, KEY_READ); + if (!NT_SUCCESS(Status)) + { + DPRINT1("IopOpenRegistryKeyEx() failed for '%wZ' with Status %08X\n", + ServiceName, Status); + ZwClose(CCSKey); + return Status; + } + + /* + * Get information about the service. + */ + RtlZeroMemory(QueryTable, sizeof(QueryTable)); + + RtlInitUnicodeString(&ServiceImagePath, NULL); + + QueryTable[0].Name = L"Start"; + QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT; + QueryTable[0].EntryContext = &ServiceStart; + + QueryTable[1].Name = L"ImagePath"; + QueryTable[1].Flags = RTL_QUERY_REGISTRY_DIRECT; + QueryTable[1].EntryContext = &ServiceImagePath; + + Status = RtlQueryRegistryValues(RTL_REGISTRY_HANDLE, + (PWSTR)ServiceKey, + QueryTable, + NULL, + NULL); + + ZwClose(ServiceKey); + ZwClose(CCSKey); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("RtlQueryRegistryValues() failed (Status %x)\n", Status); + return Status; + } + } + + /* + * Normalize the image path for all later processing. + */ + Status = IopNormalizeImagePath(&ServiceImagePath, ServiceName); + + if (!NT_SUCCESS(Status)) + { + DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status); + return Status; + } + + /* + * Case for disabled drivers + */ + if (ServiceStart >= 4) + { + /* We can't load this */ + Status = STATUS_DRIVER_UNABLE_TO_LOAD; + } + else + { + DPRINT("Loading module from %wZ\n", &ServiceImagePath); + Status = MmLoadSystemImage(&ServiceImagePath, NULL, NULL, 0, (PVOID)ModuleObject, &BaseAddress); + if (NT_SUCCESS(Status)) + { + IopDisplayLoadingMessage(ServiceName); + } + } + + ExFreePool(ServiceImagePath.Buffer); + + /* + * Now check if the module was loaded successfully. + */ + if (!NT_SUCCESS(Status)) + { + DPRINT("Module loading failed (Status %x)\n", Status); + } + + DPRINT("Module loading (Status %x)\n", Status); + + return Status; } VOID @@ -456,83 +454,83 @@ MmFreeDriverInitialization(IN PLDR_DATA_TABLE_ENTRY LdrEntry); * On successful return this contains the driver object representing * the loaded driver. */ - -NTSTATUS FASTCALL +NTSTATUS +FASTCALL IopInitializeDriverModule( - IN PDEVICE_NODE DeviceNode, - IN PLDR_DATA_TABLE_ENTRY ModuleObject, - IN PUNICODE_STRING ServiceName, - IN BOOLEAN FileSystemDriver, - OUT PDRIVER_OBJECT *DriverObject) + IN PDEVICE_NODE DeviceNode, + IN PLDR_DATA_TABLE_ENTRY ModuleObject, + IN PUNICODE_STRING ServiceName, + IN BOOLEAN FileSystemDriver, + OUT PDRIVER_OBJECT *DriverObject) { - const WCHAR ServicesKeyName[] = L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\"; - WCHAR NameBuffer[MAX_PATH]; - UNICODE_STRING DriverName; - UNICODE_STRING RegistryKey; - PDRIVER_INITIALIZE DriverEntry; - PDRIVER_OBJECT Driver; - NTSTATUS Status; - - DriverEntry = ModuleObject->EntryPoint; - - if (ServiceName != NULL && ServiceName->Length != 0) - { - RegistryKey.Length = 0; - RegistryKey.MaximumLength = sizeof(ServicesKeyName) + ServiceName->Length; - RegistryKey.Buffer = ExAllocatePool(PagedPool, RegistryKey.MaximumLength); - if (RegistryKey.Buffer == NULL) - { - return STATUS_INSUFFICIENT_RESOURCES; - } - RtlAppendUnicodeToString(&RegistryKey, ServicesKeyName); - RtlAppendUnicodeStringToString(&RegistryKey, ServiceName); - } - else - { - RtlInitUnicodeString(&RegistryKey, NULL); - } - - /* Create ModuleName string */ - if (ServiceName && ServiceName->Length > 0) - { - if (FileSystemDriver == TRUE) - wcscpy(NameBuffer, FILESYSTEM_ROOT_NAME); - else - wcscpy(NameBuffer, DRIVER_ROOT_NAME); - - RtlInitUnicodeString(&DriverName, NameBuffer); - DriverName.MaximumLength = sizeof(NameBuffer); - - RtlAppendUnicodeStringToString(&DriverName, ServiceName); - - DPRINT("Driver name: '%wZ'\n", &DriverName); - } - else - DriverName.Length = 0; - - Status = IopCreateDriver( - DriverName.Length > 0 ? &DriverName : NULL, - DriverEntry, - &RegistryKey, - ModuleObject, - &Driver); - RtlFreeUnicodeString(&RegistryKey); - - *DriverObject = Driver; - if (!NT_SUCCESS(Status)) - { - DPRINT("IopCreateDriver() failed (Status 0x%08lx)\n", Status); - return Status; - } - - MmFreeDriverInitialization((PLDR_DATA_TABLE_ENTRY)Driver->DriverSection); - - /* Set the driver as initialized */ - IopReadyDeviceObjects(Driver); - - if (PnpSystemInit) IopReinitializeDrivers(); - - return STATUS_SUCCESS; + const WCHAR ServicesKeyName[] = L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\"; + WCHAR NameBuffer[MAX_PATH]; + UNICODE_STRING DriverName; + UNICODE_STRING RegistryKey; + PDRIVER_INITIALIZE DriverEntry; + PDRIVER_OBJECT Driver; + NTSTATUS Status; + + DriverEntry = ModuleObject->EntryPoint; + + if (ServiceName != NULL && ServiceName->Length != 0) + { + RegistryKey.Length = 0; + RegistryKey.MaximumLength = sizeof(ServicesKeyName) + ServiceName->Length; + RegistryKey.Buffer = ExAllocatePool(PagedPool, RegistryKey.MaximumLength); + if (RegistryKey.Buffer == NULL) + { + return STATUS_INSUFFICIENT_RESOURCES; + } + RtlAppendUnicodeToString(&RegistryKey, ServicesKeyName); + RtlAppendUnicodeStringToString(&RegistryKey, ServiceName); + } + else + { + RtlInitUnicodeString(&RegistryKey, NULL); + } + + /* Create ModuleName string */ + if (ServiceName && ServiceName->Length > 0) + { + if (FileSystemDriver != FALSE) + wcscpy(NameBuffer, FILESYSTEM_ROOT_NAME); + else + wcscpy(NameBuffer, DRIVER_ROOT_NAME); + + RtlInitUnicodeString(&DriverName, NameBuffer); + DriverName.MaximumLength = sizeof(NameBuffer); + + RtlAppendUnicodeStringToString(&DriverName, ServiceName); + + DPRINT("Driver name: '%wZ'\n", &DriverName); + } + else + DriverName.Length = 0; + + Status = IopCreateDriver(DriverName.Length > 0 ? &DriverName : NULL, + DriverEntry, + &RegistryKey, + ServiceName, + ModuleObject, + &Driver); + RtlFreeUnicodeString(&RegistryKey); + + *DriverObject = Driver; + if (!NT_SUCCESS(Status)) + { + DPRINT("IopCreateDriver() failed (Status 0x%08lx)\n", Status); + return Status; + } + + MmFreeDriverInitialization((PLDR_DATA_TABLE_ENTRY)Driver->DriverSection); + + /* Set the driver as initialized */ + IopReadyDeviceObjects(Driver); + + if (PnpSystemInit) IopReinitializeDrivers(); + + return STATUS_SUCCESS; } /* @@ -540,57 +538,80 @@ IopInitializeDriverModule( * * Internal routine used by IopAttachFilterDrivers. */ - -NTSTATUS NTAPI +NTSTATUS +NTAPI IopAttachFilterDriversCallback( - PWSTR ValueName, - ULONG ValueType, - PVOID ValueData, - ULONG ValueLength, - PVOID Context, - PVOID EntryContext) + PWSTR ValueName, + ULONG ValueType, + PVOID ValueData, + ULONG ValueLength, + PVOID Context, + PVOID EntryContext) { - PDEVICE_NODE DeviceNode = Context; - UNICODE_STRING ServiceName; - PWCHAR Filters; - PLDR_DATA_TABLE_ENTRY ModuleObject; - PDRIVER_OBJECT DriverObject; - NTSTATUS Status; - - for (Filters = ValueData; - ((ULONG_PTR)Filters - (ULONG_PTR)ValueData) < ValueLength && - *Filters != 0; - Filters += (ServiceName.Length / sizeof(WCHAR)) + 1) - { - DPRINT("Filter Driver: %S (%wZ)\n", Filters, &DeviceNode->InstancePath); - - ServiceName.Buffer = Filters; - ServiceName.MaximumLength = - ServiceName.Length = (USHORT)wcslen(Filters) * sizeof(WCHAR); - - Status = IopGetDriverObject(&DriverObject, - &ServiceName, - FALSE); - if (!NT_SUCCESS(Status)) - { - /* Load and initialize the filter driver */ - Status = IopLoadServiceModule(&ServiceName, &ModuleObject); - if (!NT_SUCCESS(Status)) - continue; - - Status = IopInitializeDriverModule(DeviceNode, ModuleObject, &ServiceName, - FALSE, &DriverObject); - if (!NT_SUCCESS(Status)) - continue; - } - - Status = IopInitializeDevice(DeviceNode, DriverObject); - - /* Remove extra reference */ - ObDereferenceObject(DriverObject); - } - - return STATUS_SUCCESS; + PDEVICE_NODE DeviceNode = Context; + UNICODE_STRING ServiceName; + PWCHAR Filters; + PLDR_DATA_TABLE_ENTRY ModuleObject; + PDRIVER_OBJECT DriverObject; + NTSTATUS Status; + + /* No filter value present */ + if (ValueType == REG_NONE) + return STATUS_SUCCESS; + + for (Filters = ValueData; + ((ULONG_PTR)Filters - (ULONG_PTR)ValueData) < ValueLength && + *Filters != 0; + Filters += (ServiceName.Length / sizeof(WCHAR)) + 1) + { + DPRINT("Filter Driver: %S (%wZ)\n", Filters, &DeviceNode->InstancePath); + + ServiceName.Buffer = Filters; + ServiceName.MaximumLength = + ServiceName.Length = (USHORT)wcslen(Filters) * sizeof(WCHAR); + + KeEnterCriticalRegion(); + ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE); + Status = IopGetDriverObject(&DriverObject, + &ServiceName, + FALSE); + if (!NT_SUCCESS(Status)) + { + /* Load and initialize the filter driver */ + Status = IopLoadServiceModule(&ServiceName, &ModuleObject); + if (!NT_SUCCESS(Status)) + { + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + return Status; + } + + Status = IopInitializeDriverModule(DeviceNode, + ModuleObject, + &ServiceName, + FALSE, + &DriverObject); + if (!NT_SUCCESS(Status)) + { + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + return Status; + } + } + + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + + Status = IopInitializeDevice(DeviceNode, DriverObject); + + /* Remove extra reference */ + ObDereferenceObject(DriverObject); + + if (!NT_SUCCESS(Status)) + return Status; + } + + return STATUS_SUCCESS; } /* @@ -603,123 +624,148 @@ IopAttachFilterDriversCallback( * Set to TRUE for loading lower level filters or FALSE for upper * level filters. */ - -NTSTATUS FASTCALL +NTSTATUS +FASTCALL IopAttachFilterDrivers( - PDEVICE_NODE DeviceNode, - BOOLEAN Lower) + PDEVICE_NODE DeviceNode, + BOOLEAN Lower) { - RTL_QUERY_REGISTRY_TABLE QueryTable[2] = { { NULL, 0, NULL, NULL, 0, NULL, 0 }, }; - UNICODE_STRING Class; - WCHAR ClassBuffer[40]; - UNICODE_STRING EnumRoot = RTL_CONSTANT_STRING(ENUM_ROOT); - HANDLE EnumRootKey, SubKey; - NTSTATUS Status; - - /* Open enumeration root key */ - Status = IopOpenRegistryKeyEx(&EnumRootKey, NULL, - &EnumRoot, KEY_READ); - if (!NT_SUCCESS(Status)) - { - DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); - return Status; - } - - /* Open subkey */ - Status = IopOpenRegistryKeyEx(&SubKey, EnumRootKey, - &DeviceNode->InstancePath, KEY_READ); - if (!NT_SUCCESS(Status)) - { - DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); - ZwClose(EnumRootKey); - return Status; - } - - /* - * First load the device filters - */ - QueryTable[0].QueryRoutine = IopAttachFilterDriversCallback; - if (Lower) - QueryTable[0].Name = L"LowerFilters"; - else - QueryTable[0].Name = L"UpperFilters"; - QueryTable[0].Flags = RTL_QUERY_REGISTRY_REQUIRED; - - RtlQueryRegistryValues( - RTL_REGISTRY_HANDLE, - (PWSTR)SubKey, - QueryTable, - DeviceNode, - NULL); - - /* - * Now get the class GUID - */ - Class.Length = 0; - Class.MaximumLength = 40 * sizeof(WCHAR); - Class.Buffer = ClassBuffer; - QueryTable[0].QueryRoutine = NULL; - QueryTable[0].Name = L"ClassGUID"; - QueryTable[0].EntryContext = &Class; - QueryTable[0].Flags = RTL_QUERY_REGISTRY_REQUIRED | RTL_QUERY_REGISTRY_DIRECT; - - Status = RtlQueryRegistryValues( - RTL_REGISTRY_HANDLE, - (PWSTR)SubKey, - QueryTable, - DeviceNode, - NULL); - - /* Close handles */ - ZwClose(SubKey); - ZwClose(EnumRootKey); - - /* - * Load the class filter driver - */ - if (NT_SUCCESS(Status)) - { - UNICODE_STRING ControlClass = RTL_CONSTANT_STRING(L"\\Registry\\Machine\\System\\CurrentControlSet\\Control\\Class"); - - Status = IopOpenRegistryKeyEx(&EnumRootKey, NULL, - &ControlClass, KEY_READ); - if (!NT_SUCCESS(Status)) - { - DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); - return Status; - } - - /* Open subkey */ - Status = IopOpenRegistryKeyEx(&SubKey, EnumRootKey, - &Class, KEY_READ); - if (!NT_SUCCESS(Status)) - { - DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); - ZwClose(EnumRootKey); - return Status; - } - - QueryTable[0].QueryRoutine = IopAttachFilterDriversCallback; - if (Lower) - QueryTable[0].Name = L"LowerFilters"; - else - QueryTable[0].Name = L"UpperFilters"; - QueryTable[0].EntryContext = NULL; - QueryTable[0].Flags = RTL_QUERY_REGISTRY_REQUIRED; - - RtlQueryRegistryValues( - RTL_REGISTRY_HANDLE, - (PWSTR)SubKey, - QueryTable, - DeviceNode, - NULL); - - /* Clean up */ - ZwClose(SubKey); - ZwClose(EnumRootKey); - } - - return STATUS_SUCCESS; + RTL_QUERY_REGISTRY_TABLE QueryTable[2] = { { NULL, 0, NULL, NULL, 0, NULL, 0 }, }; + UNICODE_STRING Class; + WCHAR ClassBuffer[40]; + UNICODE_STRING EnumRoot = RTL_CONSTANT_STRING(ENUM_ROOT); + HANDLE EnumRootKey, SubKey; + NTSTATUS Status; + + /* Open enumeration root key */ + Status = IopOpenRegistryKeyEx(&EnumRootKey, + NULL, + &EnumRoot, + KEY_READ); + if (!NT_SUCCESS(Status)) + { + DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); + return Status; + } + + /* Open subkey */ + Status = IopOpenRegistryKeyEx(&SubKey, + EnumRootKey, + &DeviceNode->InstancePath, + KEY_READ); + if (!NT_SUCCESS(Status)) + { + DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); + ZwClose(EnumRootKey); + return Status; + } + + /* + * First load the device filters + */ + QueryTable[0].QueryRoutine = IopAttachFilterDriversCallback; + if (Lower) + QueryTable[0].Name = L"LowerFilters"; + else + QueryTable[0].Name = L"UpperFilters"; + QueryTable[0].Flags = 0; + QueryTable[0].DefaultType = REG_NONE; + + Status = RtlQueryRegistryValues(RTL_REGISTRY_HANDLE, + (PWSTR)SubKey, + QueryTable, + DeviceNode, + NULL); + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to load device %s filters: %08X\n", + Lower ? "lower" : "upper", Status); + ZwClose(SubKey); + ZwClose(EnumRootKey); + return Status; + } + + /* + * Now get the class GUID + */ + Class.Length = 0; + Class.MaximumLength = 40 * sizeof(WCHAR); + Class.Buffer = ClassBuffer; + QueryTable[0].QueryRoutine = NULL; + QueryTable[0].Name = L"ClassGUID"; + QueryTable[0].EntryContext = &Class; + QueryTable[0].Flags = RTL_QUERY_REGISTRY_REQUIRED | RTL_QUERY_REGISTRY_DIRECT; + + Status = RtlQueryRegistryValues(RTL_REGISTRY_HANDLE, + (PWSTR)SubKey, + QueryTable, + DeviceNode, + NULL); + + /* Close handles */ + ZwClose(SubKey); + ZwClose(EnumRootKey); + + /* + * Load the class filter driver + */ + if (NT_SUCCESS(Status)) + { + UNICODE_STRING ControlClass = RTL_CONSTANT_STRING(L"\\Registry\\Machine\\System\\CurrentControlSet\\Control\\Class"); + + Status = IopOpenRegistryKeyEx(&EnumRootKey, + NULL, + &ControlClass, + KEY_READ); + if (!NT_SUCCESS(Status)) + { + DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); + return Status; + } + + /* Open subkey */ + Status = IopOpenRegistryKeyEx(&SubKey, + EnumRootKey, + &Class, + KEY_READ); + if (!NT_SUCCESS(Status)) + { + /* It's okay if there's no class key */ + DPRINT1("ZwOpenKey() failed with Status %08X\n", Status); + ZwClose(EnumRootKey); + return STATUS_SUCCESS; + } + + QueryTable[0].QueryRoutine = IopAttachFilterDriversCallback; + if (Lower) + QueryTable[0].Name = L"LowerFilters"; + else + QueryTable[0].Name = L"UpperFilters"; + QueryTable[0].EntryContext = NULL; + QueryTable[0].Flags = 0; + QueryTable[0].DefaultType = REG_NONE; + + Status = RtlQueryRegistryValues(RTL_REGISTRY_HANDLE, + (PWSTR)SubKey, + QueryTable, + DeviceNode, + NULL); + + /* Clean up */ + ZwClose(SubKey); + ZwClose(EnumRootKey); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to load class %s filters: %08X\n", + Lower ? "lower" : "upper", Status); + ZwClose(SubKey); + ZwClose(EnumRootKey); + return Status; + } + } + + return STATUS_SUCCESS; } NTSTATUS @@ -814,7 +860,6 @@ LdrProcessDriverModule(PLDR_DATA_TABLE_ENTRY LdrEntry, * * Initialize a driver that is already loaded in memory. */ - NTSTATUS NTAPI INIT_FUNCTION @@ -829,86 +874,97 @@ IopInitializeBuiltinDriver(IN PLDR_DATA_TABLE_ENTRY BootLdrEntry) PLDR_DATA_TABLE_ENTRY LdrEntry; PLIST_ENTRY NextEntry; UNICODE_STRING ServiceName; + BOOLEAN Success; - /* - * Display 'Loading XXX...' message - */ - IopDisplayLoadingMessage(ModuleName); - InbvIndicateProgress(); - - /* - * Generate filename without path (not needed by freeldr) - */ - FileNameWithoutPath = wcsrchr(ModuleName->Buffer, L'\\'); - if (FileNameWithoutPath == NULL) - { - FileNameWithoutPath = ModuleName->Buffer; - } - else - { - FileNameWithoutPath++; - } - - /* - * Strip the file extension from ServiceName - */ - RtlCreateUnicodeString(&ServiceName, FileNameWithoutPath); - FileExtension = wcsrchr(ServiceName.Buffer, '.'); - if (FileExtension != NULL) - { - ServiceName.Length -= (USHORT)wcslen(FileExtension) * sizeof(WCHAR); - FileExtension[0] = 0; - } - - /* - * Determine the right device object - */ - /* Use IopRootDeviceNode for now */ - Status = IopCreateDeviceNode(IopRootDeviceNode, NULL, &ServiceName, &DeviceNode); - if (!NT_SUCCESS(Status)) - { - DPRINT1("Driver '%wZ' load failed, status (%x)\n", ModuleName, Status); - return(Status); - } - - /* Lookup the new Ldr entry in PsLoadedModuleList */ - NextEntry = PsLoadedModuleList.Flink; - while (NextEntry != &PsLoadedModuleList) - { - LdrEntry = CONTAINING_RECORD(NextEntry, - LDR_DATA_TABLE_ENTRY, - InLoadOrderLinks); - if (RtlEqualUnicodeString(ModuleName, &LdrEntry->BaseDllName, TRUE)) - { - break; - } - - NextEntry = NextEntry->Flink; - } - NT_ASSERT(NextEntry != &PsLoadedModuleList); - - /* - * Initialize the driver - */ - Status = IopInitializeDriverModule(DeviceNode, LdrEntry, - &DeviceNode->ServiceName, FALSE, &DriverObject); - - if (!NT_SUCCESS(Status)) - { - IopFreeDeviceNode(DeviceNode); - return Status; - } - - Status = IopInitializeDevice(DeviceNode, DriverObject); - if (NT_SUCCESS(Status)) - { - Status = IopStartDevice(DeviceNode); - } - - /* Remove extra reference from IopInitializeDriverModule */ - ObDereferenceObject(DriverObject); - - return Status; + /* + * Display 'Loading XXX...' message + */ + IopDisplayLoadingMessage(ModuleName); + InbvIndicateProgress(); + + /* + * Generate filename without path (not needed by freeldr) + */ + FileNameWithoutPath = wcsrchr(ModuleName->Buffer, L'\\'); + if (FileNameWithoutPath == NULL) + { + FileNameWithoutPath = ModuleName->Buffer; + } + else + { + FileNameWithoutPath++; + } + + /* + * Strip the file extension from ServiceName + */ + Success = RtlCreateUnicodeString(&ServiceName, FileNameWithoutPath); + if (!Success) + { + return STATUS_INSUFFICIENT_RESOURCES; + } + + FileExtension = wcsrchr(ServiceName.Buffer, '.'); + if (FileExtension != NULL) + { + ServiceName.Length -= (USHORT)wcslen(FileExtension) * sizeof(WCHAR); + FileExtension[0] = 0; + } + + /* + * Determine the right device object + */ + /* Use IopRootDeviceNode for now */ + Status = IopCreateDeviceNode(IopRootDeviceNode, + NULL, + &ServiceName, + &DeviceNode); + if (!NT_SUCCESS(Status)) + { + DPRINT1("Driver '%wZ' load failed, status (%x)\n", ModuleName, Status); + return(Status); + } + + /* Lookup the new Ldr entry in PsLoadedModuleList */ + NextEntry = PsLoadedModuleList.Flink; + while (NextEntry != &PsLoadedModuleList) + { + LdrEntry = CONTAINING_RECORD(NextEntry, + LDR_DATA_TABLE_ENTRY, + InLoadOrderLinks); + if (RtlEqualUnicodeString(ModuleName, &LdrEntry->BaseDllName, TRUE)) + { + break; + } + + NextEntry = NextEntry->Flink; + } + ASSERT(NextEntry != &PsLoadedModuleList); + + /* + * Initialize the driver + */ + Status = IopInitializeDriverModule(DeviceNode, + LdrEntry, + &DeviceNode->ServiceName, + FALSE, + &DriverObject); + + if (!NT_SUCCESS(Status)) + { + return Status; + } + + Status = IopInitializeDevice(DeviceNode, DriverObject); + if (NT_SUCCESS(Status)) + { + Status = IopStartDevice(DeviceNode); + } + + /* Remove extra reference from IopInitializeDriverModule */ + ObDereferenceObject(DriverObject); + + return Status; } /* @@ -959,7 +1015,6 @@ IopInitializeBootDrivers(VOID) if (!NT_SUCCESS(Status)) { /* Fail */ - IopFreeDeviceNode(DeviceNode); return; } @@ -968,7 +1023,6 @@ IopInitializeBootDrivers(VOID) if (!NT_SUCCESS(Status)) { /* Fail */ - IopFreeDeviceNode(DeviceNode); ObDereferenceObject(DriverObject); return; } @@ -978,7 +1032,6 @@ IopInitializeBootDrivers(VOID) if (!NT_SUCCESS(Status)) { /* Fail */ - IopFreeDeviceNode(DeviceNode); ObDereferenceObject(DriverObject); return; } @@ -1173,204 +1226,180 @@ IopInitializeSystemDrivers(VOID) NTSTATUS NTAPI IopUnloadDriver(PUNICODE_STRING DriverServiceName, BOOLEAN UnloadPnpDrivers) { - RTL_QUERY_REGISTRY_TABLE QueryTable[2]; - UNICODE_STRING ImagePath; - UNICODE_STRING ServiceName; - UNICODE_STRING ObjectName; - PDRIVER_OBJECT DriverObject; - PDEVICE_OBJECT DeviceObject; - PEXTENDED_DEVOBJ_EXTENSION DeviceExtension; - LOAD_UNLOAD_PARAMS LoadParams; - NTSTATUS Status; - LPWSTR Start; - BOOLEAN SafeToUnload = TRUE; - - DPRINT("IopUnloadDriver('%wZ', %u)\n", DriverServiceName, UnloadPnpDrivers); - - PAGED_CODE(); - - /* - * Get the service name from the registry key name - */ - - Start = wcsrchr(DriverServiceName->Buffer, L'\\'); - if (Start == NULL) - Start = DriverServiceName->Buffer; - else - Start++; - - RtlInitUnicodeString(&ServiceName, Start); - - /* - * Construct the driver object name - */ - - ObjectName.Length = ((USHORT)wcslen(Start) + 8) * sizeof(WCHAR); - ObjectName.MaximumLength = ObjectName.Length + sizeof(WCHAR); - ObjectName.Buffer = ExAllocatePool(PagedPool, ObjectName.MaximumLength); - if (!ObjectName.Buffer) return STATUS_INSUFFICIENT_RESOURCES; - wcscpy(ObjectName.Buffer, DRIVER_ROOT_NAME); - memcpy(ObjectName.Buffer + 8, Start, ObjectName.Length - 8 * sizeof(WCHAR)); - ObjectName.Buffer[ObjectName.Length/sizeof(WCHAR)] = 0; - - /* - * Find the driver object - */ - Status = ObReferenceObjectByName(&ObjectName, - 0, - 0, - 0, - IoDriverObjectType, - KernelMode, - 0, - (PVOID*)&DriverObject); - - if (!NT_SUCCESS(Status)) - { - DPRINT1("Can't locate driver object for %wZ\n", &ObjectName); - ExFreePool(ObjectName.Buffer); - return Status; - } - - /* - * Free the buffer for driver object name - */ - ExFreePool(ObjectName.Buffer); - - /* Check that driver is not already unloading */ - if (DriverObject->Flags & DRVO_UNLOAD_INVOKED) - { - DPRINT1("Driver deletion pending\n"); - ObDereferenceObject(DriverObject); - return STATUS_DELETE_PENDING; - } - - /* - * Get path of service... - */ - - RtlZeroMemory(QueryTable, sizeof(QueryTable)); - - RtlInitUnicodeString(&ImagePath, NULL); - - QueryTable[0].Name = L"ImagePath"; - QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT; - QueryTable[0].EntryContext = &ImagePath; - - Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE, - DriverServiceName->Buffer, QueryTable, NULL, NULL); - - if (!NT_SUCCESS(Status)) - { - DPRINT1("RtlQueryRegistryValues() failed (Status %x)\n", Status); - ObDereferenceObject(DriverObject); - return Status; - } - - /* - * Normalize the image path for all later processing. - */ - - Status = IopNormalizeImagePath(&ImagePath, &ServiceName); - - if (!NT_SUCCESS(Status)) - { - DPRINT1("IopNormalizeImagePath() failed (Status %x)\n", Status); - ObDereferenceObject(DriverObject); - return Status; - } - - /* - * Free the service path - */ - - ExFreePool(ImagePath.Buffer); - - /* - * Unload the module and release the references to the device object - */ + RTL_QUERY_REGISTRY_TABLE QueryTable[2]; + UNICODE_STRING ImagePath; + UNICODE_STRING ServiceName; + UNICODE_STRING ObjectName; + PDRIVER_OBJECT DriverObject; + PDEVICE_OBJECT DeviceObject; + PEXTENDED_DEVOBJ_EXTENSION DeviceExtension; + NTSTATUS Status; + LPWSTR Start; + BOOLEAN SafeToUnload = TRUE; + + DPRINT("IopUnloadDriver('%wZ', %u)\n", DriverServiceName, UnloadPnpDrivers); + + PAGED_CODE(); + + /* + * Get the service name from the registry key name + */ + + Start = wcsrchr(DriverServiceName->Buffer, L'\\'); + if (Start == NULL) + Start = DriverServiceName->Buffer; + else + Start++; + + RtlInitUnicodeString(&ServiceName, Start); + + /* + * Construct the driver object name + */ + + ObjectName.Length = ((USHORT)wcslen(Start) + 8) * sizeof(WCHAR); + ObjectName.MaximumLength = ObjectName.Length + sizeof(WCHAR); + ObjectName.Buffer = ExAllocatePool(PagedPool, ObjectName.MaximumLength); + if (!ObjectName.Buffer) return STATUS_INSUFFICIENT_RESOURCES; + wcscpy(ObjectName.Buffer, DRIVER_ROOT_NAME); + memcpy(ObjectName.Buffer + 8, Start, ObjectName.Length - 8 * sizeof(WCHAR)); + ObjectName.Buffer[ObjectName.Length/sizeof(WCHAR)] = 0; + + /* + * Find the driver object + */ + Status = ObReferenceObjectByName(&ObjectName, + 0, + 0, + 0, + IoDriverObjectType, + KernelMode, + 0, + (PVOID*)&DriverObject); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("Can't locate driver object for %wZ\n", &ObjectName); + ExFreePool(ObjectName.Buffer); + return Status; + } + + /* + * Free the buffer for driver object name + */ + ExFreePool(ObjectName.Buffer); + + /* Check that driver is not already unloading */ + if (DriverObject->Flags & DRVO_UNLOAD_INVOKED) + { + DPRINT1("Driver deletion pending\n"); + ObDereferenceObject(DriverObject); + return STATUS_DELETE_PENDING; + } + + /* + * Get path of service... + */ + RtlZeroMemory(QueryTable, sizeof(QueryTable)); + + RtlInitUnicodeString(&ImagePath, NULL); + + QueryTable[0].Name = L"ImagePath"; + QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT; + QueryTable[0].EntryContext = &ImagePath; + + Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE, + DriverServiceName->Buffer, + QueryTable, + NULL, + NULL); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("RtlQueryRegistryValues() failed (Status %x)\n", Status); + ObDereferenceObject(DriverObject); + return Status; + } + + /* + * Normalize the image path for all later processing. + */ + Status = IopNormalizeImagePath(&ImagePath, &ServiceName); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("IopNormalizeImagePath() failed (Status %x)\n", Status); + ObDereferenceObject(DriverObject); + return Status; + } + + /* + * Free the service path + */ + ExFreePool(ImagePath.Buffer); + + /* + * Unload the module and release the references to the device object + */ /* Call the load/unload routine, depending on current process */ - if (DriverObject->DriverUnload && DriverObject->DriverSection && - (UnloadPnpDrivers || (DriverObject->Flags & DRVO_LEGACY_DRIVER))) - { - /* Loop through each device object of the driver - and set DOE_UNLOAD_PENDING flag */ - DeviceObject = DriverObject->DeviceObject; - while (DeviceObject) - { - /* Set the unload pending flag for the device */ - DeviceExtension = IoGetDevObjExtension(DeviceObject); - DeviceExtension->ExtensionFlags |= DOE_UNLOAD_PENDING; - - /* Make sure there are no attached devices or no reference counts */ - if ((DeviceObject->ReferenceCount) || (DeviceObject->AttachedDevice)) - { - /* Not safe to unload */ - DPRINT1("Drivers device object is referenced or has attached devices\n"); - - SafeToUnload = FALSE; - } - - DeviceObject = DeviceObject->NextDevice; - } - - /* If not safe to unload, then return success */ - if (!SafeToUnload) - { - ObDereferenceObject(DriverObject); - return STATUS_SUCCESS; - } - - DPRINT1("Unloading driver '%wZ' (manual)\n", &DriverObject->DriverName); - - /* Set the unload invoked flag */ - DriverObject->Flags |= DRVO_UNLOAD_INVOKED; - - if (PsGetCurrentProcess() == PsInitialSystemProcess) - { - /* Just call right away */ - (*DriverObject->DriverUnload)(DriverObject); - } - else - { - /* Load/Unload must be called from system process */ - - /* Prepare parameters block */ - LoadParams.DriverObject = DriverObject; - KeInitializeEvent(&LoadParams.Event, NotificationEvent, FALSE); - - ExInitializeWorkItem(&LoadParams.WorkItem, - (PWORKER_THREAD_ROUTINE)IopLoadUnloadDriver, - (PVOID)&LoadParams); - - /* Queue it */ - ExQueueWorkItem(&LoadParams.WorkItem, DelayedWorkQueue); - - /* And wait when it completes */ - KeWaitForSingleObject(&LoadParams.Event, UserRequest, KernelMode, - FALSE, NULL); - } - - /* Mark the driver object temporary, so it could be deleted later */ - ObMakeTemporaryObject(DriverObject); - - /* Dereference it 2 times */ - ObDereferenceObject(DriverObject); - ObDereferenceObject(DriverObject); - - return STATUS_SUCCESS; - } - else - { - DPRINT1("No DriverUnload function! '%wZ' will not be unloaded!\n", &DriverObject->DriverName); - - /* Dereference one time (refd inside this function) */ - ObDereferenceObject(DriverObject); - - /* Return unloading failure */ - return STATUS_INVALID_DEVICE_REQUEST; - } + if (DriverObject->DriverUnload && DriverObject->DriverSection && + (UnloadPnpDrivers || (DriverObject->Flags & DRVO_LEGACY_DRIVER))) + { + /* Loop through each device object of the driver + and set DOE_UNLOAD_PENDING flag */ + DeviceObject = DriverObject->DeviceObject; + while (DeviceObject) + { + /* Set the unload pending flag for the device */ + DeviceExtension = IoGetDevObjExtension(DeviceObject); + DeviceExtension->ExtensionFlags |= DOE_UNLOAD_PENDING; + + /* Make sure there are no attached devices or no reference counts */ + if ((DeviceObject->ReferenceCount) || (DeviceObject->AttachedDevice)) + { + /* Not safe to unload */ + DPRINT1("Drivers device object is referenced or has attached devices\n"); + + SafeToUnload = FALSE; + } + + DeviceObject = DeviceObject->NextDevice; + } + + /* If not safe to unload, then return success */ + if (!SafeToUnload) + { + ObDereferenceObject(DriverObject); + return STATUS_SUCCESS; + } + + DPRINT1("Unloading driver '%wZ' (manual)\n", &DriverObject->DriverName); + + /* Set the unload invoked flag and call the unload routine */ + DriverObject->Flags |= DRVO_UNLOAD_INVOKED; + Status = IopLoadUnloadDriver(NULL, &DriverObject); + ASSERT(Status == STATUS_SUCCESS); + + /* Mark the driver object temporary, so it could be deleted later */ + ObMakeTemporaryObject(DriverObject); + + /* Dereference it 2 times */ + ObDereferenceObject(DriverObject); + ObDereferenceObject(DriverObject); + + return Status; + } + else + { + DPRINT1("No DriverUnload function! '%wZ' will not be unloaded!\n", &DriverObject->DriverName); + + /* Dereference one time (refd inside this function) */ + ObDereferenceObject(DriverObject); + + /* Return unloading failure */ + return STATUS_INVALID_DEVICE_REQUEST; + } } VOID @@ -1450,6 +1479,7 @@ NTAPI IopCreateDriver(IN PUNICODE_STRING DriverName OPTIONAL, IN PDRIVER_INITIALIZE InitializationFunction, IN PUNICODE_STRING RegistryPath, + IN PCUNICODE_STRING ServiceName, PLDR_DATA_TABLE_ENTRY ModuleObject, OUT PDRIVER_OBJECT *pDriverObject) { @@ -1471,7 +1501,7 @@ try_again: /* Create a random name and set up the string*/ NameLength = (USHORT)swprintf(NameBuffer, DRIVER_ROOT_NAME L"%08u", - KeTickCount); + KeTickCount.LowPart); LocalDriverName.Length = NameLength * sizeof(WCHAR); LocalDriverName.MaximumLength = LocalDriverName.Length + sizeof(UNICODE_NULL); LocalDriverName.Buffer = NameBuffer; @@ -1486,7 +1516,7 @@ try_again: ObjectSize = sizeof(DRIVER_OBJECT) + sizeof(EXTENDED_DRIVER_EXTENSION); InitializeObjectAttributes(&ObjectAttributes, &LocalDriverName, - OBJ_PERMANENT | OBJ_CASE_INSENSITIVE, + OBJ_PERMANENT | OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE, NULL, NULL); @@ -1521,9 +1551,9 @@ try_again: } /* Set up the service key name buffer */ - ServiceKeyName.Buffer = ExAllocatePoolWithTag(PagedPool, - LocalDriverName.Length + - sizeof(WCHAR), + ServiceKeyName.MaximumLength = ServiceName->Length + sizeof(UNICODE_NULL); + ServiceKeyName.Buffer = ExAllocatePoolWithTag(NonPagedPool, + ServiceKeyName.MaximumLength, TAG_IO); if (!ServiceKeyName.Buffer) { @@ -1533,21 +1563,26 @@ try_again: return STATUS_INSUFFICIENT_RESOURCES; } - /* Fill out the key data and copy the buffer */ - ServiceKeyName.Length = LocalDriverName.Length; - ServiceKeyName.MaximumLength = LocalDriverName.MaximumLength; - RtlCopyMemory(ServiceKeyName.Buffer, - LocalDriverName.Buffer, - LocalDriverName.Length); - - /* Null-terminate it and set it */ - ServiceKeyName.Buffer[ServiceKeyName.Length / sizeof(WCHAR)] = UNICODE_NULL; + /* Copy the name and set it in the driver extension */ + RtlCopyUnicodeString(&ServiceKeyName, + ServiceName); DriverObject->DriverExtension->ServiceKeyName = ServiceKeyName; - /* Also store it in the Driver Object. This is a bit of a hack. */ - RtlCopyMemory(&DriverObject->DriverName, - &ServiceKeyName, - sizeof(UNICODE_STRING)); + /* Make a copy of the driver name to store in the driver object */ + DriverObject->DriverName.MaximumLength = LocalDriverName.Length; + DriverObject->DriverName.Buffer = ExAllocatePoolWithTag(PagedPool, + DriverObject->DriverName.MaximumLength, + TAG_IO); + if (!DriverObject->DriverName.Buffer) + { + /* Fail */ + ObMakeTemporaryObject(DriverObject); + ObDereferenceObject(DriverObject); + return STATUS_INSUFFICIENT_RESOURCES; + } + + RtlCopyUnicodeString(&DriverObject->DriverName, + &LocalDriverName); /* Add the Object and get its handle */ Status = ObInsertObject(DriverObject, @@ -1648,8 +1683,8 @@ NTAPI IoCreateDriver(IN PUNICODE_STRING DriverName OPTIONAL, IN PDRIVER_INITIALIZE InitializationFunction) { - PDRIVER_OBJECT DriverObject; - return IopCreateDriver(DriverName, InitializationFunction, NULL, NULL, &DriverObject); + PDRIVER_OBJECT DriverObject; + return IopCreateDriver(DriverName, InitializationFunction, NULL, DriverName, NULL, &DriverObject); } /* @@ -1837,8 +1872,24 @@ IoGetDriverObjectExtension(IN PDRIVER_OBJECT DriverObject, return DriverExtensions + 1; } -VOID NTAPI -IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams) +VOID +NTAPI +IopLoadUnloadDriverWorker( + _Inout_ PVOID Parameter) +{ + PLOAD_UNLOAD_PARAMS LoadParams = Parameter; + + ASSERT(PsGetCurrentProcess() == PsInitialSystemProcess); + LoadParams->Status = IopLoadUnloadDriver(LoadParams->RegistryPath, + &LoadParams->DriverObject); + KeSetEvent(&LoadParams->Event, 0, FALSE); +} + +NTSTATUS +NTAPI +IopLoadUnloadDriver( + _In_opt_ PCUNICODE_STRING RegistryPath, + _Inout_ PDRIVER_OBJECT *DriverObject) { RTL_QUERY_REGISTRY_TABLE QueryTable[3]; UNICODE_STRING ImagePath; @@ -1846,20 +1897,40 @@ IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams) NTSTATUS Status; ULONG Type; PDEVICE_NODE DeviceNode; - PDRIVER_OBJECT DriverObject; PLDR_DATA_TABLE_ENTRY ModuleObject; PVOID BaseAddress; WCHAR *cur; - /* Check if it's an unload request */ - if (LoadParams->DriverObject) + /* Load/Unload must be called from system process */ + if (PsGetCurrentProcess() != PsInitialSystemProcess) { - (*LoadParams->DriverObject->DriverUnload)(LoadParams->DriverObject); + LOAD_UNLOAD_PARAMS LoadParams; - /* Return success and signal the event */ - LoadParams->Status = STATUS_SUCCESS; - KeSetEvent(&LoadParams->Event, 0, FALSE); - return; + /* Prepare parameters block */ + LoadParams.RegistryPath = RegistryPath; + LoadParams.DriverObject = *DriverObject; + KeInitializeEvent(&LoadParams.Event, NotificationEvent, FALSE); + + /* Initialize and queue a work item */ + ExInitializeWorkItem(&LoadParams.WorkItem, + IopLoadUnloadDriverWorker, + &LoadParams); + ExQueueWorkItem(&LoadParams.WorkItem, DelayedWorkQueue); + + /* And wait till it completes */ + KeWaitForSingleObject(&LoadParams.Event, + UserRequest, + KernelMode, + FALSE, + NULL); + return LoadParams.Status; + } + + /* Check if it's an unload request */ + if (*DriverObject) + { + (*DriverObject)->DriverUnload(*DriverObject); + return STATUS_SUCCESS; } RtlInitUnicodeString(&ImagePath, NULL); @@ -1867,28 +1938,26 @@ IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams) /* * Get the service name from the registry key name. */ - ASSERT(LoadParams->ServiceName->Length >= sizeof(WCHAR)); + ASSERT(RegistryPath->Length >= sizeof(WCHAR)); - ServiceName = *LoadParams->ServiceName; - cur = LoadParams->ServiceName->Buffer + - (LoadParams->ServiceName->Length / sizeof(WCHAR)) - 1; - while (LoadParams->ServiceName->Buffer != cur) + ServiceName = *RegistryPath; + cur = RegistryPath->Buffer + RegistryPath->Length / sizeof(WCHAR) - 1; + while (RegistryPath->Buffer != cur) { if (*cur == L'\\') { ServiceName.Buffer = cur + 1; - ServiceName.Length = LoadParams->ServiceName->Length - + ServiceName.Length = RegistryPath->Length - (USHORT)((ULONG_PTR)ServiceName.Buffer - - (ULONG_PTR)LoadParams->ServiceName->Buffer); + (ULONG_PTR)RegistryPath->Buffer); break; } cur--; } /* - * Get service type. - */ - + * Get service type. + */ RtlZeroMemory(&QueryTable, sizeof(QueryTable)); RtlInitUnicodeString(&ImagePath, NULL); @@ -1902,38 +1971,35 @@ IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams) QueryTable[1].EntryContext = &ImagePath; Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE, - LoadParams->ServiceName->Buffer, + RegistryPath->Buffer, QueryTable, NULL, NULL); if (!NT_SUCCESS(Status)) { DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status); if (ImagePath.Buffer) ExFreePool(ImagePath.Buffer); - LoadParams->Status = Status; - KeSetEvent(&LoadParams->Event, 0, FALSE); - return; + return Status; } /* - * Normalize the image path for all later processing. - */ - + * Normalize the image path for all later processing. + */ Status = IopNormalizeImagePath(&ImagePath, &ServiceName); if (!NT_SUCCESS(Status)) { DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status); - LoadParams->Status = Status; - KeSetEvent(&LoadParams->Event, 0, FALSE); - return; + return Status; } DPRINT("FullImagePath: '%wZ'\n", &ImagePath); DPRINT("Type: %lx\n", Type); + KeEnterCriticalRegion(); + ExAcquireResourceExclusiveLite(&IopDriverLoadResource, TRUE); /* * Get existing DriverObject pointer (in case the driver * has already been loaded and initialized). */ - Status = IopGetDriverObject(&DriverObject, + Status = IopGetDriverObject(DriverObject, &ServiceName, (Type == 2 /* SERVICE_FILE_SYSTEM_DRIVER */ || Type == 8 /* SERVICE_RECOGNIZER_DRIVER */)); @@ -1943,15 +2009,14 @@ IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams) /* * Load the driver module */ - DPRINT("Loading module from %wZ\n", &ImagePath); Status = MmLoadSystemImage(&ImagePath, NULL, NULL, 0, (PVOID)&ModuleObject, &BaseAddress); if (!NT_SUCCESS(Status)) { DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status); - LoadParams->Status = Status; - KeSetEvent(&LoadParams->Event, 0, FALSE); - return; + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + return Status; } /* @@ -1961,10 +2026,10 @@ IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams) if (!NT_SUCCESS(Status)) { DPRINT1("IopCreateDeviceNode() failed (Status %lx)\n", Status); + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); MmUnloadSystemImage(ModuleObject); - LoadParams->Status = Status; - KeSetEvent(&LoadParams->Event, 0, FALSE); - return; + return Status; } IopDisplayLoadingMessage(&DeviceNode->ServiceName); @@ -1974,33 +2039,36 @@ IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams) &DeviceNode->ServiceName, (Type == 2 /* SERVICE_FILE_SYSTEM_DRIVER */ || Type == 8 /* SERVICE_RECOGNIZER_DRIVER */), - &DriverObject); + DriverObject); if (!NT_SUCCESS(Status)) { DPRINT1("IopInitializeDriverModule() failed (Status %lx)\n", Status); + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); MmUnloadSystemImage(ModuleObject); - IopFreeDeviceNode(DeviceNode); - LoadParams->Status = Status; - KeSetEvent(&LoadParams->Event, 0, FALSE); - return; + return Status; } + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + /* Initialize and start device */ - IopInitializeDevice(DeviceNode, DriverObject); + IopInitializeDevice(DeviceNode, *DriverObject); Status = IopStartDevice(DeviceNode); } else { + ExReleaseResourceLite(&IopDriverLoadResource); + KeLeaveCriticalRegion(); + DPRINT("DriverObject already exist in ObjectManager\n"); Status = STATUS_IMAGE_ALREADY_LOADED; /* IopGetDriverObject references the DriverObject, so dereference it */ - ObDereferenceObject(DriverObject); + ObDereferenceObject(*DriverObject); } - /* Pass status to the caller and signal the event */ - LoadParams->Status = Status; - KeSetEvent(&LoadParams->Event, 0, FALSE); + return Status; } /* @@ -2023,7 +2091,7 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName) { UNICODE_STRING CapturedDriverServiceName = { 0, 0, NULL }; KPROCESSOR_MODE PreviousMode; - LOAD_UNLOAD_PARAMS LoadParams; + PDRIVER_OBJECT DriverObject; NTSTATUS Status; PAGED_CODE(); @@ -2053,35 +2121,14 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName) DPRINT("NtLoadDriver('%wZ')\n", &CapturedDriverServiceName); - LoadParams.ServiceName = &CapturedDriverServiceName; - LoadParams.DriverObject = NULL; - KeInitializeEvent(&LoadParams.Event, NotificationEvent, FALSE); - - /* Call the load/unload routine, depending on current process */ - if (PsGetCurrentProcess() == PsInitialSystemProcess) - { - /* Just call right away */ - IopLoadUnloadDriver(&LoadParams); - } - else - { - /* Load/Unload must be called from system process */ - ExInitializeWorkItem(&LoadParams.WorkItem, - (PWORKER_THREAD_ROUTINE)IopLoadUnloadDriver, - (PVOID)&LoadParams); - - /* Queue it */ - ExQueueWorkItem(&LoadParams.WorkItem, DelayedWorkQueue); - - /* And wait when it completes */ - KeWaitForSingleObject(&LoadParams.Event, UserRequest, KernelMode, - FALSE, NULL); - } + /* Load driver and call its entry point */ + DriverObject = NULL; + Status = IopLoadUnloadDriver(&CapturedDriverServiceName, &DriverObject); ReleaseCapturedUnicodeString(&CapturedDriverServiceName, PreviousMode); - return LoadParams.Status; + return Status; } /* @@ -2103,7 +2150,7 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName) NTSTATUS NTAPI NtUnloadDriver(IN PUNICODE_STRING DriverServiceName) { - return IopUnloadDriver(DriverServiceName, FALSE); + return IopUnloadDriver(DriverServiceName, FALSE); } /* EOF */