[KMTEST] Initial usermode support for testing FS mini-filters (#81)
[reactos.git] / modules / rostests / kmtests / kmtest / filter.c
1 /*
2 * PROJECT: ReactOS kernel-mode tests
3 * LICENSE: GPLv2+ - See COPYING in the top level directory
4 * PURPOSE: File system filter implementation of the original service.c file
5 * PROGRAMMER: Thomas Faber <thomas.faber@reactos.org>
6 * Ged Murphy <gedmurphy@reactos.org>
7 */
8
9 //#include <fltuser.h>
10 #include <kmt_test.h>
11 #include "kmtest.h"
12
13 #include <assert.h>
14
15 #define SERVICE_ACCESS (SERVICE_START | SERVICE_STOP | DELETE)
16
17 /*
18 * We need to call the internal function in the service.c file
19 */
20 DWORD
21 KmtpCreateService(
22 IN PCWSTR ServiceName,
23 IN PCWSTR ServicePath,
24 IN PCWSTR DisplayName OPTIONAL,
25 IN DWORD ServiceType,
26 OUT SC_HANDLE *ServiceHandle);
27
28 static SC_HANDLE ScmHandle;
29
30
31 /**
32 * @name KmtFltCreateService
33 *
34 * Create the specified driver service and return a handle to it
35 *
36 * @param ServiceName
37 * Name of the service to create
38 * @param ServicePath
39 * File name of the driver, relative to the current directory
40 * @param DisplayName
41 * Service display name
42 * @param ServiceHandle
43 * Pointer to a variable to receive the handle to the service
44 *
45 * @return Win32 error code
46 */
47 DWORD
48 KmtFltCreateService(
49 _In_z_ PCWSTR ServiceName,
50 _In_z_ PCWSTR ServicePath,
51 _In_z_ PCWSTR DisplayName OPTIONAL,
52 _Out_ SC_HANDLE *ServiceHandle)
53 {
54 return KmtpCreateService(ServiceName,
55 ServicePath,
56 DisplayName,
57 SERVICE_FILE_SYSTEM_DRIVER,
58 ServiceHandle);
59 }
60
61 /**
62 * @name KmtFltLoad
63 *
64 * Start the specified filter driver by name
65 *
66 * @param ServiceName
67 * The name of the filter to start
68 *
69 * @return Win32 error code
70 */
71 DWORD
72 KmtFltLoad(
73 _In_z_ PCWSTR ServiceName)
74 {
75 HRESULT hResult;
76 DWORD Error = ERROR_SUCCESS;
77
78 assert(ServiceName);
79
80 hResult = FilterLoad(ServiceName);
81 Error = SCODE_CODE(hResult);
82
83 return Error;
84 }
85
86 /**
87 * @name KmtFltCreateAndStartService
88 *
89 * Create and load the specified filter driver and return a handle to it
90 *
91 * @param ServiceName
92 * Name of the service to create
93 * @param ServicePath
94 * File name of the driver, relative to the current directory
95 * @param DisplayName
96 * Service display name
97 * @param ServiceHandle
98 * Pointer to a variable to receive the handle to the service
99 * @param RestartIfRunning
100 * TRUE to stop and restart the service if it is already running
101 *
102 * @return Win32 error code
103 */
104 DWORD
105 KmtFltCreateAndStartService(
106 _In_z_ PCWSTR ServiceName,
107 _In_z_ PCWSTR ServicePath,
108 _In_z_ PCWSTR DisplayName OPTIONAL,
109 _Out_ SC_HANDLE *ServiceHandle,
110 _In_ BOOLEAN RestartIfRunning)
111 {
112 DWORD Error = ERROR_SUCCESS;
113
114 assert(ServiceHandle);
115
116 Error = KmtFltCreateService(ServiceName, ServicePath, DisplayName, ServiceHandle);
117
118 if (Error == ERROR_SERVICE_EXISTS)
119 *ServiceHandle = OpenService(ScmHandle, ServiceName, SERVICE_ACCESS);
120
121 if (Error && Error != ERROR_SERVICE_EXISTS)
122 goto cleanup;
123
124 Error = KmtFltLoad(ServiceName);
125
126 if (Error != ERROR_SERVICE_ALREADY_RUNNING)
127 goto cleanup;
128
129 Error = ERROR_SUCCESS;
130
131 if (!RestartIfRunning)
132 goto cleanup;
133
134 Error = KmtFltUnload(ServiceName);
135 if (Error)
136 goto cleanup;
137
138 Error = KmtFltLoad(ServiceName);
139 if (Error)
140 goto cleanup;
141
142 cleanup:
143 assert(Error);
144 return Error;
145 }
146
147
148 /**
149 * @name KmtFltConnect
150 *
151 * Create a comms connection to the specified filter
152 *
153 * @param ServiceName
154 * Name of the filter to connect to
155 * @param hPort
156 * Handle to the filter's comms port
157 *
158 * @return Win32 error code
159 */
160 DWORD
161 KmtFltConnect(
162 _In_z_ PCWSTR ServiceName,
163 _Out_ HANDLE *hPort)
164 {
165 HRESULT hResult;
166 DWORD Error = ERROR_SUCCESS;
167
168 assert(ServiceName);
169 assert(hPort);
170
171 hResult = FilterConnectCommunicationPort(ServiceName,
172 0,
173 NULL,
174 0,
175 NULL,
176 hPort);
177 Error = SCODE_CODE(hResult);
178
179 return Error;
180 }
181
182 /**
183 * @name KmtFltDisconnect
184 *
185 * Disconenct from the comms port
186 *
187 * @param hPort
188 * Handle to the filter's comms port
189 *
190 * @return Win32 error code
191 */
192 DWORD
193 KmtFltDisconnect(
194 _Out_ HANDLE *hPort)
195 {
196 DWORD Error = ERROR_SUCCESS;
197
198 assert(hPort);
199
200 if (!CloseHandle(hPort))
201 {
202 Error = GetLastError();
203 }
204
205 return Error;
206 }
207
208 /**
209 * @name KmtFltSendMessage
210 *
211 * Sneds a message to a filter driver
212 *
213 * @param hPort
214 * Handle to the filter's comms port
215 * @InBuffer
216 * Pointer to a buffer to send to the filter
217 * @InBufferSize
218 * Size of the buffer pointed to by InBuffer
219 * @OutBuffer
220 * Pointer to a buffer to receive reply data from the filter
221 * @OutBufferSize
222 * Size of the buffer pointed to by OutBuffer
223 * @BytesReturned
224 * Number of bytes written in the reply buffer
225 *
226 * @return Win32 error code
227 */
228 DWORD
229 KmtFltSendMessage(
230 _In_ HANDLE hPort,
231 _In_reads_bytes_(dwInBufferSize) LPVOID InBuffer,
232 _In_ DWORD InBufferSize,
233 _Out_writes_bytes_to_opt_(dutBufferSize, *BytesReturned) LPVOID OutBuffer,
234 _In_ DWORD OutBufferSize,
235 _Out_opt_ LPDWORD BytesReturned)
236 {
237 DWORD BytesRet;
238 HRESULT hResult;
239 DWORD Error;
240
241 assert(hPort);
242 assert(InBuffer);
243 assert(InBufferSize);
244
245 if (BytesReturned) *BytesReturned = 0;
246
247 hResult = FilterSendMessage(hPort,
248 InBuffer,
249 InBufferSize,
250 OutBuffer,
251 OutBufferSize,
252 &BytesRet);
253
254 Error = SCODE_CODE(hResult);
255 if (Error == ERROR_SUCCESS)
256 {
257 if (BytesRet)
258 {
259 *BytesReturned = BytesRet;
260 }
261 }
262
263 return Error;
264 }
265
266 /**
267 * @name KmtFltGetMessage
268 *
269 * Gets a message from a filter driver
270 *
271 * @param hPort
272 * Handle to the filter's comms port
273 * @MessageBuffer
274 * Pointer to a buffer to receive the data from the filter
275 * @MessageBufferSize
276 * Size of the buffer pointed to by MessageBuffer
277 * @Overlapped
278 * Pointer to an overlapped structure
279 *
280 * @return Win32 error code
281 */
282 DWORD
283 KmtFltGetMessage(
284 _In_ HANDLE hPort,
285 _Out_writes_bytes_(MessageBufferSize) PFILTER_MESSAGE_HEADER MessageBuffer,
286 _In_ DWORD MessageBufferSize,
287 _In_opt_ LPOVERLAPPED Overlapped)
288 {
289 HRESULT hResult;
290 DWORD Error;
291
292 assert(hPort);
293 assert(MessageBuffer);
294
295 hResult = FilterGetMessage(hPort,
296 MessageBuffer,
297 MessageBufferSize,
298 Overlapped);
299 Error = SCODE_CODE(hResult);
300 return Error;
301 }
302
303 /**
304 * @name KmtFltReplyMessage
305 *
306 * Replies to a message from a filter driver
307 *
308 * @param hPort
309 * Handle to the filter's comms port
310 * @ReplyBuffer
311 * Pointer to a buffer to return to the filter
312 * @ReplyBufferSize
313 * Size of the buffer pointed to by ReplyBuffer
314 *
315 * @return Win32 error code
316 */
317 DWORD
318 KmtFltReplyMessage(
319 _In_ HANDLE hPort,
320 _In_reads_bytes_(ReplyBufferSize) PFILTER_REPLY_HEADER ReplyBuffer,
321 _In_ DWORD ReplyBufferSize)
322 {
323 HRESULT hResult;
324 DWORD Error;
325
326 hResult = FilterReplyMessage(hPort,
327 ReplyBuffer,
328 ReplyBufferSize);
329 Error = SCODE_CODE(hResult);
330 return Error;
331 }
332
333 /**
334 * @name KmtFltGetMessageResult
335 *
336 * Gets the overlapped result from the IO
337 *
338 * @param hPort
339 * Handle to the filter's comms port
340 * @Overlapped
341 * Pointer to the overlapped structure usdd in the IO
342 * @BytesTransferred
343 * Number of bytes transferred in the IO
344 *
345 * @return Win32 error code
346 */
347 DWORD
348 KmtFltGetMessageResult(
349 _In_ HANDLE hPort,
350 _In_ LPOVERLAPPED Overlapped,
351 _Out_ LPDWORD BytesTransferred)
352 {
353 BOOL Success;
354 DWORD Error = ERROR_SUCCESS;
355
356 *BytesTransferred = 0;
357
358 Success = GetOverlappedResult(hPort, Overlapped, BytesTransferred, TRUE);
359 if (!Success)
360 {
361 Error = GetLastError();
362 }
363
364 return Error;
365 }
366
367 /**
368 * @name KmtFltUnload
369 *
370 * Unload the specified filter driver
371 *
372 * @param ServiceName
373 * The name of the filter to unload
374 *
375 * @return Win32 error code
376 */
377 DWORD
378 KmtFltUnload(
379 _In_z_ PCWSTR ServiceName)
380 {
381 HRESULT hResult;
382 DWORD Error = ERROR_SUCCESS;
383
384 assert(ServiceName);
385
386 hResult = FilterUnload(ServiceName);
387 Error = SCODE_CODE(hResult);
388
389 return Error;
390 }
391
392 /**
393 * @name KmtFltDeleteService
394 *
395 * Delete the specified filter driver
396 *
397 * @param ServiceName
398 * If *ServiceHandle is NULL, name of the service to delete
399 * @param ServiceHandle
400 * Pointer to a variable containing the service handle.
401 * Will be set to NULL on success
402 *
403 * @return Win32 error code
404 */
405 DWORD
406 KmtFltDeleteService(
407 _In_z_ PCWSTR ServiceName OPTIONAL,
408 _Inout_ SC_HANDLE *ServiceHandle)
409 {
410 return KmtDeleteService(ServiceName, ServiceHandle);
411 }
412
413 /**
414 * @name KmtFltCloseService
415 *
416 * Close the specified driver service handle
417 *
418 * @param ServiceHandle
419 * Pointer to a variable containing the service handle.
420 * Will be set to NULL on success
421 *
422 * @return Win32 error code
423 */
424 DWORD KmtFltCloseService(
425 _Inout_ SC_HANDLE *ServiceHandle)
426 {
427 return KmtCloseService(ServiceHandle);
428 }