70f2444c64557c6c0843a67cd9999db6659361de
[reactos.git] / reactos / lib / msafd / misc / dllmain.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
4 * FILE: misc/dllmain.c
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7 * REVISIONS:
8 * CSH 01/09-2000 Created
9 */
10 #include <string.h>
11 #include <msafd.h>
12 #include <helpers.h>
13 #include <rosrtl/string.h>
14
15 #ifdef DBG
16
17 /* See debug.h for debug/trace constants */
18 DWORD DebugTraceLevel = MIN_TRACE;
19 //DWORD DebugTraceLevel = DEBUG_ULTRA;
20
21 #endif /* DBG */
22
23 /* To make the linker happy */
24 VOID STDCALL KeBugCheck (ULONG BugCheckCode) {}
25
26
27 HANDLE GlobalHeap;
28 WSPUPCALLTABLE Upcalls;
29 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
30 CRITICAL_SECTION InitCriticalSection;
31 DWORD StartupCount = 0;
32 HANDLE CommandChannel;
33
34
35 NTSTATUS OpenSocket(
36 SOCKET *Socket,
37 INT AddressFamily,
38 INT SocketType,
39 INT Protocol,
40 PVOID HelperContext,
41 DWORD NotificationEvents,
42 PUNICODE_STRING TdiDeviceName)
43 /*
44 * FUNCTION: Opens a socket
45 * ARGUMENTS:
46 * Socket = Address of buffer to place socket descriptor
47 * AddressFamily = Address family
48 * SocketType = Type of socket
49 * Protocol = Protocol type
50 * HelperContext = Pointer to context information for helper DLL
51 * NotificationEvents = Events for which helper DLL is to be notified
52 * TdiDeviceName = Pointer to name of TDI device to use
53 * RETURNS:
54 * Status of operation
55 */
56 {
57 OBJECT_ATTRIBUTES ObjectAttributes;
58 PAFD_SOCKET_INFORMATION SocketInfo;
59 PFILE_FULL_EA_INFORMATION EaInfo;
60 UNICODE_STRING DeviceName;
61 IO_STATUS_BLOCK Iosb;
62 HANDLE FileHandle;
63 NTSTATUS Status;
64 ULONG EaLength;
65 ULONG EaShort;
66
67 AFD_DbgPrint(MAX_TRACE, ("Socket (0x%X) TdiDeviceName (%wZ)\n",
68 Socket, TdiDeviceName));
69
70 AFD_DbgPrint(MAX_TRACE, ("Socket2 (0x%X) TdiDeviceName (%S)\n",
71 Socket, TdiDeviceName->Buffer));
72
73 EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
74 AFD_SOCKET_LENGTH +
75 sizeof(AFD_SOCKET_INFORMATION);
76
77 EaLength = EaShort + TdiDeviceName->Length + sizeof(WCHAR);
78
79 EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
80 if (!EaInfo) {
81 return STATUS_INSUFFICIENT_RESOURCES;
82 }
83
84 RtlZeroMemory(EaInfo, EaLength);
85 EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
86 RtlCopyMemory(EaInfo->EaName,
87 AfdSocket,
88 AFD_SOCKET_LENGTH);
89 EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
90
91 SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
92 SocketInfo->CommandChannel = FALSE;
93 SocketInfo->AddressFamily = AddressFamily;
94 SocketInfo->SocketType = SocketType;
95 SocketInfo->Protocol = Protocol;
96 SocketInfo->HelperContext = HelperContext;
97 SocketInfo->NotificationEvents = NotificationEvents;
98 /* Zeroed above so initialized to a wildcard address if a raw socket */
99 SocketInfo->Name.sa_family = AddressFamily;
100
101 /* Store TDI device name last in buffer */
102 SocketInfo->TdiDeviceName.Buffer = (PWCHAR)(EaInfo + EaShort);
103 SocketInfo->TdiDeviceName.MaximumLength = TdiDeviceName->Length + sizeof(WCHAR);
104 RtlCopyUnicodeString(&SocketInfo->TdiDeviceName, TdiDeviceName);
105
106 AFD_DbgPrint(MAX_TRACE, ("EaInfo at (0x%X) EaLength is (%d).\n", (UINT)EaInfo, (INT)EaLength));
107
108 RtlRosInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
109 InitializeObjectAttributes(
110 &ObjectAttributes,
111 &DeviceName,
112 0,
113 NULL,
114 NULL);
115
116 Status = NtCreateFile(
117 &FileHandle,
118 FILE_GENERIC_READ | FILE_GENERIC_WRITE,
119 &ObjectAttributes,
120 &Iosb,
121 NULL,
122 0,
123 0,
124 FILE_OPEN,
125 FILE_SYNCHRONOUS_IO_ALERT,
126 EaInfo,
127 EaLength);
128
129 HeapFree(GlobalHeap, 0, EaInfo);
130
131 if (!NT_SUCCESS(Status)) {
132 AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
133 (UINT)Status));
134 return STATUS_INSUFFICIENT_RESOURCES;
135 }
136
137 *Socket = (SOCKET)FileHandle;
138
139 return STATUS_SUCCESS;
140 }
141
142
143 SOCKET
144 WSPAPI
145 WSPSocket(
146 IN INT af,
147 IN INT type,
148 IN INT protocol,
149 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
150 IN GROUP g,
151 IN DWORD dwFlags,
152 OUT LPINT lpErrno)
153 /*
154 * FUNCTION: Creates a new socket
155 * ARGUMENTS:
156 * af = Address family
157 * type = Socket type
158 * protocol = Protocol type
159 * lpProtocolInfo = Pointer to protocol information
160 * g = Reserved
161 * dwFlags = Socket flags
162 * lpErrno = Address of buffer for error information
163 * RETURNS:
164 * Created socket, or INVALID_SOCKET if it could not be created
165 */
166 {
167 WSAPROTOCOL_INFOW ProtocolInfo;
168 UNICODE_STRING TdiDeviceName;
169 DWORD NotificationEvents;
170 PWSHELPER_DLL HelperDLL;
171 PVOID HelperContext;
172 INT AddressFamily;
173 NTSTATUS NtStatus;
174 INT SocketType;
175 SOCKET Socket2;
176 SOCKET Socket;
177 INT Protocol;
178 INT Status;
179
180 AFD_DbgPrint(MAX_TRACE, ("af (%d) type (%d) protocol (%d).\n",
181 af, type, protocol));
182
183 if (!lpProtocolInfo) {
184 lpProtocolInfo = &ProtocolInfo;
185 ZeroMemory(&ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
186
187 ProtocolInfo.iAddressFamily = af;
188 ProtocolInfo.iSocketType = type;
189 ProtocolInfo.iProtocol = protocol;
190 }
191
192 HelperDLL = LocateHelperDLL(lpProtocolInfo);
193 if (!HelperDLL) {
194 *lpErrno = WSAEAFNOSUPPORT;
195 return INVALID_SOCKET;
196 }
197
198 AddressFamily = lpProtocolInfo->iAddressFamily;
199 SocketType = lpProtocolInfo->iSocketType;
200 Protocol = lpProtocolInfo->iProtocol;
201
202 /* The OPTIONAL export WSHOpenSocket2 supersedes WSHOpenSocket */
203 if (HelperDLL->EntryTable.lpWSHOpenSocket2)
204 {
205 Status = HelperDLL->EntryTable.lpWSHOpenSocket2(
206 &AddressFamily,
207 &SocketType,
208 &Protocol,
209 0,
210 0,
211 &TdiDeviceName,
212 &HelperContext,
213 &NotificationEvents);
214 }
215 else
216 {
217 Status = HelperDLL->EntryTable.lpWSHOpenSocket(
218 &AddressFamily,
219 &SocketType,
220 &Protocol,
221 &TdiDeviceName,
222 &HelperContext,
223 &NotificationEvents);
224 }
225
226 if (Status != NO_ERROR) {
227 AFD_DbgPrint(MAX_TRACE, ("WinSock Helper DLL failed (0x%X).\n", Status));
228 *lpErrno = Status;
229 return INVALID_SOCKET;
230 }
231
232 NtStatus = OpenSocket(&Socket,
233 AddressFamily,
234 SocketType,
235 Protocol,
236 HelperContext,
237 NotificationEvents,
238 &TdiDeviceName);
239
240 RtlFreeUnicodeString(&TdiDeviceName);
241 if (!NT_SUCCESS(NtStatus)) {
242 *lpErrno = RtlNtStatusToDosError(Status);
243 return INVALID_SOCKET;
244 }
245
246 /* FIXME: Assumes catalog entry id to be 1 */
247 Socket2 = Upcalls.lpWPUModifyIFSHandle(1, Socket, lpErrno);
248
249 if (Socket2 == INVALID_SOCKET) {
250 /* FIXME: Cleanup */
251 AFD_DbgPrint(MIN_TRACE, ("FIXME: Cleanup.\n"));
252 return INVALID_SOCKET;
253 }
254
255 *lpErrno = NO_ERROR;
256
257 AFD_DbgPrint(MID_TRACE, ("Returning socket descriptor (0x%X).\n", Socket2));
258
259 return Socket2;
260 }
261
262
263 INT
264 WSPAPI
265 WSPCloseSocket(
266 IN SOCKET s,
267 OUT LPINT lpErrno)
268 /*
269 * FUNCTION: Closes an open socket
270 * ARGUMENTS:
271 * s = Socket descriptor
272 * lpErrno = Address of buffer for error information
273 * RETURNS:
274 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
275 */
276 {
277 NTSTATUS Status;
278
279 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
280
281 Status = NtClose((HANDLE)s);
282
283 if (NT_SUCCESS(Status)) {
284 *lpErrno = NO_ERROR;
285 return NO_ERROR;
286 }
287
288 *lpErrno = WSAENOTSOCK;
289 return SOCKET_ERROR;
290 }
291
292
293 INT
294 WSPAPI
295 WSPBind(
296 IN SOCKET s,
297 IN CONST LPSOCKADDR name,
298 IN INT namelen,
299 OUT LPINT lpErrno)
300 /*
301 * FUNCTION: Associates a local address with a socket
302 * ARGUMENTS:
303 * s = Socket descriptor
304 * name = Pointer to local address
305 * namelen = Length of name
306 * lpErrno = Address of buffer for error information
307 * RETURNS:
308 * 0, or SOCKET_ERROR if the socket could not be bound
309 */
310 {
311 FILE_REQUEST_BIND Request;
312 FILE_REPLY_BIND Reply;
313 IO_STATUS_BLOCK Iosb;
314 NTSTATUS Status;
315
316 AFD_DbgPrint(MAX_TRACE, ("s (0x%X) name (0x%X) namelen (%d).\n", s, name, namelen));
317
318 RtlCopyMemory(&Request.Name, name, sizeof(SOCKADDR));
319
320 Status = NtDeviceIoControlFile(
321 (HANDLE)s,
322 NULL,
323 NULL,
324 NULL,
325 &Iosb,
326 IOCTL_AFD_BIND,
327 &Request,
328 sizeof(FILE_REQUEST_BIND),
329 &Reply,
330 sizeof(FILE_REPLY_BIND));
331 if (Status == STATUS_PENDING) {
332 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
333 /* FIXME: Wait only for blocking sockets */
334 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
335 }
336
337 if (!NT_SUCCESS(Status)) {
338 *lpErrno = Reply.Status;
339 return SOCKET_ERROR;
340 }
341
342 return 0;
343 }
344
345 INT
346 WSPAPI
347 WSPListen(
348 IN SOCKET s,
349 IN INT backlog,
350 OUT LPINT lpErrno)
351 /*
352 * FUNCTION: Listens for incoming connections
353 * ARGUMENTS:
354 * s = Socket descriptor
355 * backlog = Maximum number of pending connection requests
356 * lpErrno = Address of buffer for error information
357 * RETURNS:
358 * 0, or SOCKET_ERROR if the socket could not be bound
359 */
360 {
361 FILE_REQUEST_LISTEN Request;
362 FILE_REPLY_LISTEN Reply;
363 IO_STATUS_BLOCK Iosb;
364 NTSTATUS Status;
365
366 AFD_DbgPrint(MAX_TRACE, ("s (0x%X) backlog (%d).\n", s, backlog));
367
368 Request.Backlog = backlog;
369
370 Status = NtDeviceIoControlFile(
371 (HANDLE)s,
372 NULL,
373 NULL,
374 NULL,
375 &Iosb,
376 IOCTL_AFD_LISTEN,
377 &Request,
378 sizeof(FILE_REQUEST_LISTEN),
379 &Reply,
380 sizeof(FILE_REPLY_LISTEN));
381 if (Status == STATUS_PENDING) {
382 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
383 /* FIXME: Wait only for blocking sockets */
384 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
385 }
386
387 if (!NT_SUCCESS(Status)) {
388 *lpErrno = Reply.Status;
389 return SOCKET_ERROR;
390 }
391
392 return 0;
393 }
394
395
396 INT
397 WSPAPI
398 WSPSelect(
399 IN INT nfds,
400 IN OUT LPFD_SET readfds,
401 IN OUT LPFD_SET writefds,
402 IN OUT LPFD_SET exceptfds,
403 IN CONST LPTIMEVAL timeout,
404 OUT LPINT lpErrno)
405 /*
406 * FUNCTION: Returns status of one or more sockets
407 * ARGUMENTS:
408 * nfds = Always ignored
409 * readfds = Pointer to socket set to be checked for readability (optional)
410 * writefds = Pointer to socket set to be checked for writability (optional)
411 * exceptfds = Pointer to socket set to be checked for errors (optional)
412 * timeout = Pointer to a TIMEVAL structure indicating maximum wait time
413 * (NULL means wait forever)
414 * lpErrno = Address of buffer for error information
415 * RETURNS:
416 * Number of ready socket descriptors, or SOCKET_ERROR if an error ocurred
417 */
418 {
419 PFILE_REQUEST_SELECT Request;
420 FILE_REPLY_SELECT Reply;
421 IO_STATUS_BLOCK Iosb;
422 NTSTATUS Status;
423 DWORD Size;
424 DWORD ReadSize;
425 DWORD WriteSize;
426 DWORD ExceptSize;
427 PVOID Current;
428
429 AFD_DbgPrint(MAX_TRACE, ("readfds (0x%X) writefds (0x%X) exceptfds (0x%X).\n",
430 readfds, writefds, exceptfds));
431 #if 0
432 /* FIXME: For now, all reads are timed out immediately */
433 if (readfds != NULL) {
434 AFD_DbgPrint(MID_TRACE, ("Timing out read query.\n"));
435 *lpErrno = WSAETIMEDOUT;
436 return SOCKET_ERROR;
437 }
438
439 /* FIXME: For now, always allow write */
440 if (writefds != NULL) {
441 AFD_DbgPrint(MID_TRACE, ("Setting one socket writeable.\n"));
442 *lpErrno = NO_ERROR;
443 return 1;
444 }
445 #endif
446
447 ReadSize = 0;
448 if ((readfds != NULL) && (readfds->fd_count > 0)) {
449 ReadSize = (readfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
450 }
451
452 WriteSize = 0;
453 if ((writefds != NULL) && (writefds->fd_count > 0)) {
454 WriteSize = (writefds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
455 }
456
457 ExceptSize = 0;
458 if ((exceptfds != NULL) && (exceptfds->fd_count > 0)) {
459 ExceptSize = (exceptfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
460 }
461
462 Size = ReadSize + WriteSize + ExceptSize;
463
464 Request = (PFILE_REQUEST_SELECT)HeapAlloc(
465 GlobalHeap, 0, sizeof(FILE_REQUEST_SELECT) + Size);
466 if (!Request) {
467 *lpErrno = WSAENOBUFS;
468 return SOCKET_ERROR;
469 }
470
471 /* Put FD SETs after request structure */
472 Current = (Request + 1);
473
474 if (ReadSize > 0) {
475 Request->ReadFDSet = (LPFD_SET)Current;
476 Current += ReadSize;
477 RtlCopyMemory(Request->ReadFDSet, readfds, ReadSize);
478 } else {
479 Request->ReadFDSet = NULL;
480 }
481
482 if (WriteSize > 0) {
483 Request->WriteFDSet = (LPFD_SET)Current;
484 Current += WriteSize;
485 RtlCopyMemory(Request->WriteFDSet, writefds, WriteSize);
486 } else {
487 Request->WriteFDSet = NULL;
488 }
489
490 if (ExceptSize > 0) {
491 Request->ExceptFDSet = (LPFD_SET)Current;
492 RtlCopyMemory(Request->ExceptFDSet, exceptfds, ExceptSize);
493 } else {
494 Request->ExceptFDSet = NULL;
495 }
496
497 AFD_DbgPrint(MAX_TRACE, ("R1 (0x%X) W1 (0x%X).\n", Request->ReadFDSet, Request->WriteFDSet));
498
499 Status = NtDeviceIoControlFile(
500 CommandChannel,
501 NULL,
502 NULL,
503 NULL,
504 &Iosb,
505 IOCTL_AFD_SELECT,
506 Request,
507 sizeof(FILE_REQUEST_SELECT) + Size,
508 &Reply,
509 sizeof(FILE_REPLY_SELECT));
510
511 HeapFree(GlobalHeap, 0, Request);
512
513 if (Status == STATUS_PENDING) {
514 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
515 /* FIXME: Wait only for blocking sockets */
516 Status = NtWaitForSingleObject(CommandChannel, FALSE, NULL);
517 }
518
519 if (!NT_SUCCESS(Status)) {
520 AFD_DbgPrint(MAX_TRACE, ("Status (0x%X).\n", Status));
521 *lpErrno = WSAENOBUFS;
522 return SOCKET_ERROR;
523 }
524
525 AFD_DbgPrint(MAX_TRACE, ("Select successful. Status (0x%X) Count (0x%X).\n",
526 Reply.Status, Reply.SocketCount));
527
528 *lpErrno = Reply.Status;
529
530 return Reply.SocketCount;
531 }
532
533 SOCKET
534 WSPAPI
535 WSPAccept(
536 IN SOCKET s,
537 OUT LPSOCKADDR addr,
538 IN OUT LPINT addrlen,
539 IN LPCONDITIONPROC lpfnCondition,
540 IN DWORD dwCallbackData,
541 OUT LPINT lpErrno)
542 {
543 FILE_REQUEST_ACCEPT Request;
544 FILE_REPLY_ACCEPT Reply;
545 IO_STATUS_BLOCK Iosb;
546 NTSTATUS Status;
547
548 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
549
550 Request.addr = addr;
551 Request.addrlen = *addrlen;
552 Request.lpfnCondition = lpfnCondition;
553 Request.dwCallbackData = dwCallbackData;
554
555 Status = NtDeviceIoControlFile(
556 (HANDLE)s,
557 NULL,
558 NULL,
559 NULL,
560 &Iosb,
561 IOCTL_AFD_ACCEPT,
562 &Request,
563 sizeof(FILE_REQUEST_ACCEPT),
564 &Reply,
565 sizeof(FILE_REPLY_ACCEPT));
566 if (Status == STATUS_PENDING) {
567 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
568 /* FIXME: Wait only for blocking sockets */
569 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
570 }
571
572 if (!NT_SUCCESS(Status)) {
573 *lpErrno = Reply.Status;
574 return INVALID_SOCKET;
575 }
576
577 *addrlen = Reply.addrlen;
578
579 return Reply.Socket;
580 }
581
582
583 INT
584 WSPAPI
585 WSPConnect(
586 IN SOCKET s,
587 IN CONST LPSOCKADDR name,
588 IN INT namelen,
589 IN LPWSABUF lpCallerData,
590 OUT LPWSABUF lpCalleeData,
591 IN LPQOS lpSQOS,
592 IN LPQOS lpGQOS,
593 OUT LPINT lpErrno)
594 {
595 FILE_REQUEST_CONNECT Request;
596 FILE_REPLY_CONNECT Reply;
597 IO_STATUS_BLOCK Iosb;
598 NTSTATUS Status;
599
600 AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
601
602 Request.name = name;
603 Request.namelen = namelen;
604 Request.lpCallerData = lpCallerData;
605 Request.lpCalleeData = lpCalleeData;
606 Request.lpSQOS = lpSQOS;
607 Request.lpGQOS = lpGQOS;
608
609 Status = NtDeviceIoControlFile(
610 (HANDLE)s,
611 NULL,
612 NULL,
613 NULL,
614 &Iosb,
615 IOCTL_AFD_CONNECT,
616 &Request,
617 sizeof(FILE_REQUEST_CONNECT),
618 &Reply,
619 sizeof(FILE_REPLY_CONNECT));
620 if (Status == STATUS_PENDING) {
621 AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
622 /* FIXME: Wait only for blocking sockets */
623 Status = NtWaitForSingleObject((HANDLE)s, FALSE, NULL);
624 }
625
626 if (!NT_SUCCESS(Status)) {
627 *lpErrno = Reply.Status;
628 return INVALID_SOCKET;
629 }
630
631 return 0;
632 }
633
634
635 NTSTATUS OpenCommandChannel(
636 VOID)
637 /*
638 * FUNCTION: Opens a command channel to afd.sys
639 * ARGUMENTS:
640 * None
641 * RETURNS:
642 * Status of operation
643 */
644 {
645 OBJECT_ATTRIBUTES ObjectAttributes;
646 PAFD_SOCKET_INFORMATION SocketInfo;
647 PFILE_FULL_EA_INFORMATION EaInfo;
648 UNICODE_STRING DeviceName;
649 IO_STATUS_BLOCK Iosb;
650 HANDLE FileHandle;
651 NTSTATUS Status;
652 ULONG EaLength;
653 ULONG EaShort;
654
655 AFD_DbgPrint(MAX_TRACE, ("Called\n"));
656
657 EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
658 AFD_SOCKET_LENGTH +
659 sizeof(AFD_SOCKET_INFORMATION);
660
661 EaLength = EaShort;
662
663 EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
664 if (!EaInfo) {
665 return STATUS_INSUFFICIENT_RESOURCES;
666 }
667
668 RtlZeroMemory(EaInfo, EaLength);
669 EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
670 RtlCopyMemory(EaInfo->EaName,
671 AfdSocket,
672 AFD_SOCKET_LENGTH);
673 EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
674
675 SocketInfo = (PAFD_SOCKET_INFORMATION)((ULONG_PTR)EaInfo->EaName + AFD_SOCKET_LENGTH);
676 SocketInfo->CommandChannel = TRUE;
677
678 RtlRosInitUnicodeStringFromLiteral(&DeviceName, L"\\Device\\Afd");
679 InitializeObjectAttributes(
680 &ObjectAttributes,
681 &DeviceName,
682 0,
683 NULL,
684 NULL);
685
686 Status = NtCreateFile(
687 &FileHandle,
688 FILE_GENERIC_READ | FILE_GENERIC_WRITE,
689 &ObjectAttributes,
690 &Iosb,
691 NULL,
692 0,
693 0,
694 FILE_OPEN,
695 FILE_SYNCHRONOUS_IO_ALERT,
696 EaInfo,
697 EaLength);
698
699 if (!NT_SUCCESS(Status)) {
700 AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
701 (UINT)Status));
702 return Status;
703 }
704
705 CommandChannel = FileHandle;
706
707 return STATUS_SUCCESS;
708 }
709
710
711 NTSTATUS CloseCommandChannel(
712 VOID)
713 /*
714 * FUNCTION: Closes command channel to afd.sys
715 * ARGUMENTS:
716 * None
717 * RETURNS:
718 * Status of operation
719 */
720 {
721 AFD_DbgPrint(MAX_TRACE, ("Called.\n"));
722
723 return NtClose(CommandChannel);
724 }
725
726
727 INT
728 WSPAPI
729 WSPStartup(
730 IN WORD wVersionRequested,
731 OUT LPWSPDATA lpWSPData,
732 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
733 IN WSPUPCALLTABLE UpcallTable,
734 OUT LPWSPPROC_TABLE lpProcTable)
735 /*
736 * FUNCTION: Initialize service provider for a client
737 * ARGUMENTS:
738 * wVersionRequested = Highest WinSock SPI version that the caller can use
739 * lpWSPData = Address of WSPDATA structure to initialize
740 * lpProtocolInfo = Pointer to structure that defines the desired protocol
741 * UpcallTable = Pointer to upcall table of the WinSock DLL
742 * lpProcTable = Address of procedure table to initialize
743 * RETURNS:
744 * Status of operation
745 */
746 {
747 HMODULE hWS2_32;
748 INT Status;
749
750 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
751
752 EnterCriticalSection(&InitCriticalSection);
753
754 Upcalls = UpcallTable;
755
756 if (StartupCount == 0) {
757 /* First time called */
758
759 Status = OpenCommandChannel();
760 if (NT_SUCCESS(Status)) {
761 hWS2_32 = GetModuleHandle(L"ws2_32.dll");
762 if (hWS2_32 != NULL) {
763 lpWPUCompleteOverlappedRequest = (LPWPUCOMPLETEOVERLAPPEDREQUEST)
764 GetProcAddress(hWS2_32, "WPUCompleteOverlappedRequest");
765 if (lpWPUCompleteOverlappedRequest != NULL) {
766 Status = NO_ERROR;
767 StartupCount++;
768 }
769 } else {
770 AFD_DbgPrint(MIN_TRACE, ("GetModuleHandle() failed for ws2_32.dll\n"));
771 }
772 } else {
773 AFD_DbgPrint(MIN_TRACE, ("Cannot open afd.sys\n"));
774 }
775 } else {
776 Status = NO_ERROR;
777 StartupCount++;
778 }
779
780 LeaveCriticalSection(&InitCriticalSection);
781
782 if (Status == NO_ERROR) {
783 lpProcTable->lpWSPAccept = WSPAccept;
784 lpProcTable->lpWSPAddressToString = WSPAddressToString;
785 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
786 lpProcTable->lpWSPBind = WSPBind;
787 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
788 lpProcTable->lpWSPCleanup = WSPCleanup;
789 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
790 lpProcTable->lpWSPConnect = WSPConnect;
791 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
792 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
793 lpProcTable->lpWSPEventSelect = WSPEventSelect;
794 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
795 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
796 lpProcTable->lpWSPGetSockName = WSPGetSockName;
797 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
798 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
799 lpProcTable->lpWSPIoctl = WSPIoctl;
800 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
801 lpProcTable->lpWSPListen = WSPListen;
802 lpProcTable->lpWSPRecv = WSPRecv;
803 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
804 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
805 lpProcTable->lpWSPSelect = WSPSelect;
806 lpProcTable->lpWSPSend = WSPSend;
807 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
808 lpProcTable->lpWSPSendTo = WSPSendTo;
809 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
810 lpProcTable->lpWSPShutdown = WSPShutdown;
811 lpProcTable->lpWSPSocket = WSPSocket;
812 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
813
814 lpWSPData->wVersion = MAKEWORD(2, 2);
815 lpWSPData->wHighVersion = MAKEWORD(2, 2);
816 }
817
818 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
819
820 return Status;
821 }
822
823
824 INT
825 WSPAPI
826 WSPCleanup(
827 OUT LPINT lpErrno)
828 /*
829 * FUNCTION: Cleans up service provider for a client
830 * ARGUMENTS:
831 * lpErrno = Address of buffer for error information
832 * RETURNS:
833 * 0 if successful, or SOCKET_ERROR if not
834 */
835 {
836 AFD_DbgPrint(MAX_TRACE, ("\n"));
837
838 EnterCriticalSection(&InitCriticalSection);
839
840 if (StartupCount > 0) {
841 StartupCount--;
842
843 if (StartupCount == 0) {
844 AFD_DbgPrint(MAX_TRACE, ("Cleaning up msafd.dll.\n"));
845
846 CloseCommandChannel();
847 }
848 }
849
850 LeaveCriticalSection(&InitCriticalSection);
851
852 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
853
854 *lpErrno = NO_ERROR;
855
856 return 0;
857 }
858
859
860 BOOL
861 STDCALL
862 DllMain(HANDLE hInstDll,
863 ULONG dwReason,
864 PVOID Reserved)
865 {
866 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll\n"));
867
868 switch (dwReason) {
869 case DLL_PROCESS_ATTACH:
870 /* Don't need thread attach notifications
871 so disable them to improve performance */
872 DisableThreadLibraryCalls(hInstDll);
873
874 InitializeCriticalSection(&InitCriticalSection);
875
876 GlobalHeap = GetProcessHeap();
877
878 CreateHelperDLLDatabase();
879 break;
880
881 case DLL_THREAD_ATTACH:
882 break;
883
884 case DLL_THREAD_DETACH:
885 break;
886
887 case DLL_PROCESS_DETACH:
888
889 DestroyHelperDLLDatabase();
890
891 DeleteCriticalSection(&InitCriticalSection);
892
893 break;
894 }
895
896 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
897
898 return TRUE;
899 }
900
901 /* EOF */