Removed some badly formatted checkpoints.
[reactos.git] / reactos / lib / msafd / misc / dllmain.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
4 * FILE: misc/dllmain.c
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7 * Alex Ionescu (alex@relsoft.net)
8 * REVISIONS:
9 * CSH 01/09-2000 Created
10 * Alex 16/07/2004 - Complete Rewrite
11 */
12 #include <roscfg.h>
13 #include <string.h>
14 #include <msafd.h>
15 #include <helpers.h>
16 #include <rosrtl/string.h>
17
18 #ifdef DBG
19 //DWORD DebugTraceLevel = DEBUG_ULTRA;
20 DWORD DebugTraceLevel = 0;
21 #endif /* DBG */
22
23 HANDLE GlobalHeap;
24 WSPUPCALLTABLE Upcalls;
25 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
26 ULONG SocketCount;
27 PSOCKET_INFORMATION *Sockets = NULL;
28 LIST_ENTRY SockHelpersListHead = {NULL};
29 ULONG SockAsyncThreadRefCount;
30 HANDLE SockAsyncHelperAfdHandle;
31 HANDLE SockAsyncCompletionPort;
32 BOOLEAN SockAsyncSelectCalled;
33
34 SOCKET
35 WSPAPI
36 WSPSocket(
37 int AddressFamily,
38 int SocketType,
39 int Protocol,
40 LPWSAPROTOCOL_INFOW lpProtocolInfo,
41 GROUP g,
42 DWORD dwFlags,
43 LPINT lpErrno)
44 /*
45 * FUNCTION: Creates a new socket
46 * ARGUMENTS:
47 * af = Address family
48 * type = Socket type
49 * protocol = Protocol type
50 * lpProtocolInfo = Pointer to protocol information
51 * g = Reserved
52 * dwFlags = Socket flags
53 * lpErrno = Address of buffer for error information
54 * RETURNS:
55 * Created socket, or INVALID_SOCKET if it could not be created
56 */
57 {
58 OBJECT_ATTRIBUTES Object;
59 IO_STATUS_BLOCK IOSB;
60 USHORT SizeOfPacket;
61 ULONG SizeOfEA;
62 PAFD_CREATE_PACKET AfdPacket;
63 HANDLE Sock;
64 PSOCKET_INFORMATION Socket = NULL;
65 PFILE_FULL_EA_INFORMATION EABuffer = NULL;
66 PHELPER_DATA HelperData;
67 PVOID HelperDLLContext;
68 DWORD HelperEvents;
69 DWORD IOOptions;
70 UNICODE_STRING TransportName;
71 UNICODE_STRING DevName;
72 LARGE_INTEGER GroupData;
73 INT Status;
74
75 AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
76 AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
77 AddressFamily, SocketType, Protocol));
78
79 /* Get Helper Data and Transport */
80 Status = SockGetTdiName (&AddressFamily,
81 &SocketType,
82 &Protocol,
83 g,
84 dwFlags,
85 &TransportName,
86 &HelperDLLContext,
87 &HelperData,
88 &HelperEvents);
89
90 /* Check for error */
91 if (Status != NO_ERROR) {
92 AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
93 goto error;
94 }
95
96 /* AFD Device Name */
97 RtlInitUnicodeString(&DevName, L"\\Device\\Afd\\Endpoint");
98
99 /* Set Socket Data */
100 Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
101 RtlZeroMemory(Socket, sizeof(*Socket));
102 Socket->RefCount = 2;
103 Socket->Handle = -1;
104 Socket->SharedData.State = SocketOpen;
105 Socket->SharedData.AddressFamily = AddressFamily;
106 Socket->SharedData.SocketType = SocketType;
107 Socket->SharedData.Protocol = Protocol;
108 Socket->HelperContext = HelperDLLContext;
109 Socket->HelperData = HelperData;
110 Socket->HelperEvents = HelperEvents;
111 Socket->LocalAddress = &Socket->WSLocalAddress;
112 Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
113 Socket->RemoteAddress = &Socket->WSRemoteAddress;
114 Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
115 Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
116 Socket->SharedData.CreateFlags = dwFlags;
117 Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
118 Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
119 Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
120 Socket->SharedData.GroupID = g;
121 Socket->SharedData.GroupType = 0;
122 Socket->SharedData.UseSAN = FALSE;
123 Socket->SanData = NULL;
124
125 /* Ask alex about this */
126 if( Socket->SharedData.SocketType == SOCK_DGRAM )
127 Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
128
129 /* Packet Size */
130 SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
131
132 /* EA Size */
133 SizeOfEA = SizeOfPacket + sizeof(FILE_FULL_EA_INFORMATION) + AFD_PACKET_COMMAND_LENGTH;
134
135 /* Set up EA Buffer */
136 EABuffer = HeapAlloc(GlobalHeap, 0, (SIZE_T)&SizeOfEA);
137 EABuffer->NextEntryOffset = 0;
138 EABuffer->Flags = 0;
139 EABuffer->EaNameLength = AFD_PACKET_COMMAND_LENGTH;
140 RtlCopyMemory (EABuffer->EaName,
141 AfdCommand,
142 AFD_PACKET_COMMAND_LENGTH + 1);
143 EABuffer->EaValueLength = SizeOfPacket;
144
145 /* Set up AFD Packet */
146 AfdPacket = (PAFD_CREATE_PACKET)(EABuffer->EaName + EABuffer->EaNameLength + 1);
147 AfdPacket->SizeOfTransportName = TransportName.Length;
148 RtlCopyMemory (AfdPacket->TransportName,
149 TransportName.Buffer,
150 TransportName.Length + sizeof(WCHAR));
151 AfdPacket->GroupID = g;
152
153 /* Set up Endpoint Flags */
154 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0) {
155 if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW)) {
156 goto error; /* Only RAW or UDP can be Connectionless */
157 }
158 AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
159 }
160
161 if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0) {
162 if (SocketType == SOCK_STREAM) {
163 if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0) {
164 goto error; /* The Provider doesn't actually support Message Oriented Streams */
165 }
166 }
167 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MESSAGE_ORIENTED;
168 }
169
170 if (SocketType == SOCK_RAW) AfdPacket->EndpointFlags |= AFD_ENDPOINT_RAW;
171
172 if (dwFlags & (WSA_FLAG_MULTIPOINT_C_ROOT |
173 WSA_FLAG_MULTIPOINT_C_LEAF |
174 WSA_FLAG_MULTIPOINT_D_ROOT |
175 WSA_FLAG_MULTIPOINT_D_LEAF)) {
176 if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0) {
177 goto error; /* The Provider doesn't actually support Multipoint */
178 }
179 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
180 if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT) {
181 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
182 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0)) {
183 goto error; /* The Provider doesn't support Control Planes, or you already gave a leaf */
184 }
185 AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
186 }
187 if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT) {
188 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
189 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0)) {
190 goto error; /* The Provider doesn't support Data Planes, or you already gave a leaf */
191 }
192 AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
193 }
194 }
195
196 /* Set up Object Attributes */
197 InitializeObjectAttributes (&Object,
198 &DevName,
199 OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
200 0,
201 0);
202
203 /* Set IO Flag */
204 if ((dwFlags & WSA_FLAG_OVERLAPPED) == 0) IOOptions = FILE_SYNCHRONOUS_IO_NONALERT;
205
206 /* Create the Socket */
207 ZwCreateFile(&Sock,
208 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
209 &Object,
210 &IOSB,
211 NULL,
212 0,
213 FILE_SHARE_READ | FILE_SHARE_WRITE,
214 FILE_OPEN_IF,
215 IOOptions,
216 EABuffer,
217 SizeOfEA);
218
219 /* Save Handle */
220 Socket->Handle = (SOCKET)Sock;
221
222 /* Save Group Info */
223 if (g != 0) {
224 GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
225
226 Socket->SharedData.GroupID = GroupData.u.LowPart;
227 Socket->SharedData.GroupType = GroupData.u.HighPart;
228 }
229
230 /* Get Window Sizes and Save them */
231 GetSocketInformation (Socket,
232 AFD_INFO_SEND_WINDOW_SIZE,
233 &Socket->SharedData.SizeOfSendBuffer,
234 NULL);
235 GetSocketInformation (Socket,
236 AFD_INFO_RECEIVE_WINDOW_SIZE,
237 &Socket->SharedData.SizeOfRecvBuffer,
238 NULL);
239
240 /* Save in Process Sockets List */
241 Sockets[SocketCount] = Socket;
242 SocketCount ++;
243
244 /* Create the Socket Context */
245 CreateContext(Socket);
246
247 /* Notify Winsock */
248 Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
249
250 /* Return Socket Handle */
251 return (SOCKET)Sock;
252
253 error:
254 AFD_DbgPrint(MID_TRACE,("Ending\n"));
255
256 if( lpErrno ) *lpErrno = Status;
257
258 return INVALID_SOCKET;
259 }
260
261
262 DWORD MsafdReturnWithErrno( NTSTATUS Status, LPINT Errno, DWORD Received,
263 LPDWORD ReturnedBytes ) {
264 switch (Status) {
265 case STATUS_CANT_WAIT: *Errno = WSAEWOULDBLOCK; break;
266 case STATUS_TIMEOUT:
267 case STATUS_SUCCESS:
268 /* Return Number of bytes Read */
269 if( ReturnedBytes ) *ReturnedBytes = Received; break;
270 case STATUS_PENDING: *Errno = WSA_IO_PENDING; break;
271 case STATUS_BUFFER_OVERFLOW: *Errno = WSAEMSGSIZE; break;
272 default: {
273 DbgPrint("MSAFD: Error %d is unknown\n", Errno);
274 *Errno = WSAEINVAL; break;
275 } break;
276 }
277
278 /* Success */
279 return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
280 }
281
282
283 INT
284 WSPAPI
285 WSPCloseSocket(
286 IN SOCKET s,
287 OUT LPINT lpErrno)
288 /*
289 * FUNCTION: Closes an open socket
290 * ARGUMENTS:
291 * s = Socket descriptor
292 * lpErrno = Address of buffer for error information
293 * RETURNS:
294 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
295 */
296 {
297 return 0;
298 }
299
300
301 INT
302 WSPAPI
303 WSPBind(
304 SOCKET Handle,
305 struct sockaddr *SocketAddress,
306 int SocketAddressLength,
307 LPINT lpErrno)
308 /*
309 * FUNCTION: Associates a local address with a socket
310 * ARGUMENTS:
311 * s = Socket descriptor
312 * name = Pointer to local address
313 * namelen = Length of name
314 * lpErrno = Address of buffer for error information
315 * RETURNS:
316 * 0, or SOCKET_ERROR if the socket could not be bound
317 */
318 {
319 IO_STATUS_BLOCK IOSB;
320 PAFD_BIND_DATA BindData;
321 PSOCKET_INFORMATION Socket = NULL;
322 NTSTATUS Status;
323 UCHAR BindBuffer[0x1A];
324 SOCKADDR_INFO SocketInfo;
325 HANDLE SockEvent;
326
327 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
328 NULL, 1, FALSE );
329
330 if( !NT_SUCCESS(Status) ) return -1;
331
332 /* Get the Socket Structure associate to this Socket*/
333 Socket = GetSocketStructure(Handle);
334
335 /* Dynamic Structure...ugh */
336 BindData = (PAFD_BIND_DATA)BindBuffer;
337
338 /* Set up Address in TDI Format */
339 BindData->Address.TAAddressCount = 1;
340 BindData->Address.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
341 BindData->Address.Address[0].AddressType = SocketAddress->sa_family;
342 RtlCopyMemory (BindData->Address.Address[0].Address,
343 SocketAddress->sa_data,
344 SocketAddressLength - sizeof(SocketAddress->sa_family));
345
346 /* Get Address Information */
347 Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
348 SocketAddressLength,
349 &SocketInfo);
350
351 /* Set the Share Type */
352 if (Socket->SharedData.ExclusiveAddressUse) {
353 BindData->ShareType = AFD_SHARE_EXCLUSIVE;
354 }
355 else if (SocketInfo.EndpointInfo == SockaddrEndpointInfoWildcard) {
356 BindData->ShareType = AFD_SHARE_WILDCARD;
357 }
358 else if (Socket->SharedData.ReuseAddresses) {
359 BindData->ShareType = AFD_SHARE_REUSE;
360 } else {
361 BindData->ShareType = AFD_SHARE_UNIQUE;
362 }
363
364 /* Send IOCTL */
365 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
366 SockEvent,
367 NULL,
368 NULL,
369 &IOSB,
370 IOCTL_AFD_BIND,
371 BindData,
372 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
373 BindData,
374 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
375
376 /* Set up Socket Data */
377 Socket->SharedData.State = SocketBound;
378 Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
379
380 NtClose( SockEvent );
381
382 return MsafdReturnWithErrno
383 ( IOSB.Status, lpErrno, IOSB.Information, NULL );
384 }
385
386 int
387 WSPAPI
388 WSPListen(
389 SOCKET Handle,
390 int Backlog,
391 LPINT lpErrno)
392 {
393 IO_STATUS_BLOCK IOSB;
394 AFD_LISTEN_DATA ListenData;
395 PSOCKET_INFORMATION Socket = NULL;
396 HANDLE SockEvent;
397 NTSTATUS Status;
398
399 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
400 NULL, 1, FALSE );
401
402 if( !NT_SUCCESS(Status) ) return -1;
403
404 /* Get the Socket Structure associate to this Socket*/
405 Socket = GetSocketStructure(Handle);
406
407 /* Set Up Listen Structure */
408 ListenData.UseSAN = FALSE;
409 ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
410 ListenData.Backlog = Backlog;
411
412 /* Send IOCTL */
413 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
414 SockEvent,
415 NULL,
416 NULL,
417 &IOSB,
418 IOCTL_AFD_START_LISTEN,
419 &ListenData,
420 sizeof(ListenData),
421 NULL,
422 0);
423
424 /* Set to Listening */
425 Socket->SharedData.Listening = TRUE;
426
427 NtClose( SockEvent );
428
429 return MsafdReturnWithErrno
430 ( IOSB.Status, lpErrno, IOSB.Information, NULL );
431 }
432
433
434 int
435 WSPAPI
436 WSPSelect(
437 int nfds,
438 fd_set *readfds,
439 fd_set *writefds,
440 fd_set *exceptfds,
441 struct timeval *timeout,
442 LPINT lpErrno)
443 {
444 IO_STATUS_BLOCK IOSB;
445 PAFD_POLL_INFO PollInfo;
446 NTSTATUS Status;
447 ULONG HandleCount, OutCount = 0;
448 ULONG PollBufferSize;
449 LARGE_INTEGER uSec;
450 PVOID PollBuffer;
451 ULONG i, j = 0, x;
452 HANDLE SockEvent;
453
454 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
455 NULL, 1, FALSE );
456
457 if( !NT_SUCCESS(Status) ) return SOCKET_ERROR;
458
459 /* Find out how many sockets we have, and how large the buffer needs
460 * to be */
461
462 HandleCount =
463 ( readfds ? readfds->fd_count : 0 ) +
464 ( writefds ? writefds->fd_count : 0 ) +
465 ( exceptfds ? exceptfds->fd_count : 0 );
466 PollBufferSize = sizeof(*PollInfo) +
467 (HandleCount * sizeof(AFD_HANDLE));
468
469 /* Allocate */
470 PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
471 PollInfo = (PAFD_POLL_INFO)PollBuffer;
472
473 RtlZeroMemory( PollInfo, PollBufferSize );
474
475 /* Convert Timeout to NT Format */
476 if (timeout == NULL) {
477 PollInfo->Timeout.u.LowPart = -1;
478 PollInfo->Timeout.u.HighPart = 0x7FFFFFFF;
479 } else {
480 PollInfo->Timeout = RtlEnlargedIntegerMultiply
481 ((timeout->tv_sec * 1000) + timeout->tv_usec, -10000);
482 PollInfo->Timeout.QuadPart += uSec.QuadPart;
483 }
484
485 /* Number of handles for AFD to Check */
486 PollInfo->HandleCount = HandleCount;
487 PollInfo->Exclusive = FALSE;
488
489 if (readfds != NULL) {
490 for (i = 0; i < readfds->fd_count; i++, j++) {
491 PollInfo->Handles[j].Handle = readfds->fd_array[i];
492 PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE | AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
493 }
494 }
495 if (writefds != NULL) {
496 for (i = 0; i < writefds->fd_count; i++, j++) {
497 PollInfo->Handles[j].Handle = writefds->fd_array[i];
498 PollInfo->Handles[j].Events |= AFD_EVENT_SEND;
499 }
500
501 }
502 if (exceptfds != NULL) {
503 for (i = 0; i < exceptfds->fd_count; i++, j++) {
504 PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
505 PollInfo->Handles[j].Events |=
506 AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
507 }
508 }
509
510 /* Send IOCTL */
511 Status = NtDeviceIoControlFile( (HANDLE)Sockets[0]->Handle,
512 SockEvent,
513 NULL,
514 NULL,
515 &IOSB,
516 IOCTL_AFD_SELECT,
517 PollInfo,
518 PollBufferSize,
519 PollInfo,
520 PollBufferSize);
521
522 AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
523
524 /* Wait for Completition */
525 if (Status == STATUS_PENDING) {
526 WaitForSingleObject(SockEvent, INFINITE);
527 }
528
529 /* Clear the Structures */
530 if( readfds ) FD_ZERO(readfds);
531 if( writefds ) FD_ZERO(writefds);
532 if( exceptfds ) FD_ZERO(exceptfds);
533
534 /* Loop through return structure */
535 HandleCount = PollInfo->HandleCount;
536
537 /* Return in FDSET Format */
538 for (i = 0; i < HandleCount; i++) {
539 for(x = 1; x; x<<=1) {
540 switch (PollInfo->Handles[i].Events & x) {
541 case AFD_EVENT_RECEIVE:
542 case AFD_EVENT_DISCONNECT:
543 case AFD_EVENT_ABORT:
544 case AFD_EVENT_ACCEPT:
545 case AFD_EVENT_CLOSE:
546 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
547 PollInfo->Handles[i].Events,
548 PollInfo->Handles[i].Handle));
549 OutCount++;
550 if( readfds ) FD_SET(PollInfo->Handles[i].Handle, readfds);
551 break;
552
553 case AFD_EVENT_SEND: case AFD_EVENT_CONNECT:
554 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
555 PollInfo->Handles[i].Events,
556 PollInfo->Handles[i].Handle));
557 OutCount++;
558 if( writefds ) FD_SET(PollInfo->Handles[i].Handle, writefds);
559 break;
560
561 case AFD_EVENT_OOB_RECEIVE: case AFD_EVENT_CONNECT_FAIL:
562 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
563 PollInfo->Handles[i].Events,
564 PollInfo->Handles[i].Handle));
565 OutCount++;
566 if( exceptfds ) FD_SET(PollInfo->Handles[i].Handle, exceptfds);
567 break;
568 }
569 }
570 }
571
572 NtClose( SockEvent );
573
574 AFD_DbgPrint(MID_TRACE,("lpErrno = %x\n", lpErrno));
575
576 if( lpErrno ) {
577 switch( IOSB.Status ) {
578 case STATUS_SUCCESS:
579 case STATUS_TIMEOUT: *lpErrno = 0; break;
580 default: *lpErrno = WSAEINVAL; break;
581 }
582 }
583
584 AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
585
586 return OutCount;
587 }
588
589 SOCKET
590 WSPAPI
591 WSPAccept(
592 SOCKET Handle,
593 struct sockaddr *SocketAddress,
594 int *SocketAddressLength,
595 LPCONDITIONPROC lpfnCondition,
596 DWORD_PTR dwCallbackData,
597 LPINT lpErrno)
598 {
599 IO_STATUS_BLOCK IOSB;
600 PAFD_RECEIVED_ACCEPT_DATA ListenReceiveData;
601 AFD_ACCEPT_DATA AcceptData;
602 AFD_DEFER_ACCEPT_DATA DeferData;
603 AFD_PENDING_ACCEPT_DATA PendingAcceptData;
604 PSOCKET_INFORMATION Socket = NULL;
605 NTSTATUS Status;
606 struct fd_set ReadSet;
607 struct timeval Timeout;
608 PVOID PendingData;
609 ULONG PendingDataLength;
610 PVOID CalleeDataBuffer;
611 WSABUF CallerData, CalleeID, CallerID, CalleeData;
612 PSOCKADDR RemoteAddress = NULL;
613 GROUP GroupID = 0;
614 ULONG CallBack;
615 WSAPROTOCOL_INFOW ProtocolInfo;
616 SOCKET AcceptSocket;
617 UCHAR ReceiveBuffer[0x1A];
618 HANDLE SockEvent;
619
620 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
621 NULL, 1, FALSE );
622
623 if( !NT_SUCCESS(Status) ) return -1;
624
625 /* Dynamic Structure...ugh */
626 ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
627
628 /* Get the Socket Structure associate to this Socket*/
629 Socket = GetSocketStructure(Handle);
630
631 /* If this is non-blocking, make sure there's something for us to accept */
632 FD_ZERO(&ReadSet);
633 FD_SET(Socket->Handle, &ReadSet);
634 Timeout.tv_sec=0;
635 Timeout.tv_usec=0;
636
637 WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
638
639 if (ReadSet.fd_array[0] != Socket->Handle) return 0;
640
641 /* Send IOCTL */
642 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
643 SockEvent,
644 NULL,
645 NULL,
646 &IOSB,
647 IOCTL_AFD_WAIT_FOR_LISTEN,
648 NULL,
649 0,
650 ListenReceiveData,
651 0xA + sizeof(*ListenReceiveData));
652
653 if (lpfnCondition != NULL) {
654 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0) {
655 /* Find out how much data is pending */
656 PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
657 PendingAcceptData.ReturnSize = TRUE;
658
659 /* Send IOCTL */
660 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
661 SockEvent,
662 NULL,
663 NULL,
664 &IOSB,
665 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
666 &PendingAcceptData,
667 sizeof(PendingAcceptData),
668 &PendingAcceptData,
669 sizeof(PendingAcceptData));
670
671 /* How much data to allocate */
672 PendingDataLength = IOSB.Information;
673
674 if (PendingDataLength) {
675 /* Allocate needed space */
676 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
677
678 /* We want the data now */
679 PendingAcceptData.ReturnSize = FALSE;
680
681 /* Send IOCTL */
682 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
683 SockEvent,
684 NULL,
685 NULL,
686 &IOSB,
687 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
688 &PendingAcceptData,
689 sizeof(PendingAcceptData),
690 PendingData,
691 PendingDataLength);
692 }
693 }
694
695 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0) {
696 /* I don't support this yet */
697 }
698
699 /* Build Callee ID */
700 CalleeID.buf = (PVOID)Socket->LocalAddress;
701 CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
702
703 /* Set up Address in SOCKADDR Format */
704 RtlCopyMemory (RemoteAddress,
705 &ListenReceiveData->Address.Address[0].AddressType,
706 sizeof(RemoteAddress));
707
708 /* Build Caller ID */
709 CallerID.buf = (PVOID)RemoteAddress;
710 CallerID.len = sizeof(RemoteAddress);
711
712 /* Build Caller Data */
713 CallerData.buf = PendingData;
714 CallerData.len = PendingDataLength;
715
716 /* Check if socket supports Conditional Accept */
717 if (Socket->SharedData.UseDelayedAcceptance != 0) {
718 /* Allocate Buffer for Callee Data */
719 CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
720 CalleeData.buf = CalleeDataBuffer;
721 CalleeData.len = 4096;
722 } else {
723 /* Nothing */
724 CalleeData.buf = 0;
725 CalleeData.len = 0;
726 }
727
728 /* Call the Condition Function */
729 CallBack = (lpfnCondition)( &CallerID,
730 CallerData.buf == NULL
731 ? NULL
732 : & CallerData,
733 NULL,
734 NULL,
735 &CalleeID,
736 CalleeData.buf == NULL
737 ? NULL
738 : & CalleeData,
739 &GroupID,
740 dwCallbackData);
741
742 if (((CallBack == CF_ACCEPT) && GroupID) != 0) {
743 /* TBD: Check for Validity */
744 }
745
746 if (CallBack == CF_ACCEPT) {
747
748 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0) {
749 /* I don't support this yet */
750 }
751
752 if (CalleeData.buf) {
753 // SockSetConnectData Sockets(SocketID), IOCTL_AFD_SET_CONNECT_DATA, CalleeData.Buffer, CalleeData.BuffSize, 0
754 }
755
756 } else {
757 /* Callback rejected. Build Defer Structure */
758 DeferData.SequenceNumber = ListenReceiveData->SequenceNumber;
759 DeferData.RejectConnection = (CallBack == CF_REJECT);
760
761 /* Send IOCTL */
762 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
763 SockEvent,
764 NULL,
765 NULL,
766 &IOSB,
767 IOCTL_AFD_DEFER_ACCEPT,
768 &DeferData,
769 sizeof(DeferData),
770 NULL,
771 0);
772
773 NtClose( SockEvent );
774
775 if (CallBack == CF_REJECT ) {
776 return WSAECONNREFUSED;
777 } else {
778 return WSATRY_AGAIN;
779 }
780 }
781 }
782
783 /* Create a new Socket */
784 ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
785 ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
786 ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
787
788 AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
789 Socket->SharedData.SocketType,
790 Socket->SharedData.Protocol,
791 &ProtocolInfo,
792 GroupID,
793 Socket->SharedData.CreateFlags,
794 NULL);
795
796 /* Set up the Accept Structure */
797 AcceptData.ListenHandle = AcceptSocket;
798 AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
799
800 /* Send IOCTL to Accept */
801 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
802 SockEvent,
803 NULL,
804 NULL,
805 &IOSB,
806 IOCTL_AFD_ACCEPT,
807 &AcceptData,
808 sizeof(AcceptData),
809 NULL,
810 0);
811
812 /* Return Address in SOCKADDR FORMAT */
813 RtlCopyMemory (SocketAddress,
814 &ListenReceiveData->Address.Address[0].AddressType,
815 sizeof(RemoteAddress));
816
817 NtClose( SockEvent );
818
819 AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
820
821 /* Return Socket */
822 return AcceptSocket;
823 }
824
825 int
826 WSPAPI
827 WSPConnect(
828 SOCKET Handle,
829 struct sockaddr * SocketAddress,
830 int SocketAddressLength,
831 LPWSABUF lpCallerData,
832 LPWSABUF lpCalleeData,
833 LPQOS lpSQOS,
834 LPQOS lpGQOS,
835 LPINT lpErrno)
836 {
837 IO_STATUS_BLOCK IOSB;
838 PAFD_CONNECT_INFO ConnectInfo;
839 PSOCKET_INFORMATION Socket = NULL;
840 NTSTATUS Status;
841 UCHAR ConnectBuffer[0x22];
842 ULONG ConnectDataLength;
843 ULONG InConnectDataLength;
844 UINT BindAddressLength;
845 PSOCKADDR BindAddress;
846 HANDLE SockEvent;
847
848 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
849 NULL, 1, FALSE );
850
851 if( !NT_SUCCESS(Status) ) return -1;
852
853 AFD_DbgPrint(MID_TRACE,("Called\n"));
854
855 /* Get the Socket Structure associate to this Socket*/
856 Socket = GetSocketStructure(Handle);
857
858 /* Bind us First */
859 if (Socket->SharedData.State == SocketOpen) {
860
861 /* Get the Wildcard Address */
862 BindAddressLength = Socket->HelperData->MaxWSAddressLength;
863 BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
864 Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
865 BindAddress,
866 &BindAddressLength);
867
868 /* Bind it */
869 WSPBind(Handle, BindAddress, BindAddressLength, NULL);
870 }
871
872 /* Set the Connect Data */
873 if (lpCallerData != NULL) {
874 ConnectDataLength = lpCallerData->len;
875 Status = NtDeviceIoControlFile((HANDLE)Handle,
876 SockEvent,
877 NULL,
878 NULL,
879 &IOSB,
880 IOCTL_AFD_SET_CONNECT_DATA,
881 lpCallerData->buf,
882 ConnectDataLength,
883 NULL,
884 0);
885 }
886
887 /* Dynamic Structure...ugh */
888 ConnectInfo = (PAFD_CONNECT_INFO)ConnectBuffer;
889
890 /* Set up Address in TDI Format */
891 ConnectInfo->RemoteAddress.TAAddressCount = 1;
892 ConnectInfo->RemoteAddress.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
893 ConnectInfo->RemoteAddress.Address[0].AddressType = SocketAddress->sa_family;
894 RtlCopyMemory (ConnectInfo->RemoteAddress.Address[0].Address,
895 SocketAddress->sa_data,
896 SocketAddressLength - sizeof(SocketAddress->sa_family));
897
898 /* Tell AFD that we want Connection Data back, have it allocate a buffer */
899 if (lpCalleeData != NULL) {
900 InConnectDataLength = lpCalleeData->len;
901 Status = NtDeviceIoControlFile((HANDLE)Handle,
902 SockEvent,
903 NULL,
904 NULL,
905 &IOSB,
906 IOCTL_AFD_SET_CONNECT_DATA_SIZE,
907 &InConnectDataLength,
908 sizeof(InConnectDataLength),
909 NULL,
910 0);
911 }
912
913 /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
914 ConnectInfo->Root = 0;
915 ConnectInfo->UseSAN = FALSE;
916 ConnectInfo->Unknown = 0;
917
918 /* Send IOCTL */
919 Status = NtDeviceIoControlFile((HANDLE)Handle,
920 SockEvent,
921 NULL,
922 NULL,
923 &IOSB,
924 IOCTL_AFD_CONNECT,
925 ConnectInfo,
926 0x22,
927 NULL,
928 0);
929
930 /* Get any pending connect data */
931 if (lpCalleeData != NULL) {
932 Status = NtDeviceIoControlFile((HANDLE)Handle,
933 SockEvent,
934 NULL,
935 NULL,
936 &IOSB,
937 IOCTL_AFD_GET_CONNECT_DATA,
938 NULL,
939 0,
940 lpCalleeData->buf,
941 lpCalleeData->len);
942 }
943
944 AFD_DbgPrint(MID_TRACE,("Ending\n"));
945
946 NtClose( SockEvent );
947
948 return MsafdReturnWithErrno( IOSB.Status, lpErrno, 0, NULL );
949 }
950 int
951 WSPAPI
952 WSPShutdown(
953 SOCKET Handle,
954 int HowTo,
955 LPINT lpErrno)
956
957 {
958 IO_STATUS_BLOCK IOSB;
959 AFD_DISCONNECT_INFO DisconnectInfo;
960 PSOCKET_INFORMATION Socket = NULL;
961 NTSTATUS Status;
962 HANDLE SockEvent;
963
964 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
965 NULL, 1, FALSE );
966
967 if( !NT_SUCCESS(Status) ) return -1;
968
969 AFD_DbgPrint(MID_TRACE,("Called\n"));
970
971 /* Get the Socket Structure associate to this Socket*/
972 Socket = GetSocketStructure(Handle);
973
974 /* Set AFD Disconnect Type */
975 switch (HowTo) {
976
977 case SD_RECEIVE:
978 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
979 Socket->SharedData.ReceiveShutdown = TRUE;
980 break;
981
982 case SD_SEND:
983 DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
984 Socket->SharedData.SendShutdown = TRUE;
985 break;
986
987 case SD_BOTH:
988 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
989 Socket->SharedData.ReceiveShutdown = TRUE;
990 Socket->SharedData.SendShutdown = TRUE;
991 break;
992 }
993
994 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
995
996 /* Send IOCTL */
997 Status = NtDeviceIoControlFile((HANDLE)Handle,
998 SockEvent,
999 NULL,
1000 NULL,
1001 &IOSB,
1002 IOCTL_AFD_DISCONNECT,
1003 &DisconnectInfo,
1004 sizeof(DisconnectInfo),
1005 NULL,
1006 0);
1007
1008 /* Wait for return */
1009 if (Status == STATUS_PENDING) {
1010 WaitForSingleObject(SockEvent, INFINITE);
1011 }
1012
1013 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1014
1015 NtClose( SockEvent );
1016
1017 return MsafdReturnWithErrno( IOSB.Status, lpErrno, 0, NULL );
1018 }
1019
1020
1021 INT
1022 WSPAPI
1023 WSPGetSockName(
1024 IN SOCKET Handle,
1025 OUT LPSOCKADDR Name,
1026 IN OUT LPINT NameLength,
1027 OUT LPINT lpErrno)
1028 {
1029 IO_STATUS_BLOCK IOSB;
1030 ULONG TdiAddressSize;
1031 PTDI_ADDRESS_INFO TdiAddress;
1032 PTRANSPORT_ADDRESS SocketAddress;
1033 PSOCKET_INFORMATION Socket = NULL;
1034 NTSTATUS Status;
1035 HANDLE SockEvent;
1036
1037 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1038 NULL, 1, FALSE );
1039
1040 if( !NT_SUCCESS(Status) ) return SOCKET_ERROR;
1041
1042 /* Get the Socket Structure associate to this Socket*/
1043 Socket = GetSocketStructure(Handle);
1044
1045 /* Allocate a buffer for the address */
1046 TdiAddressSize = FIELD_OFFSET(TDI_ADDRESS_INFO,
1047 Address.Address[0].Address) +
1048 Socket->SharedData.SizeOfLocalAddress;
1049 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1050
1051 if ( TdiAddress == NULL ) {
1052 NtClose( SockEvent );
1053 *lpErrno = WSAENOBUFS;
1054 return SOCKET_ERROR;
1055 }
1056
1057 SocketAddress = &TdiAddress->Address;
1058
1059 /* Send IOCTL */
1060 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1061 SockEvent,
1062 NULL,
1063 NULL,
1064 &IOSB,
1065 IOCTL_AFD_GET_SOCK_NAME,
1066 NULL,
1067 0,
1068 TdiAddress,
1069 TdiAddressSize);
1070
1071 /* Wait for return */
1072 if (Status == STATUS_PENDING) {
1073 WaitForSingleObject(SockEvent, INFINITE);
1074 Status = IOSB.Status;
1075 }
1076
1077 NtClose( SockEvent );
1078
1079 if (NT_SUCCESS(Status)) {
1080 if (*NameLength >= SocketAddress->Address[0].AddressLength) {
1081 Name->sa_family = SocketAddress->Address[0].AddressType;
1082 RtlCopyMemory (Name->sa_data,
1083 SocketAddress->Address[0].Address,
1084 SocketAddress->Address[0].AddressLength);
1085 HeapFree(GlobalHeap, 0, TdiAddress);
1086 return 0;
1087 } else {
1088 HeapFree(GlobalHeap, 0, TdiAddress);
1089 *lpErrno = WSAEFAULT;
1090 return SOCKET_ERROR;
1091 }
1092 }
1093
1094 return MsafdReturnWithErrno
1095 ( IOSB.Status, lpErrno, 0, NULL );
1096 }
1097
1098
1099 INT
1100 WSPAPI
1101 WSPGetPeerName(
1102 IN SOCKET s,
1103 OUT LPSOCKADDR name,
1104 IN OUT LPINT namelen,
1105 OUT LPINT lpErrno)
1106 {
1107
1108 return 0;
1109 }
1110
1111 INT
1112 WSPAPI
1113 WSPIoctl(
1114 IN SOCKET Handle,
1115 IN DWORD dwIoControlCode,
1116 IN LPVOID lpvInBuffer,
1117 IN DWORD cbInBuffer,
1118 OUT LPVOID lpvOutBuffer,
1119 IN DWORD cbOutBuffer,
1120 OUT LPDWORD lpcbBytesReturned,
1121 IN LPWSAOVERLAPPED lpOverlapped,
1122 IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
1123 IN LPWSATHREADID lpThreadId,
1124 OUT LPINT lpErrno)
1125 {
1126 PSOCKET_INFORMATION Socket = NULL;
1127
1128 /* Get the Socket Structure associate to this Socket*/
1129 Socket = GetSocketStructure(Handle);
1130
1131 switch( dwIoControlCode ) {
1132 case FIONBIO:
1133 if( cbInBuffer < sizeof(INT) ) return SOCKET_ERROR;
1134 Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
1135 AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n",
1136 Handle, Socket->SharedData.NonBlocking));
1137 return 0;
1138 default:
1139 return SOCKET_ERROR;
1140 }
1141 }
1142
1143
1144 INT
1145 WSPAPI
1146 WSPStartup(
1147 IN WORD wVersionRequested,
1148 OUT LPWSPDATA lpWSPData,
1149 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
1150 IN WSPUPCALLTABLE UpcallTable,
1151 OUT LPWSPPROC_TABLE lpProcTable)
1152 /*
1153 * FUNCTION: Initialize service provider for a client
1154 * ARGUMENTS:
1155 * wVersionRequested = Highest WinSock SPI version that the caller can use
1156 * lpWSPData = Address of WSPDATA structure to initialize
1157 * lpProtocolInfo = Pointer to structure that defines the desired protocol
1158 * UpcallTable = Pointer to upcall table of the WinSock DLL
1159 * lpProcTable = Address of procedure table to initialize
1160 * RETURNS:
1161 * Status of operation
1162 */
1163 {
1164 NTSTATUS Status;
1165
1166 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
1167
1168 Status = NO_ERROR;
1169
1170 Upcalls = UpcallTable;
1171
1172 if (Status == NO_ERROR) {
1173 lpProcTable->lpWSPAccept = WSPAccept;
1174 lpProcTable->lpWSPAddressToString = WSPAddressToString;
1175 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
1176 lpProcTable->lpWSPBind = (LPWSPBIND)WSPBind;
1177 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
1178 lpProcTable->lpWSPCleanup = WSPCleanup;
1179 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
1180 lpProcTable->lpWSPConnect = (LPWSPCONNECT)WSPConnect;
1181 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
1182 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
1183 lpProcTable->lpWSPEventSelect = WSPEventSelect;
1184 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
1185 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
1186 lpProcTable->lpWSPGetSockName = WSPGetSockName;
1187 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
1188 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
1189 lpProcTable->lpWSPIoctl = WSPIoctl;
1190 lpProcTable->lpWSPJoinLeaf = (LPWSPJOINLEAF)WSPJoinLeaf;
1191 lpProcTable->lpWSPListen = WSPListen;
1192 lpProcTable->lpWSPRecv = WSPRecv;
1193 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
1194 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
1195 lpProcTable->lpWSPSelect = WSPSelect;
1196 lpProcTable->lpWSPSend = WSPSend;
1197 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
1198 lpProcTable->lpWSPSendTo = (LPWSPSENDTO)WSPSendTo;
1199 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
1200 lpProcTable->lpWSPShutdown = WSPShutdown;
1201 lpProcTable->lpWSPSocket = WSPSocket;
1202 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
1203
1204 lpWSPData->wVersion = MAKEWORD(2, 2);
1205 lpWSPData->wHighVersion = MAKEWORD(2, 2);
1206 }
1207
1208 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
1209
1210 return Status;
1211 }
1212
1213
1214 INT
1215 WSPAPI
1216 WSPCleanup(
1217 OUT LPINT lpErrno)
1218 /*
1219 * FUNCTION: Cleans up service provider for a client
1220 * ARGUMENTS:
1221 * lpErrno = Address of buffer for error information
1222 * RETURNS:
1223 * 0 if successful, or SOCKET_ERROR if not
1224 */
1225 {
1226 AFD_DbgPrint(MAX_TRACE, ("\n"));
1227
1228
1229 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
1230
1231 *lpErrno = NO_ERROR;
1232
1233 return 0;
1234 }
1235
1236
1237
1238 int
1239 GetSocketInformation(
1240 PSOCKET_INFORMATION Socket,
1241 ULONG AfdInformationClass,
1242 PULONG Ulong OPTIONAL,
1243 PLARGE_INTEGER LargeInteger OPTIONAL)
1244 {
1245 IO_STATUS_BLOCK IOSB;
1246 AFD_INFO InfoData;
1247 NTSTATUS Status;
1248 HANDLE SockEvent;
1249
1250 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1251 NULL, 1, FALSE );
1252
1253 if( !NT_SUCCESS(Status) ) return -1;
1254
1255 /* Set Info Class */
1256 InfoData.InformationClass = AfdInformationClass;
1257
1258 /* Send IOCTL */
1259 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1260 SockEvent,
1261 NULL,
1262 NULL,
1263 &IOSB,
1264 IOCTL_AFD_GET_INFO,
1265 &InfoData,
1266 sizeof(InfoData),
1267 &InfoData,
1268 sizeof(InfoData));
1269
1270 /* Wait for return */
1271 if (Status == STATUS_PENDING) {
1272 WaitForSingleObject(SockEvent, INFINITE);
1273 }
1274
1275 /* Return Information */
1276 *Ulong = InfoData.Information.Ulong;
1277 if (LargeInteger != NULL) {
1278 *LargeInteger = InfoData.Information.LargeInteger;
1279 }
1280
1281 NtClose( SockEvent );
1282
1283 return 0;
1284
1285 }
1286
1287
1288 int
1289 SetSocketInformation(
1290 PSOCKET_INFORMATION Socket,
1291 ULONG AfdInformationClass,
1292 PULONG Ulong OPTIONAL,
1293 PLARGE_INTEGER LargeInteger OPTIONAL)
1294 {
1295 IO_STATUS_BLOCK IOSB;
1296 AFD_INFO InfoData;
1297 NTSTATUS Status;
1298 HANDLE SockEvent;
1299
1300 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1301 NULL, 1, FALSE );
1302
1303 if( !NT_SUCCESS(Status) ) return -1;
1304
1305 /* Set Info Class */
1306 InfoData.InformationClass = AfdInformationClass;
1307
1308 /* Set Information */
1309 InfoData.Information.Ulong = *Ulong;
1310 if (LargeInteger != NULL) {
1311 InfoData.Information.LargeInteger = *LargeInteger;
1312 }
1313
1314 /* Send IOCTL */
1315 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1316 SockEvent,
1317 NULL,
1318 NULL,
1319 &IOSB,
1320 IOCTL_AFD_GET_INFO,
1321 &InfoData,
1322 sizeof(InfoData),
1323 NULL,
1324 0);
1325
1326 /* Wait for return */
1327 if (Status == STATUS_PENDING) {
1328 WaitForSingleObject(SockEvent, INFINITE);
1329 }
1330
1331 NtClose( SockEvent );
1332
1333 return 0;
1334
1335 }
1336
1337 PSOCKET_INFORMATION
1338 GetSocketStructure(
1339 SOCKET Handle)
1340 {
1341 ULONG i;
1342
1343 for (i=0; i<SocketCount; i++) {
1344 if (Sockets[i]->Handle == Handle) {
1345 return Sockets[i];
1346 }
1347 }
1348 return 0;
1349 }
1350
1351 int CreateContext(PSOCKET_INFORMATION Socket)
1352 {
1353 IO_STATUS_BLOCK IOSB;
1354 SOCKET_CONTEXT ContextData;
1355 NTSTATUS Status;
1356 HANDLE SockEvent;
1357
1358 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1359 NULL, 1, FALSE );
1360
1361 if( !NT_SUCCESS(Status) ) return -1;
1362
1363 /* Create Context */
1364 ContextData.SharedData = Socket->SharedData;
1365 ContextData.SizeOfHelperData = 0;
1366 RtlCopyMemory (&ContextData.LocalAddress,
1367 Socket->LocalAddress,
1368 Socket->SharedData.SizeOfLocalAddress);
1369 RtlCopyMemory (&ContextData.RemoteAddress,
1370 Socket->RemoteAddress,
1371 Socket->SharedData.SizeOfRemoteAddress);
1372
1373 /* Send IOCTL */
1374 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1375 SockEvent,
1376 NULL,
1377 NULL,
1378 &IOSB,
1379 IOCTL_AFD_SET_CONTEXT,
1380 &ContextData,
1381 sizeof(ContextData),
1382 NULL,
1383 0);
1384
1385 /* Wait for Completition */
1386 if (Status == STATUS_PENDING) {
1387 WaitForSingleObject(SockEvent, INFINITE);
1388 }
1389
1390 NtClose( SockEvent );
1391
1392 return 0;
1393 }
1394
1395 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
1396 {
1397 HANDLE hAsyncThread;
1398 DWORD AsyncThreadId;
1399 HANDLE AsyncEvent;
1400 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
1401 NTSTATUS Status;
1402
1403 /* Check if the Thread Already Exists */
1404 if (SockAsyncThreadRefCount) {
1405 return TRUE;
1406 }
1407
1408 /* Create the Completion Port */
1409 if (!SockAsyncCompletionPort) {
1410 Status = NtCreateIoCompletion(&SockAsyncCompletionPort,
1411 IO_COMPLETION_ALL_ACCESS,
1412 NULL,
1413 2); // Allow 2 threads only
1414
1415 /* Protect Handle */
1416 HandleFlags.ProtectFromClose = TRUE;
1417 HandleFlags.Inherit = FALSE;
1418 Status = NtSetInformationObject(SockAsyncCompletionPort,
1419 ObjectHandleInformation,
1420 &HandleFlags,
1421 sizeof(HandleFlags));
1422 }
1423
1424 /* Create the Async Event */
1425 Status = NtCreateEvent(&AsyncEvent,
1426 EVENT_ALL_ACCESS,
1427 NULL,
1428 NotificationEvent,
1429 FALSE);
1430
1431 /* Create the Async Thread */
1432 hAsyncThread = CreateThread(NULL,
1433 0,
1434 (LPTHREAD_START_ROUTINE)SockAsyncThread,
1435 NULL,
1436 0,
1437 &AsyncThreadId);
1438
1439 /* Close the Handle */
1440 NtClose(hAsyncThread);
1441
1442 /* Increase the Reference Count */
1443 SockAsyncThreadRefCount++;
1444 return TRUE;
1445 }
1446
1447 int SockAsyncThread(PVOID ThreadParam)
1448 {
1449 PVOID AsyncContext;
1450 PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
1451 IO_STATUS_BLOCK IOSB;
1452 NTSTATUS Status;
1453 LARGE_INTEGER Timeout;
1454
1455 /* Make the Thread Higher Priority */
1456 SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
1457
1458 /* Do a KQUEUE/WorkItem Style Loop, thanks to IoCompletion Ports */
1459 do {
1460 Status = NtRemoveIoCompletion (SockAsyncCompletionPort,
1461 (PVOID*)&AsyncCompletionRoutine,
1462 &AsyncContext,
1463 &IOSB,
1464 &Timeout);
1465
1466 /* Call the Async Function */
1467 if (NT_SUCCESS(Status)) {
1468 //(*AsyncCompletionRoutine)(AsyncContext, IOSB);
1469 } else {
1470 /* It Failed, sleep for a second */
1471 Sleep(1000);
1472 }
1473 } while ((Status != STATUS_TIMEOUT));
1474
1475 /* The Thread has Ended */
1476 return 0;
1477 }
1478
1479 BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
1480 {
1481 UNICODE_STRING AfdHelper;
1482 OBJECT_ATTRIBUTES ObjectAttributes;
1483 IO_STATUS_BLOCK IoSb;
1484 NTSTATUS Status;
1485 FILE_COMPLETION_INFORMATION CompletionInfo;
1486 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
1487
1488 /* First, make sure we're not already intialized */
1489 if (SockAsyncHelperAfdHandle) {
1490 return TRUE;
1491 }
1492
1493 /* Set up Handle Name and Object */
1494 RtlInitUnicodeString(&AfdHelper, L"\\Device\\Afd\\AsyncSelectHlp" );
1495 InitializeObjectAttributes(&ObjectAttributes,
1496 &AfdHelper,
1497 OBJ_INHERIT | OBJ_CASE_INSENSITIVE,
1498 NULL,
1499 NULL);
1500
1501 /* Open the Handle to AFD */
1502 Status = NtCreateFile(&SockAsyncHelperAfdHandle,
1503 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
1504 &ObjectAttributes,
1505 &IoSb,
1506 NULL,
1507 0,
1508 FILE_SHARE_READ | FILE_SHARE_WRITE,
1509 FILE_OPEN_IF,
1510 0,
1511 NULL,
1512 0);
1513
1514 /*
1515 * Now Set up the Completion Port Information
1516 * This means that whenever a Poll is finished, the routine will be executed
1517 */
1518 CompletionInfo.Port = SockAsyncCompletionPort;
1519 CompletionInfo.Key = SockAsyncSelectCompletionRoutine;
1520 Status = NtSetInformationFile(SockAsyncHelperAfdHandle,
1521 &IoSb,
1522 &CompletionInfo,
1523 sizeof(CompletionInfo),
1524 FileCompletionInformation);
1525
1526
1527 /* Protect the Handle */
1528 HandleFlags.ProtectFromClose = TRUE;
1529 HandleFlags.Inherit = FALSE;
1530 Status = NtSetInformationObject(SockAsyncCompletionPort,
1531 ObjectHandleInformation,
1532 &HandleFlags,
1533 sizeof(HandleFlags));
1534
1535
1536 /* Set this variable to true so that Send/Recv/Accept will know wether to renable disabled events */
1537 SockAsyncSelectCalled = TRUE;
1538 return TRUE;
1539 }
1540
1541 VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
1542 {
1543
1544 PASYNC_DATA AsyncData = Context;
1545 PSOCKET_INFORMATION Socket;
1546 ULONG x;
1547
1548 /* Get the Socket */
1549 Socket = AsyncData->ParentSocket;
1550
1551 /* Check if the Sequence Number Changed behind our back */
1552 if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber ){
1553 return;
1554 }
1555
1556 /* Check we were manually called b/c of a failure */
1557 if (!NT_SUCCESS(IoStatusBlock->Status)) {
1558 /* FIXME: Perform Upcall */
1559 return;
1560 }
1561
1562
1563 for (x = 1; x; x<<=1) {
1564 switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x) {
1565 case AFD_EVENT_RECEIVE:
1566 if ((Socket->SharedData.AsyncEvents & FD_READ) && (!Socket->SharedData.AsyncDisabledEvents & FD_READ)) {
1567 /* Make the Notifcation */
1568 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
1569 Socket->SharedData.wMsg,
1570 Socket->Handle,
1571 WSAMAKESELECTREPLY(FD_READ, 0));
1572 /* Disable this event until the next read(); */
1573 Socket->SharedData.AsyncDisabledEvents |= FD_READ;
1574 }
1575
1576 case AFD_EVENT_OOB_RECEIVE:
1577 if ((Socket->SharedData.AsyncEvents & FD_OOB) && (!Socket->SharedData.AsyncDisabledEvents & FD_OOB)) {
1578 /* Make the Notifcation */
1579 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
1580 Socket->SharedData.wMsg,
1581 Socket->Handle,
1582 WSAMAKESELECTREPLY(FD_OOB, 0));
1583 /* Disable this event until the next read(); */
1584 Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
1585 }
1586
1587 case AFD_EVENT_SEND:
1588 if ((Socket->SharedData.AsyncEvents & FD_WRITE) && (!Socket->SharedData.AsyncDisabledEvents & FD_WRITE)) {
1589 /* Make the Notifcation */
1590 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
1591 Socket->SharedData.wMsg,
1592 Socket->Handle,
1593 WSAMAKESELECTREPLY(FD_WRITE, 0));
1594 /* Disable this event until the next write(); */
1595 Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
1596 }
1597
1598 case AFD_EVENT_ACCEPT:
1599 if ((Socket->SharedData.AsyncEvents & FD_ACCEPT) && (!Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT)) {
1600 /* Make the Notifcation */
1601 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
1602 Socket->SharedData.wMsg,
1603 Socket->Handle,
1604 WSAMAKESELECTREPLY(FD_ACCEPT, 0));
1605 /* Disable this event until the next accept(); */
1606 Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
1607 }
1608
1609 case AFD_EVENT_DISCONNECT:
1610 case AFD_EVENT_ABORT:
1611 case AFD_EVENT_CLOSE:
1612 if ((Socket->SharedData.AsyncEvents & FD_CLOSE) && (!Socket->SharedData.AsyncDisabledEvents & FD_CLOSE)) {
1613 /* Make the Notifcation */
1614 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
1615 Socket->SharedData.wMsg,
1616 Socket->Handle,
1617 WSAMAKESELECTREPLY(FD_CLOSE, 0));
1618 /* Disable this event forever; */
1619 Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
1620 }
1621
1622 /* FIXME: Support QOS */
1623 }
1624 }
1625
1626 /* Check if there are any events left for us to check */
1627 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 ) {
1628 return;
1629 }
1630
1631 /* Keep Polling */
1632 SockProcessAsyncSelect(Socket, AsyncData);
1633 return;
1634 }
1635
1636 VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
1637 {
1638
1639 ULONG lNetworkEvents;
1640 NTSTATUS Status;
1641
1642 /* Set up the Async Data Event Info */
1643 AsyncData->AsyncSelectInfo.Timeout.HighPart = 0x7FFFFFFF;
1644 AsyncData->AsyncSelectInfo.Timeout.LowPart = 0xFFFFFFFF;
1645 AsyncData->AsyncSelectInfo.HandleCount = 1;
1646 AsyncData->AsyncSelectInfo.Exclusive = TRUE;
1647 AsyncData->AsyncSelectInfo.Handles[0].Handle = Socket->Handle;
1648 AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
1649
1650 /* Remove unwanted events */
1651 lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
1652
1653 /* Set Events to wait for */
1654 if (lNetworkEvents & FD_READ) {
1655 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_RECEIVE;
1656 }
1657
1658 if (lNetworkEvents & FD_WRITE) {
1659 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_SEND;
1660 }
1661
1662 if (lNetworkEvents & FD_OOB) {
1663 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_OOB_RECEIVE;
1664 }
1665
1666 if (lNetworkEvents & FD_ACCEPT) {
1667 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_ACCEPT;
1668 }
1669
1670 if (lNetworkEvents & FD_CONNECT) {
1671 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_CONNECT | AFD_EVENT_CONNECT_FAIL;
1672 }
1673
1674 if (lNetworkEvents & FD_CLOSE) {
1675 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
1676 }
1677
1678 if (lNetworkEvents & FD_QOS) {
1679 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_QOS;
1680 }
1681
1682 if (lNetworkEvents & FD_GROUP_QOS) {
1683 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_GROUP_QOS;
1684 }
1685
1686 /* Send IOCTL */
1687 Status = NtDeviceIoControlFile (SockAsyncHelperAfdHandle,
1688 NULL,
1689 NULL,
1690 AsyncData,
1691 &AsyncData->IoStatusBlock,
1692 IOCTL_AFD_SELECT,
1693 &AsyncData->AsyncSelectInfo,
1694 sizeof(AsyncData->AsyncSelectInfo),
1695 &AsyncData->AsyncSelectInfo,
1696 sizeof(AsyncData->AsyncSelectInfo));
1697
1698 /* I/O Manager Won't call the completion routine, let's do it manually */
1699 if (NT_SUCCESS(Status)) {
1700 return;
1701 } else {
1702 AsyncData->IoStatusBlock.Status = Status;
1703 SockAsyncSelectCompletionRoutine(AsyncData, &AsyncData->IoStatusBlock);
1704 }
1705 }
1706
1707 VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
1708 {
1709 PASYNC_DATA AsyncData = Context;
1710 PSOCKET_INFORMATION Socket;
1711
1712 /* Get the Socket */
1713 Socket = AsyncData->ParentSocket;
1714
1715 /* If someone closed it, stop the function */
1716 if (Socket->SharedData.State != SocketClosed) {
1717 /* Check if the Sequence Number changed by now, in which case quit */
1718 if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber) {
1719 /* Do the actuall select, if needed */
1720 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents))) {
1721 SockProcessAsyncSelect(Socket, AsyncData);
1722 }
1723 }
1724 }
1725
1726 /* Free the Context */
1727 HeapFree(GetProcessHeap(), 0, AsyncData);
1728 return;
1729 }
1730
1731 BOOL
1732 STDCALL
1733 DllMain(HANDLE hInstDll,
1734 ULONG dwReason,
1735 PVOID Reserved)
1736 {
1737
1738 switch (dwReason) {
1739 case DLL_PROCESS_ATTACH:
1740
1741 AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
1742
1743 /* Don't need thread attach notifications
1744 so disable them to improve performance */
1745 DisableThreadLibraryCalls(hInstDll);
1746
1747 /* List of DLL Helpers */
1748 InitializeListHead(&SockHelpersListHead);
1749
1750 /* Heap to use when allocating */
1751 GlobalHeap = GetProcessHeap();
1752
1753 /* Allocate Heap for 1024 Sockets, can be expanded later */
1754 Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
1755
1756 AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
1757
1758 break;
1759
1760 case DLL_THREAD_ATTACH:
1761 break;
1762
1763 case DLL_THREAD_DETACH:
1764 break;
1765
1766 case DLL_PROCESS_DETACH:
1767 break;
1768 }
1769
1770 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
1771
1772 return TRUE;
1773 }
1774
1775 /* EOF */
1776
1777