4127ccad24ff74151d65edb3eacca958856cffc0
[reactos.git] / modules / rostests / kmtests / kmtest / fltsupport.c
1 /*
2 * PROJECT: ReactOS kernel-mode tests
3 * LICENSE: GPLv2+ - See COPYING in the top level directory
4 * PURPOSE: File system mini-filter support routines
5 * PROGRAMMER: Ged Murphy <gedmurphy@reactos.org>
6 */
7
8 #include <kmt_test.h>
9
10 #define KMT_FLT_USER_MODE
11 #include "kmtest.h"
12 #include <kmt_public.h>
13
14 #include <assert.h>
15 #include <debug.h>
16
17 /*
18 * We need to call the internal function in the service.c file
19 */
20 DWORD
21 KmtpCreateService(
22 IN PCWSTR ServiceName,
23 IN PCWSTR ServicePath,
24 IN PCWSTR DisplayName OPTIONAL,
25 IN DWORD ServiceType,
26 OUT SC_HANDLE *ServiceHandle);
27
28 DWORD EnablePrivilegeInCurrentProcess(
29 _In_z_ LPWSTR lpPrivName,
30 _In_ BOOL bEnable);
31
32 // move to a shared location
33 typedef struct _KMTFLT_MESSAGE_HEADER
34 {
35 ULONG Message;
36 PVOID Buffer;
37 ULONG BufferSize;
38
39 } KMTFLT_MESSAGE_HEADER, *PKMTFLT_MESSAGE_HEADER;
40
41 extern HANDLE KmtestHandle;
42 static WCHAR TestServiceName[MAX_PATH];
43
44
45
46 /**
47 * @name KmtFltCreateService
48 *
49 * Create the specified driver service and return a handle to it
50 *
51 * @param ServiceName
52 * Name of the service to create
53 * @param ServicePath
54 * File name of the driver, relative to the current directory
55 * @param DisplayName
56 * Service display name
57 * @param ServiceHandle
58 * Pointer to a variable to receive the handle to the service
59 *
60 * @return Win32 error code
61 */
62 DWORD
63 KmtFltCreateService(
64 _In_z_ PCWSTR ServiceName,
65 _In_z_ PCWSTR DisplayName,
66 _Out_ SC_HANDLE *ServiceHandle)
67 {
68 WCHAR ServicePath[MAX_PATH];
69
70 StringCbCopy(ServicePath, sizeof ServicePath, ServiceName);
71 StringCbCat(ServicePath, sizeof ServicePath, L"_drv.sys");
72
73 StringCbCopy(TestServiceName, sizeof TestServiceName, L"Kmtest-");
74 StringCbCat(TestServiceName, sizeof TestServiceName, ServiceName);
75
76 return KmtpCreateService(TestServiceName,
77 ServicePath,
78 DisplayName,
79 SERVICE_FILE_SYSTEM_DRIVER,
80 ServiceHandle);
81 }
82
83 /**
84 * @name KmtFltDeleteService
85 *
86 * Delete the specified filter driver
87 *
88 * @param ServiceName
89 * If *ServiceHandle is NULL, name of the service to delete
90 * @param ServiceHandle
91 * Pointer to a variable containing the service handle.
92 * Will be set to NULL on success
93 *
94 * @return Win32 error code
95 */
96 DWORD
97 KmtFltDeleteService(
98 _In_opt_z_ PCWSTR ServiceName,
99 _Inout_ SC_HANDLE *ServiceHandle)
100 {
101 return KmtDeleteService(ServiceName, ServiceHandle);
102 }
103
104 /**
105 * @name KmtFltLoadDriver
106 *
107 * Delete the specified filter driver
108 *
109 * @return Win32 error code
110 */
111 DWORD
112 KmtFltLoadDriver(
113 _In_ BOOLEAN EnableDriverLoadPrivlege,
114 _In_ BOOLEAN RestartIfRunning,
115 _In_ BOOLEAN ConnectComms,
116 _Out_ HANDLE *hPort
117 )
118 {
119 DWORD Error;
120
121 if (EnableDriverLoadPrivlege)
122 {
123 Error = EnablePrivilegeInCurrentProcess(SE_LOAD_DRIVER_NAME , TRUE);
124 if (Error)
125 {
126 return Error;
127 }
128 }
129
130 Error = KmtFltLoad(TestServiceName);
131 if ((Error == ERROR_SERVICE_ALREADY_RUNNING) && RestartIfRunning)
132 {
133 Error = KmtFltUnload(TestServiceName);
134 if (Error)
135 {
136 // TODO
137 __debugbreak();
138 }
139
140 Error = KmtFltLoad(TestServiceName);
141 }
142
143 if (Error)
144 {
145 return Error;
146 }
147
148 if (ConnectComms)
149 {
150 Error = KmtFltConnectComms(hPort);
151 }
152
153 return Error;
154 }
155
156 /**
157 * @name KmtFltUnloadDriver
158 *
159 * Unload the specified filter driver
160 *
161 * @param hPort
162 * Handle to the filter's comms port
163 * @param DisonnectComms
164 * TRUE to disconnect the comms connection before unloading
165 *
166 * @return Win32 error code
167 */
168 DWORD
169 KmtFltUnloadDriver(
170 _In_ HANDLE *hPort,
171 _In_ BOOLEAN DisonnectComms)
172 {
173 DWORD Error = ERROR_SUCCESS;
174
175 if (DisonnectComms)
176 {
177 Error = KmtFltDisconnect(hPort);
178
179 if (Error)
180 {
181 return Error;
182 }
183 }
184
185 Error = KmtFltUnload(TestServiceName);
186
187 if (Error)
188 {
189 // TODO
190 __debugbreak();
191 }
192
193 return Error;
194 }
195
196 /**
197 * @name KmtFltConnectComms
198 *
199 * Create a comms connection to the specified filter
200 *
201 * @param hPort
202 * Handle to the filter's comms port
203 *
204 * @return Win32 error code
205 */
206 DWORD
207 KmtFltConnectComms(
208 _Out_ HANDLE *hPort)
209 {
210 return KmtFltConnect(TestServiceName, hPort);
211 }
212
213 /**
214 * @name KmtFltDisconnectComms
215 *
216 * Disconenct from the comms port
217 *
218 * @param hPort
219 * Handle to the filter's comms port
220 *
221 * @return Win32 error code
222 */
223 DWORD
224 KmtFltDisconnectComms(
225 _In_ HANDLE hPort)
226 {
227 return KmtFltDisconnect(hPort);
228 }
229
230
231 /**
232 * @name KmtFltCloseService
233 *
234 * Close the specified driver service handle
235 *
236 * @param ServiceHandle
237 * Pointer to a variable containing the service handle.
238 * Will be set to NULL on success
239 *
240 * @return Win32 error code
241 */
242 DWORD KmtFltCloseService(
243 _Inout_ SC_HANDLE *ServiceHandle)
244 {
245 return KmtCloseService(ServiceHandle);
246 }
247
248 /**
249 * @name KmtFltRunKernelTest
250 *
251 * Run the specified filter test part
252 *
253 * @param hPort
254 * Handle to the filter's comms port
255 * @param TestName
256 * Name of the test to run
257 *
258 * @return Win32 error code
259 */
260 DWORD
261 KmtFltRunKernelTest(
262 _In_ HANDLE hPort,
263 _In_z_ PCSTR TestName)
264 {
265 return KmtFltSendStringToDriver(hPort, KMTFLT_RUN_TEST, TestName);
266 }
267
268 /**
269 * @name KmtFltSendToDriver
270 *
271 * Send an I/O control message with no arguments to the driver opened with KmtOpenDriver
272 *
273 * @param hPort
274 * Handle to the filter's comms port
275 * @param Message
276 * The message to send to the filter
277 *
278 * @return Win32 error code
279 */
280 DWORD
281 KmtFltSendToDriver(
282 _In_ HANDLE hPort,
283 _In_ DWORD Message)
284 {
285 assert(hPort);
286 return KmtFltSendBufferToDriver(hPort, Message, NULL, 0, NULL, 0, NULL);
287 }
288
289 /**
290 * @name KmtFltSendStringToDriver
291 *
292 * Send an I/O control message with a string argument to the driver opened with KmtOpenDriver
293 *
294 *
295 * @param hPort
296 * Handle to the filter's comms port
297 * @param Message
298 * The message associated with the string
299 * @param String
300 * An ANSI string to send to the filter
301 *
302 * @return Win32 error code
303 */
304 DWORD
305 KmtFltSendStringToDriver(
306 _In_ HANDLE hPort,
307 _In_ DWORD Message,
308 _In_ PCSTR String)
309 {
310 assert(hPort);
311 assert(String);
312 return KmtFltSendBufferToDriver(hPort, Message, (PVOID)String, (DWORD)strlen(String), NULL, 0, NULL);
313 }
314
315 /**
316 * @name KmtFltSendWStringToDriver
317 *
318 * Send an I/O control message with a wide string argument to the driver opened with KmtOpenDriver
319 *
320 * @param hPort
321 * Handle to the filter's comms port
322 * @param Message
323 * The message associated with the string
324 * @param String
325 * An wide string to send to the filter
326 *
327 * @return Win32 error code
328 */
329 DWORD
330 KmtFltSendWStringToDriver(
331 _In_ HANDLE hPort,
332 _In_ DWORD Message,
333 _In_ PCWSTR String)
334 {
335 return KmtFltSendBufferToDriver(hPort, Message, (PVOID)String, (DWORD)wcslen(String) * sizeof(WCHAR), NULL, 0, NULL);
336 }
337
338 /**
339 * @name KmtFltSendUlongToDriver
340 *
341 * Send an I/O control message with an integer argument to the driver opened with KmtOpenDriver
342 *
343 * @param hPort
344 * Handle to the filter's comms port
345 * @param Message
346 * The message associated with the value
347 * @param Value
348 * An 32bit valueng to send to the filter
349 *
350 * @return Win32 error code
351 */
352 DWORD
353 KmtFltSendUlongToDriver(
354 _In_ HANDLE hPort,
355 _In_ DWORD Message,
356 _In_ DWORD Value)
357 {
358 return KmtFltSendBufferToDriver(hPort, Message, &Value, sizeof(Value), NULL, 0, NULL);
359 }
360
361 /**
362 * @name KmtSendBufferToDriver
363 *
364 * Send an I/O control message with the specified arguments to the driver opened with KmtOpenDriver
365 *
366 * @param hPort
367 * Handle to the filter's comms port
368 * @param Message
369 * The message associated with the value
370 * @param InBuffer
371 * Pointer to a buffer to send to the filter
372 * @param BufferSize
373 * Size of the buffer pointed to by InBuffer
374 * @param OutBuffer
375 * Pointer to a buffer to receive a response from the filter
376 * @param OutBufferSize
377 * Size of the buffer pointed to by OutBuffer
378 * @param BytesReturned
379 * Number of bytes written in the reply buffer
380 *
381 * @return Win32 error code
382 */
383 DWORD
384 KmtFltSendBufferToDriver(
385 _In_ HANDLE hPort,
386 _In_ DWORD Message,
387 _In_reads_bytes_(BufferSize) LPVOID InBuffer,
388 _In_ DWORD BufferSize,
389 _Out_writes_bytes_to_opt_(OutBufferSize, *BytesReturned) LPVOID OutBuffer,
390 _In_ DWORD OutBufferSize,
391 _Out_opt_ LPDWORD BytesReturned)
392 {
393 PKMTFLT_MESSAGE_HEADER Ptr;
394 KMTFLT_MESSAGE_HEADER Header;
395 BOOLEAN FreeMemory = FALSE;
396 DWORD InBufferSize;
397 DWORD Error;
398
399 assert(hPort);
400
401 if (BufferSize)
402 {
403 assert(InBuffer);
404
405 InBufferSize = sizeof(KMTFLT_MESSAGE_HEADER) + BufferSize;
406 Ptr = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, InBufferSize);
407 if (!Ptr)
408 {
409 return ERROR_NOT_ENOUGH_MEMORY;
410 }
411 FreeMemory = TRUE;
412 }
413 else
414 {
415 InBufferSize = sizeof(KMTFLT_MESSAGE_HEADER);
416 Ptr = &Header;
417 }
418
419 Ptr->Message = Message;
420 if (BufferSize)
421 {
422 Ptr->Buffer = (Ptr + 1);
423 StringCbCopy(Ptr->Buffer, BufferSize, InBuffer);
424 Ptr->BufferSize = BufferSize;
425 }
426
427 Error = KmtFltSendMessage(hPort, Ptr, InBufferSize, OutBuffer, OutBufferSize, BytesReturned);
428
429 if (FreeMemory)
430 {
431 HeapFree(GetProcessHeap(), 0, Ptr);
432 }
433
434 return Error;
435 }
436
437 /**
438 * @name KmtFltAddAltitude
439 *
440 * Sets up the mini-filter altitude data in the registry
441 *
442 * @param hPort
443 * The altitude string to set
444 *
445 * @return Win32 error code
446 */
447 DWORD
448 KmtFltAddAltitude(
449 _In_z_ LPWSTR Altitude)
450 {
451 WCHAR DefaultInstance[128];
452 WCHAR KeyPath[256];
453 HKEY hKey = NULL;
454 HKEY hSubKey = NULL;
455 DWORD Zero = 0;
456 LONG Error;
457
458 StringCbCopy(KeyPath, sizeof KeyPath, L"SYSTEM\\CurrentControlSet\\Services\\");
459 StringCbCat(KeyPath, sizeof KeyPath, TestServiceName);
460 StringCbCat(KeyPath, sizeof KeyPath, L"\\Instances\\");
461
462 Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE,
463 KeyPath,
464 0,
465 NULL,
466 REG_OPTION_NON_VOLATILE,
467 KEY_CREATE_SUB_KEY | KEY_SET_VALUE,
468 NULL,
469 &hKey,
470 NULL);
471 if (Error != ERROR_SUCCESS)
472 {
473 return Error;
474 }
475
476 StringCbCopy(DefaultInstance, sizeof DefaultInstance, TestServiceName);
477 StringCbCat(DefaultInstance, sizeof DefaultInstance, L" Instance");
478
479 Error = RegSetValueExW(hKey,
480 L"DefaultInstance",
481 0,
482 REG_SZ,
483 (LPBYTE)DefaultInstance,
484 (wcslen(DefaultInstance) + 1) * sizeof(WCHAR));
485 if (Error != ERROR_SUCCESS)
486 {
487 goto Quit;
488 }
489
490 Error = RegCreateKeyW(hKey, DefaultInstance, &hSubKey);
491 if (Error != ERROR_SUCCESS)
492 {
493 goto Quit;
494 }
495
496 Error = RegSetValueExW(hSubKey,
497 L"Altitude",
498 0,
499 REG_SZ,
500 (LPBYTE)Altitude,
501 (wcslen(Altitude) + 1) * sizeof(WCHAR));
502 if (Error != ERROR_SUCCESS)
503 {
504 goto Quit;
505 }
506
507 Error = RegSetValueExW(hSubKey,
508 L"Flags",
509 0,
510 REG_DWORD,
511 (LPBYTE)&Zero,
512 sizeof(DWORD));
513
514 Quit:
515 if (hSubKey)
516 {
517 RegCloseKey(hSubKey);
518 }
519 if (hKey)
520 {
521 RegCloseKey(hKey);
522 }
523
524 return Error;
525
526 }
527
528 /*
529 * Private functions, not meant for use in kmtests
530 */
531
532 DWORD EnablePrivilege(
533 _In_ HANDLE hToken,
534 _In_z_ LPWSTR lpPrivName,
535 _In_ BOOL bEnable)
536 {
537 TOKEN_PRIVILEGES TokenPrivileges;
538 LUID luid;
539 BOOL bSuccess;
540 DWORD dwError = ERROR_SUCCESS;
541
542 /* Get the luid for this privilege */
543 if (!LookupPrivilegeValueW(NULL, lpPrivName, &luid))
544 return GetLastError();
545
546 /* Setup the struct with the priv info */
547 TokenPrivileges.PrivilegeCount = 1;
548 TokenPrivileges.Privileges[0].Luid = luid;
549 TokenPrivileges.Privileges[0].Attributes = bEnable ? SE_PRIVILEGE_ENABLED : 0;
550
551 /* Enable the privilege info in the token */
552 bSuccess = AdjustTokenPrivileges(hToken,
553 FALSE,
554 &TokenPrivileges,
555 sizeof(TOKEN_PRIVILEGES),
556 NULL,
557 NULL);
558 if (bSuccess == FALSE) dwError = GetLastError();
559
560 /* return status */
561 return dwError;
562 }
563
564 DWORD EnablePrivilegeInCurrentProcess(
565 _In_z_ LPWSTR lpPrivName,
566 _In_ BOOL bEnable)
567 {
568 HANDLE hToken;
569 BOOL bSuccess;
570 DWORD dwError = ERROR_SUCCESS;
571
572 /* Get a handle to our token */
573 bSuccess = OpenProcessToken(GetCurrentProcess(),
574 TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
575 &hToken);
576 if (bSuccess == FALSE) return GetLastError();
577
578 /* Enable the privilege in the agent token */
579 dwError = EnablePrivilege(hToken, lpPrivName, bEnable);
580
581 /* We're done with this now */
582 CloseHandle(hToken);
583
584 /* return status */
585 return dwError;
586 }