X-Git-Url: https://git.reactos.org/?p=reactos.git;a=blobdiff_plain;f=modules%2Frostests%2Fkmtests%2Fkmtest%2Ffltsupport.c;h=4127ccad24ff74151d65edb3eacca958856cffc0;hp=772e98a211cd606cd3288e15ef44faa5efe878b2;hb=dfb776380d4ca289a30b999791040d80bd38e0ae;hpb=f884d29c90db2e62ef7f16f908587514ba450357;ds=inline diff --git a/modules/rostests/kmtests/kmtest/fltsupport.c b/modules/rostests/kmtests/kmtest/fltsupport.c index 772e98a211c..4127ccad24f 100644 --- a/modules/rostests/kmtests/kmtest/fltsupport.c +++ b/modules/rostests/kmtests/kmtest/fltsupport.c @@ -7,12 +7,28 @@ #include +#define KMT_FLT_USER_MODE #include "kmtest.h" #include #include #include +/* + * We need to call the internal function in the service.c file + */ +DWORD +KmtpCreateService( + IN PCWSTR ServiceName, + IN PCWSTR ServicePath, + IN PCWSTR DisplayName OPTIONAL, + IN DWORD ServiceType, + OUT SC_HANDLE *ServiceHandle); + +DWORD EnablePrivilegeInCurrentProcess( + _In_z_ LPWSTR lpPrivName, + _In_ BOOL bEnable); + // move to a shared location typedef struct _KMTFLT_MESSAGE_HEADER { @@ -26,33 +42,30 @@ extern HANDLE KmtestHandle; static WCHAR TestServiceName[MAX_PATH]; + /** - * @name KmtFltLoadDriver + * @name KmtFltCreateService * - * Load the specified filter driver - * This routine will create the service entry if it doesn't already exist + * Create the specified driver service and return a handle to it * * @param ServiceName - * Name of the driver service (Kmtest- prefix will be added automatically) - * @param RestartIfRunning - * TRUE to stop and restart the service if it is already running - * @param ConnectComms - * TRUE to create a comms connection to the specified filter - * @param hPort - * Handle to the filter's comms port + * Name of the service to create + * @param ServicePath + * File name of the driver, relative to the current directory + * @param DisplayName + * Service display name + * @param ServiceHandle + * Pointer to a variable to receive the handle to the service * * @return Win32 error code */ DWORD -KmtFltLoadDriver( +KmtFltCreateService( _In_z_ PCWSTR ServiceName, - _In_ BOOLEAN RestartIfRunning, - _In_ BOOLEAN ConnectComms, - _Out_ HANDLE *hPort) + _In_z_ PCWSTR DisplayName, + _Out_ SC_HANDLE *ServiceHandle) { - DWORD Error = ERROR_SUCCESS; WCHAR ServicePath[MAX_PATH]; - SC_HANDLE TestServiceHandle; StringCbCopy(ServicePath, sizeof ServicePath, ServiceName); StringCbCat(ServicePath, sizeof ServicePath, L"_drv.sys"); @@ -60,11 +73,81 @@ KmtFltLoadDriver( StringCbCopy(TestServiceName, sizeof TestServiceName, L"Kmtest-"); StringCbCat(TestServiceName, sizeof TestServiceName, ServiceName); - Error = KmtFltCreateAndStartService(TestServiceName, ServicePath, NULL, &TestServiceHandle, TRUE); + return KmtpCreateService(TestServiceName, + ServicePath, + DisplayName, + SERVICE_FILE_SYSTEM_DRIVER, + ServiceHandle); +} + +/** + * @name KmtFltDeleteService + * + * Delete the specified filter driver + * + * @param ServiceName + * If *ServiceHandle is NULL, name of the service to delete + * @param ServiceHandle + * Pointer to a variable containing the service handle. + * Will be set to NULL on success + * + * @return Win32 error code + */ +DWORD +KmtFltDeleteService( + _In_opt_z_ PCWSTR ServiceName, + _Inout_ SC_HANDLE *ServiceHandle) +{ + return KmtDeleteService(ServiceName, ServiceHandle); +} - if (Error == ERROR_SUCCESS && ConnectComms) +/** + * @name KmtFltLoadDriver + * + * Delete the specified filter driver + * + * @return Win32 error code + */ +DWORD +KmtFltLoadDriver( + _In_ BOOLEAN EnableDriverLoadPrivlege, + _In_ BOOLEAN RestartIfRunning, + _In_ BOOLEAN ConnectComms, + _Out_ HANDLE *hPort +) +{ + DWORD Error; + + if (EnableDriverLoadPrivlege) + { + Error = EnablePrivilegeInCurrentProcess(SE_LOAD_DRIVER_NAME , TRUE); + if (Error) + { + return Error; + } + } + + Error = KmtFltLoad(TestServiceName); + if ((Error == ERROR_SERVICE_ALREADY_RUNNING) && RestartIfRunning) + { + Error = KmtFltUnload(TestServiceName); + if (Error) + { + // TODO + __debugbreak(); + } + + Error = KmtFltLoad(TestServiceName); + } + + if (Error) { - Error = KmtFltConnect(ServiceName, hPort); + return Error; + } + + if (ConnectComms) + { + Error = KmtFltConnectComms(hPort); } return Error; @@ -77,7 +160,7 @@ KmtFltLoadDriver( * * @param hPort * Handle to the filter's comms port - * @param ConnectComms + * @param DisonnectComms * TRUE to disconnect the comms connection before unloading * * @return Win32 error code @@ -110,6 +193,57 @@ KmtFltUnloadDriver( return Error; } +/** + * @name KmtFltConnectComms + * + * Create a comms connection to the specified filter + * + * @param hPort + * Handle to the filter's comms port + * + * @return Win32 error code + */ +DWORD +KmtFltConnectComms( + _Out_ HANDLE *hPort) +{ + return KmtFltConnect(TestServiceName, hPort); +} + +/** + * @name KmtFltDisconnectComms + * + * Disconenct from the comms port + * + * @param hPort + * Handle to the filter's comms port + * + * @return Win32 error code + */ +DWORD +KmtFltDisconnectComms( + _In_ HANDLE hPort) +{ + return KmtFltDisconnect(hPort); +} + + +/** +* @name KmtFltCloseService +* +* Close the specified driver service handle +* +* @param ServiceHandle +* Pointer to a variable containing the service handle. +* Will be set to NULL on success +* +* @return Win32 error code +*/ +DWORD KmtFltCloseService( + _Inout_ SC_HANDLE *ServiceHandle) +{ + return KmtCloseService(ServiceHandle); +} /** * @name KmtFltRunKernelTest @@ -141,7 +275,7 @@ KmtFltRunKernelTest( * @param Message * The message to send to the filter * -* @return Win32 error code as returned by DeviceIoControl +* @return Win32 error code */ DWORD KmtFltSendToDriver( @@ -165,7 +299,7 @@ KmtFltSendToDriver( * @param String * An ANSI string to send to the filter * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendStringToDriver( @@ -190,7 +324,7 @@ KmtFltSendStringToDriver( * @param String * An wide string to send to the filter * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendWStringToDriver( @@ -213,7 +347,7 @@ KmtFltSendWStringToDriver( * @param Value * An 32bit valueng to send to the filter * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendUlongToDriver( @@ -244,7 +378,7 @@ KmtFltSendUlongToDriver( * @param BytesReturned * Number of bytes written in the reply buffer * - * @return Win32 error code as returned by DeviceIoControl + * @return Win32 error code */ DWORD KmtFltSendBufferToDriver( @@ -299,3 +433,154 @@ KmtFltSendBufferToDriver( return Error; } + +/** +* @name KmtFltAddAltitude +* +* Sets up the mini-filter altitude data in the registry +* +* @param hPort +* The altitude string to set +* +* @return Win32 error code +*/ +DWORD +KmtFltAddAltitude( + _In_z_ LPWSTR Altitude) +{ + WCHAR DefaultInstance[128]; + WCHAR KeyPath[256]; + HKEY hKey = NULL; + HKEY hSubKey = NULL; + DWORD Zero = 0; + LONG Error; + + StringCbCopy(KeyPath, sizeof KeyPath, L"SYSTEM\\CurrentControlSet\\Services\\"); + StringCbCat(KeyPath, sizeof KeyPath, TestServiceName); + StringCbCat(KeyPath, sizeof KeyPath, L"\\Instances\\"); + + Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE, + KeyPath, + 0, + NULL, + REG_OPTION_NON_VOLATILE, + KEY_CREATE_SUB_KEY | KEY_SET_VALUE, + NULL, + &hKey, + NULL); + if (Error != ERROR_SUCCESS) + { + return Error; + } + + StringCbCopy(DefaultInstance, sizeof DefaultInstance, TestServiceName); + StringCbCat(DefaultInstance, sizeof DefaultInstance, L" Instance"); + + Error = RegSetValueExW(hKey, + L"DefaultInstance", + 0, + REG_SZ, + (LPBYTE)DefaultInstance, + (wcslen(DefaultInstance) + 1) * sizeof(WCHAR)); + if (Error != ERROR_SUCCESS) + { + goto Quit; + } + + Error = RegCreateKeyW(hKey, DefaultInstance, &hSubKey); + if (Error != ERROR_SUCCESS) + { + goto Quit; + } + + Error = RegSetValueExW(hSubKey, + L"Altitude", + 0, + REG_SZ, + (LPBYTE)Altitude, + (wcslen(Altitude) + 1) * sizeof(WCHAR)); + if (Error != ERROR_SUCCESS) + { + goto Quit; + } + + Error = RegSetValueExW(hSubKey, + L"Flags", + 0, + REG_DWORD, + (LPBYTE)&Zero, + sizeof(DWORD)); + +Quit: + if (hSubKey) + { + RegCloseKey(hSubKey); + } + if (hKey) + { + RegCloseKey(hKey); + } + + return Error; + +} + +/* +* Private functions, not meant for use in kmtests +*/ + +DWORD EnablePrivilege( + _In_ HANDLE hToken, + _In_z_ LPWSTR lpPrivName, + _In_ BOOL bEnable) +{ + TOKEN_PRIVILEGES TokenPrivileges; + LUID luid; + BOOL bSuccess; + DWORD dwError = ERROR_SUCCESS; + + /* Get the luid for this privilege */ + if (!LookupPrivilegeValueW(NULL, lpPrivName, &luid)) + return GetLastError(); + + /* Setup the struct with the priv info */ + TokenPrivileges.PrivilegeCount = 1; + TokenPrivileges.Privileges[0].Luid = luid; + TokenPrivileges.Privileges[0].Attributes = bEnable ? SE_PRIVILEGE_ENABLED : 0; + + /* Enable the privilege info in the token */ + bSuccess = AdjustTokenPrivileges(hToken, + FALSE, + &TokenPrivileges, + sizeof(TOKEN_PRIVILEGES), + NULL, + NULL); + if (bSuccess == FALSE) dwError = GetLastError(); + + /* return status */ + return dwError; +} + +DWORD EnablePrivilegeInCurrentProcess( + _In_z_ LPWSTR lpPrivName, + _In_ BOOL bEnable) +{ + HANDLE hToken; + BOOL bSuccess; + DWORD dwError = ERROR_SUCCESS; + + /* Get a handle to our token */ + bSuccess = OpenProcessToken(GetCurrentProcess(), + TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, + &hToken); + if (bSuccess == FALSE) return GetLastError(); + + /* Enable the privilege in the agent token */ + dwError = EnablePrivilege(hToken, lpPrivName, bEnable); + + /* We're done with this now */ + CloseHandle(hToken); + + /* return status */ + return dwError; +}