[FLTMGR] Latest from my branch (#135)
[reactos.git] / modules / rostests / kmtests / kmtest / fltsupport.c
index 772e98a..4127cca 100644 (file)
@@ -7,12 +7,28 @@
 
 #include <kmt_test.h>
 
 
 #include <kmt_test.h>
 
+#define KMT_FLT_USER_MODE
 #include "kmtest.h"
 #include <kmt_public.h>
 
 #include <assert.h>
 #include <debug.h>
 
 #include "kmtest.h"
 #include <kmt_public.h>
 
 #include <assert.h>
 #include <debug.h>
 
+/*
+ * We need to call the internal function in the service.c file
+ */
+DWORD
+KmtpCreateService(
+    IN PCWSTR ServiceName,
+    IN PCWSTR ServicePath,
+    IN PCWSTR DisplayName OPTIONAL,
+    IN DWORD ServiceType,
+    OUT SC_HANDLE *ServiceHandle);
+
+DWORD EnablePrivilegeInCurrentProcess(
+    _In_z_ LPWSTR lpPrivName,
+    _In_ BOOL bEnable);
+
 // move to a shared location
 typedef struct _KMTFLT_MESSAGE_HEADER
 {
 // move to a shared location
 typedef struct _KMTFLT_MESSAGE_HEADER
 {
@@ -26,33 +42,30 @@ extern HANDLE KmtestHandle;
 static WCHAR TestServiceName[MAX_PATH];
 
 
 static WCHAR TestServiceName[MAX_PATH];
 
 
+
 /**
 /**
- * @name KmtFltLoadDriver
+ * @name KmtFltCreateService
  *
  *
- * Load the specified filter driver
- * This routine will create the service entry if it doesn't already exist
+ * Create the specified driver service and return a handle to it
  *
  * @param ServiceName
  *
  * @param ServiceName
- *        Name of the driver service (Kmtest- prefix will be added automatically)
- * @param RestartIfRunning
- *        TRUE to stop and restart the service if it is already running
- * @param ConnectComms
- *        TRUE to create a comms connection to the specified filter
- * @param hPort
- *        Handle to the filter's comms port
+ *        Name of the service to create
+ * @param ServicePath
+ *        File name of the driver, relative to the current directory
+ * @param DisplayName
+ *        Service display name
+ * @param ServiceHandle
+ *        Pointer to a variable to receive the handle to the service
  *
  * @return Win32 error code
  */
 DWORD
  *
  * @return Win32 error code
  */
 DWORD
-KmtFltLoadDriver(
+KmtFltCreateService(
     _In_z_ PCWSTR ServiceName,
     _In_z_ PCWSTR ServiceName,
-    _In_ BOOLEAN RestartIfRunning,
-    _In_ BOOLEAN ConnectComms,
-    _Out_ HANDLE *hPort)
+    _In_z_ PCWSTR DisplayName,
+    _Out_ SC_HANDLE *ServiceHandle)
 {
 {
-    DWORD Error = ERROR_SUCCESS;
     WCHAR ServicePath[MAX_PATH];
     WCHAR ServicePath[MAX_PATH];
-    SC_HANDLE TestServiceHandle;
 
     StringCbCopy(ServicePath, sizeof ServicePath, ServiceName);
     StringCbCat(ServicePath, sizeof ServicePath, L"_drv.sys");
 
     StringCbCopy(ServicePath, sizeof ServicePath, ServiceName);
     StringCbCat(ServicePath, sizeof ServicePath, L"_drv.sys");
@@ -60,11 +73,81 @@ KmtFltLoadDriver(
     StringCbCopy(TestServiceName, sizeof TestServiceName, L"Kmtest-");
     StringCbCat(TestServiceName, sizeof TestServiceName, ServiceName);
 
     StringCbCopy(TestServiceName, sizeof TestServiceName, L"Kmtest-");
     StringCbCat(TestServiceName, sizeof TestServiceName, ServiceName);
 
-    Error = KmtFltCreateAndStartService(TestServiceName, ServicePath, NULL, &TestServiceHandle, TRUE);
+    return KmtpCreateService(TestServiceName,
+                             ServicePath,
+                             DisplayName,
+                             SERVICE_FILE_SYSTEM_DRIVER,
+                             ServiceHandle);
+}
+
+/**
+ * @name KmtFltDeleteService
+ *
+ * Delete the specified filter driver
+ *
+ * @param ServiceName
+ *        If *ServiceHandle is NULL, name of the service to delete
+ * @param ServiceHandle
+ *        Pointer to a variable containing the service handle.
+ *        Will be set to NULL on success
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltDeleteService(
+    _In_opt_z_ PCWSTR ServiceName,
+    _Inout_ SC_HANDLE *ServiceHandle)
+{
+    return KmtDeleteService(ServiceName, ServiceHandle);
+}
 
 
-    if (Error == ERROR_SUCCESS && ConnectComms)
+/**
+ * @name KmtFltLoadDriver
+ *
+ * Delete the specified filter driver
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltLoadDriver(
+    _In_ BOOLEAN EnableDriverLoadPrivlege,
+    _In_ BOOLEAN RestartIfRunning,
+    _In_ BOOLEAN ConnectComms,
+    _Out_ HANDLE *hPort
+)
+{
+    DWORD Error;
+
+    if (EnableDriverLoadPrivlege)
+    {
+        Error = EnablePrivilegeInCurrentProcess(SE_LOAD_DRIVER_NAME , TRUE);
+        if (Error)
+        {
+            return Error;
+        }
+    }
+
+    Error = KmtFltLoad(TestServiceName);
+    if ((Error == ERROR_SERVICE_ALREADY_RUNNING) && RestartIfRunning)
+    {
+        Error = KmtFltUnload(TestServiceName);
+        if (Error)
+        {
+            // TODO
+            __debugbreak();
+        }
+
+        Error = KmtFltLoad(TestServiceName);
+    }
+
+    if (Error)
     {
     {
-        Error = KmtFltConnect(ServiceName, hPort);
+        return Error;
+    }
+
+    if (ConnectComms)
+    {
+        Error = KmtFltConnectComms(hPort);
     }
 
     return Error;
     }
 
     return Error;
@@ -77,7 +160,7 @@ KmtFltLoadDriver(
  *
  * @param hPort
  *        Handle to the filter's comms port
  *
  * @param hPort
  *        Handle to the filter's comms port
- * @param ConnectComms
+ * @param DisonnectComms
  *        TRUE to disconnect the comms connection before unloading
  *
  * @return Win32 error code
  *        TRUE to disconnect the comms connection before unloading
  *
  * @return Win32 error code
@@ -110,6 +193,57 @@ KmtFltUnloadDriver(
     return Error;
 }
 
     return Error;
 }
 
+/**
+ * @name KmtFltConnectComms
+ *
+ * Create a comms connection to the specified filter
+ *
+ * @param hPort
+ *        Handle to the filter's comms port
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltConnectComms(
+    _Out_ HANDLE *hPort)
+{
+    return KmtFltConnect(TestServiceName, hPort);
+}
+
+/**
+ * @name KmtFltDisconnectComms
+ *
+ * Disconenct from the comms port
+ *
+ * @param hPort
+ *        Handle to the filter's comms port
+ *
+ * @return Win32 error code
+ */
+DWORD
+KmtFltDisconnectComms(
+    _In_ HANDLE hPort)
+{
+    return KmtFltDisconnect(hPort);
+}
+
+
+/**
+* @name KmtFltCloseService
+*
+* Close the specified driver service handle
+*
+* @param ServiceHandle
+*        Pointer to a variable containing the service handle.
+*        Will be set to NULL on success
+*
+* @return Win32 error code
+*/
+DWORD KmtFltCloseService(
+    _Inout_ SC_HANDLE *ServiceHandle)
+{
+    return KmtCloseService(ServiceHandle);
+}
 
 /**
 * @name KmtFltRunKernelTest
 
 /**
 * @name KmtFltRunKernelTest
@@ -141,7 +275,7 @@ KmtFltRunKernelTest(
 * @param Message
 *        The message to send to the filter
 *
 * @param Message
 *        The message to send to the filter
 *
-* @return Win32 error code as returned by DeviceIoControl
+* @return Win32 error code
 */
 DWORD
 KmtFltSendToDriver(
 */
 DWORD
 KmtFltSendToDriver(
@@ -165,7 +299,7 @@ KmtFltSendToDriver(
  * @param String
  *        An ANSI string to send to the filter
  *
  * @param String
  *        An ANSI string to send to the filter
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendStringToDriver(
  */
 DWORD
 KmtFltSendStringToDriver(
@@ -190,7 +324,7 @@ KmtFltSendStringToDriver(
  * @param String
  *        An wide string to send to the filter
  *
  * @param String
  *        An wide string to send to the filter
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendWStringToDriver(
  */
 DWORD
 KmtFltSendWStringToDriver(
@@ -213,7 +347,7 @@ KmtFltSendWStringToDriver(
  * @param Value
  *        An 32bit valueng to send to the filter
  *
  * @param Value
  *        An 32bit valueng to send to the filter
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendUlongToDriver(
  */
 DWORD
 KmtFltSendUlongToDriver(
@@ -244,7 +378,7 @@ KmtFltSendUlongToDriver(
  * @param BytesReturned
  *        Number of bytes written in the reply buffer
  *
  * @param BytesReturned
  *        Number of bytes written in the reply buffer
  *
- * @return Win32 error code as returned by DeviceIoControl
+ * @return Win32 error code
  */
 DWORD
 KmtFltSendBufferToDriver(
  */
 DWORD
 KmtFltSendBufferToDriver(
@@ -299,3 +433,154 @@ KmtFltSendBufferToDriver(
 
     return Error;
 }
 
     return Error;
 }
+
+/**
+* @name KmtFltAddAltitude
+*
+* Sets up the mini-filter altitude data in the registry
+*
+* @param hPort
+*        The altitude string to set
+*
+* @return Win32 error code
+*/
+DWORD
+KmtFltAddAltitude(
+    _In_z_ LPWSTR Altitude)
+{
+    WCHAR DefaultInstance[128];
+    WCHAR KeyPath[256];
+    HKEY hKey = NULL;
+    HKEY hSubKey = NULL;
+    DWORD Zero = 0;
+    LONG Error;
+
+    StringCbCopy(KeyPath, sizeof KeyPath, L"SYSTEM\\CurrentControlSet\\Services\\");
+    StringCbCat(KeyPath, sizeof KeyPath, TestServiceName);
+    StringCbCat(KeyPath, sizeof KeyPath, L"\\Instances\\");
+
+    Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE,
+                           KeyPath,
+                           0,
+                           NULL,
+                           REG_OPTION_NON_VOLATILE,
+                           KEY_CREATE_SUB_KEY | KEY_SET_VALUE,
+                           NULL,
+                           &hKey,
+                           NULL);
+    if (Error != ERROR_SUCCESS)
+    {
+        return Error;
+    }
+
+    StringCbCopy(DefaultInstance, sizeof DefaultInstance, TestServiceName);
+    StringCbCat(DefaultInstance, sizeof DefaultInstance, L" Instance");
+
+    Error = RegSetValueExW(hKey,
+                           L"DefaultInstance",
+                           0,
+                           REG_SZ,
+                           (LPBYTE)DefaultInstance,
+                           (wcslen(DefaultInstance) + 1) * sizeof(WCHAR));
+    if (Error != ERROR_SUCCESS)
+    {
+        goto Quit;
+    }
+
+    Error = RegCreateKeyW(hKey, DefaultInstance, &hSubKey);
+    if (Error != ERROR_SUCCESS)
+    {
+        goto Quit;
+    }
+
+    Error = RegSetValueExW(hSubKey,
+                           L"Altitude",
+                           0,
+                           REG_SZ,
+                           (LPBYTE)Altitude,
+                           (wcslen(Altitude) + 1) * sizeof(WCHAR));
+    if (Error != ERROR_SUCCESS)
+    {
+        goto Quit;
+    }
+
+    Error = RegSetValueExW(hSubKey,
+                           L"Flags",
+                           0,
+                           REG_DWORD,
+                           (LPBYTE)&Zero,
+                           sizeof(DWORD));
+
+Quit:
+    if (hSubKey)
+    {
+        RegCloseKey(hSubKey);
+    }
+    if (hKey)
+    {
+        RegCloseKey(hKey);
+    }
+
+    return Error;
+
+}
+
+/*
+* Private functions, not meant for use in kmtests
+*/
+
+DWORD EnablePrivilege(
+    _In_ HANDLE hToken,
+    _In_z_ LPWSTR lpPrivName,
+    _In_ BOOL bEnable)
+{
+    TOKEN_PRIVILEGES TokenPrivileges;
+    LUID luid;
+    BOOL bSuccess;
+    DWORD dwError = ERROR_SUCCESS;
+
+    /* Get the luid for this privilege */
+    if (!LookupPrivilegeValueW(NULL, lpPrivName, &luid))
+        return GetLastError();
+
+    /* Setup the struct with the priv info */
+    TokenPrivileges.PrivilegeCount = 1;
+    TokenPrivileges.Privileges[0].Luid = luid;
+    TokenPrivileges.Privileges[0].Attributes = bEnable ? SE_PRIVILEGE_ENABLED : 0;
+
+    /* Enable the privilege info in the token */
+    bSuccess = AdjustTokenPrivileges(hToken,
+                                     FALSE,
+                                     &TokenPrivileges,
+                                     sizeof(TOKEN_PRIVILEGES),
+                                     NULL,
+                                     NULL);
+    if (bSuccess == FALSE) dwError = GetLastError();
+
+    /* return status */
+    return dwError;
+}
+
+DWORD EnablePrivilegeInCurrentProcess(
+    _In_z_ LPWSTR lpPrivName,
+    _In_ BOOL bEnable)
+{
+    HANDLE hToken;
+    BOOL bSuccess;
+    DWORD dwError = ERROR_SUCCESS;
+
+    /* Get a handle to our token */
+    bSuccess = OpenProcessToken(GetCurrentProcess(),
+                                TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
+                                &hToken);
+    if (bSuccess == FALSE) return GetLastError();
+
+    /* Enable the privilege in the agent token */
+    dwError = EnablePrivilege(hToken, lpPrivName, bEnable);
+
+    /* We're done with this now */
+    CloseHandle(hToken);
+
+    /* return status */
+    return dwError;
+}