8632e1cdfb508890ba42b738fccba04af880ce61
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
1
2 /*
3 * COPYRIGHT: See COPYING in the top level directory
4 * PROJECT: ReactOS Ancillary Function Driver DLL
5 * FILE: misc/dllmain.c
6 * PURPOSE: DLL entry point
7 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
8 * Alex Ionescu (alex@relsoft.net)
9 * REVISIONS:
10 * CSH 01/09-2000 Created
11 * Alex 16/07/2004 - Complete Rewrite
12 */
13 #include <msafd.h>
14
15 #include <debug.h>
16
17 #ifdef DBG
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
19 DWORD DebugTraceLevel = 0;
20 #endif /* DBG */
21
22 HANDLE GlobalHeap;
23 WSPUPCALLTABLE Upcalls;
24 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
25 ULONG SocketCount = 0;
26 PSOCKET_INFORMATION *Sockets = NULL;
27 LIST_ENTRY SockHelpersListHead = {NULL};
28 ULONG SockAsyncThreadRefCount;
29 HANDLE SockAsyncHelperAfdHandle;
30 HANDLE SockAsyncCompletionPort;
31 BOOLEAN SockAsyncSelectCalled;
32
33 SOCKET
34 WSPAPI
35 WSPSocket(
36 int AddressFamily,
37 int SocketType,
38 int Protocol,
39 LPWSAPROTOCOL_INFOW lpProtocolInfo,
40 GROUP g,
41 DWORD dwFlags,
42 LPINT lpErrno)
43 /*
44 * FUNCTION: Creates a new socket
45 * ARGUMENTS:
46 * af = Address family
47 * type = Socket type
48 * protocol = Protocol type
49 * lpProtocolInfo = Pointer to protocol information
50 * g = Reserved
51 * dwFlags = Socket flags
52 * lpErrno = Address of buffer for error information
53 * RETURNS:
54 * Created socket, or INVALID_SOCKET if it could not be created
55 */
56 {
57 OBJECT_ATTRIBUTES Object;
58 IO_STATUS_BLOCK IOSB;
59 USHORT SizeOfPacket;
60 ULONG SizeOfEA;
61 PAFD_CREATE_PACKET AfdPacket;
62 HANDLE Sock;
63 PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
64 PFILE_FULL_EA_INFORMATION EABuffer = NULL;
65 PHELPER_DATA HelperData;
66 PVOID HelperDLLContext;
67 DWORD HelperEvents;
68 UNICODE_STRING TransportName;
69 UNICODE_STRING DevName;
70 LARGE_INTEGER GroupData;
71 INT Status;
72
73 AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
74 AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
75 AddressFamily, SocketType, Protocol));
76
77 /* Get Helper Data and Transport */
78 Status = SockGetTdiName (&AddressFamily,
79 &SocketType,
80 &Protocol,
81 g,
82 dwFlags,
83 &TransportName,
84 &HelperDLLContext,
85 &HelperData,
86 &HelperEvents);
87
88 /* Check for error */
89 if (Status != NO_ERROR) {
90 AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
91 goto error;
92 }
93
94 /* AFD Device Name */
95 RtlInitUnicodeString(&DevName, L"\\Device\\Afd\\Endpoint");
96
97 /* Set Socket Data */
98 Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
99 RtlZeroMemory(Socket, sizeof(*Socket));
100 Socket->RefCount = 2;
101 Socket->Handle = -1;
102 Socket->SharedData.State = SocketOpen;
103 Socket->SharedData.AddressFamily = AddressFamily;
104 Socket->SharedData.SocketType = SocketType;
105 Socket->SharedData.Protocol = Protocol;
106 Socket->HelperContext = HelperDLLContext;
107 Socket->HelperData = HelperData;
108 Socket->HelperEvents = HelperEvents;
109 Socket->LocalAddress = &Socket->WSLocalAddress;
110 Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
111 Socket->RemoteAddress = &Socket->WSRemoteAddress;
112 Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
113 Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
114 Socket->SharedData.CreateFlags = dwFlags;
115 Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
116 Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
117 Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
118 Socket->SharedData.GroupID = g;
119 Socket->SharedData.GroupType = 0;
120 Socket->SharedData.UseSAN = FALSE;
121 Socket->SharedData.NonBlocking = FALSE; /* Sockets start blocking */
122 Socket->SanData = NULL;
123
124 /* Ask alex about this */
125 if( Socket->SharedData.SocketType == SOCK_DGRAM ||
126 Socket->SharedData.SocketType == SOCK_RAW ) {
127 AFD_DbgPrint(MID_TRACE,("Connectionless socket\n"));
128 Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
129 }
130
131 /* Packet Size */
132 SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
133
134 /* EA Size */
135 SizeOfEA = SizeOfPacket + sizeof(FILE_FULL_EA_INFORMATION) + AFD_PACKET_COMMAND_LENGTH;
136
137 /* Set up EA Buffer */
138 EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
139 RtlZeroMemory(EABuffer, SizeOfEA);
140 EABuffer->NextEntryOffset = 0;
141 EABuffer->Flags = 0;
142 EABuffer->EaNameLength = AFD_PACKET_COMMAND_LENGTH;
143 RtlCopyMemory (EABuffer->EaName,
144 AfdCommand,
145 AFD_PACKET_COMMAND_LENGTH + 1);
146 EABuffer->EaValueLength = SizeOfPacket;
147
148 /* Set up AFD Packet */
149 AfdPacket = (PAFD_CREATE_PACKET)(EABuffer->EaName + EABuffer->EaNameLength + 1);
150 AfdPacket->SizeOfTransportName = TransportName.Length;
151 RtlCopyMemory (AfdPacket->TransportName,
152 TransportName.Buffer,
153 TransportName.Length + sizeof(WCHAR));
154 AfdPacket->GroupID = g;
155
156 /* Set up Endpoint Flags */
157 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0) {
158 if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW)) {
159 goto error; /* Only RAW or UDP can be Connectionless */
160 }
161 AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
162 }
163
164 if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0) {
165 if (SocketType == SOCK_STREAM) {
166 if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0) {
167 goto error; /* The Provider doesn't actually support Message Oriented Streams */
168 }
169 }
170 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MESSAGE_ORIENTED;
171 }
172
173 if (SocketType == SOCK_RAW) AfdPacket->EndpointFlags |= AFD_ENDPOINT_RAW;
174
175 if (dwFlags & (WSA_FLAG_MULTIPOINT_C_ROOT |
176 WSA_FLAG_MULTIPOINT_C_LEAF |
177 WSA_FLAG_MULTIPOINT_D_ROOT |
178 WSA_FLAG_MULTIPOINT_D_LEAF)) {
179 if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0) {
180 goto error; /* The Provider doesn't actually support Multipoint */
181 }
182 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
183 if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT) {
184 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
185 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0)) {
186 goto error; /* The Provider doesn't support Control Planes, or you already gave a leaf */
187 }
188 AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
189 }
190 if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT) {
191 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
192 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0)) {
193 goto error; /* The Provider doesn't support Data Planes, or you already gave a leaf */
194 }
195 AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
196 }
197 }
198
199 /* Set up Object Attributes */
200 InitializeObjectAttributes (&Object,
201 &DevName,
202 OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
203 0,
204 0);
205
206 /* Create the Socket as asynchronous. That means we have to block
207 ourselves after every call to NtDeviceIoControlFile. This is
208 because the kernel doesn't support overlapping synchronous I/O
209 requests (made from multiple threads) at this time (Sep 2005) */
210 ZwCreateFile(&Sock,
211 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
212 &Object,
213 &IOSB,
214 NULL,
215 0,
216 FILE_SHARE_READ | FILE_SHARE_WRITE,
217 FILE_OPEN_IF,
218 0,
219 EABuffer,
220 SizeOfEA);
221
222 /* Save Handle */
223 Socket->Handle = (SOCKET)Sock;
224
225 /* XXX See if there's a structure we can reuse -- We need to do this
226 * more properly. */
227 PrevSocket = GetSocketStructure( (SOCKET)Sock );
228
229 if( PrevSocket ) {
230 RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
231 RtlFreeHeap( GlobalHeap, 0, Socket );
232 Socket = PrevSocket;
233 }
234
235 /* Save Group Info */
236 if (g != 0) {
237 GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
238
239 Socket->SharedData.GroupID = GroupData.u.LowPart;
240 Socket->SharedData.GroupType = GroupData.u.HighPart;
241 }
242
243 /* Get Window Sizes and Save them */
244 GetSocketInformation (Socket,
245 AFD_INFO_SEND_WINDOW_SIZE,
246 &Socket->SharedData.SizeOfSendBuffer,
247 NULL);
248 GetSocketInformation (Socket,
249 AFD_INFO_RECEIVE_WINDOW_SIZE,
250 &Socket->SharedData.SizeOfRecvBuffer,
251 NULL);
252
253 /* Save in Process Sockets List */
254 Sockets[SocketCount] = Socket;
255 SocketCount ++;
256
257 /* Create the Socket Context */
258 CreateContext(Socket);
259
260 /* Notify Winsock */
261 Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
262
263 /* Return Socket Handle */
264 AFD_DbgPrint(MID_TRACE,("Success %x\n", Sock));
265 return (SOCKET)Sock;
266
267 error:
268 AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
269
270 if( lpErrno ) *lpErrno = Status;
271
272 return INVALID_SOCKET;
273 }
274
275
276 DWORD MsafdReturnWithErrno( NTSTATUS Status, LPINT Errno, DWORD Received,
277 LPDWORD ReturnedBytes ) {
278 if( ReturnedBytes ) *ReturnedBytes = 0;
279 if( Errno ) {
280 switch (Status) {
281 case STATUS_CANT_WAIT: *Errno = WSAEWOULDBLOCK; break;
282 case STATUS_TIMEOUT:
283 case STATUS_SUCCESS:
284 /* Return Number of bytes Read */
285 if( ReturnedBytes ) *ReturnedBytes = Received; break;
286 case STATUS_END_OF_FILE: *Errno = WSAESHUTDOWN; *ReturnedBytes = 0; break;
287 case STATUS_PENDING: *Errno = WSA_IO_PENDING; break;
288 case STATUS_BUFFER_OVERFLOW: *Errno = WSAEMSGSIZE; break;
289 default: {
290 DbgPrint("MSAFD: Error %x is unknown\n", Status);
291 *Errno = WSAEINVAL; break;
292 } break;
293 }
294 }
295
296 /* Success */
297 return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
298 }
299
300
301 INT
302 WSPAPI
303 WSPCloseSocket(
304 IN SOCKET Handle,
305 OUT LPINT lpErrno)
306 /*
307 * FUNCTION: Closes an open socket
308 * ARGUMENTS:
309 * s = Socket descriptor
310 * lpErrno = Address of buffer for error information
311 * RETURNS:
312 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
313 */
314 {
315 IO_STATUS_BLOCK IoStatusBlock;
316 PSOCKET_INFORMATION Socket = NULL;
317 NTSTATUS Status;
318 HANDLE SockEvent;
319 AFD_DISCONNECT_INFO DisconnectInfo;
320 SOCKET_STATE OldState;
321
322 /* Create the Wait Event */
323 Status = NtCreateEvent(&SockEvent,
324 GENERIC_READ | GENERIC_WRITE,
325 NULL,
326 1,
327 FALSE);
328
329 if(!NT_SUCCESS(Status)) return SOCKET_ERROR;
330
331 /* Get the Socket Structure associate to this Socket*/
332 Socket = GetSocketStructure(Handle);
333
334 /* If a Close is already in Process, give up */
335 if (Socket->SharedData.State == SocketClosed) {
336 *lpErrno = WSAENOTSOCK;
337 return SOCKET_ERROR;
338 }
339
340 /* Set the state to close */
341 OldState = Socket->SharedData.State;
342 Socket->SharedData.State = SocketClosed;
343
344 /* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
345 /* FIXME: Should we do this on Datagram Sockets too? */
346 if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff)) {
347 ULONG LingerWait;
348 ULONG SendsInProgress;
349 ULONG SleepWait;
350
351 /* We need to respect the timeout */
352 SleepWait = 100;
353 LingerWait = Socket->SharedData.LingerData.l_linger * 1000;
354
355 /* Loop until no more sends are pending, within the timeout */
356 while (LingerWait) {
357
358 /* Find out how many Sends are in Progress */
359 if (GetSocketInformation(Socket,
360 AFD_INFO_SENDS_IN_PROGRESS,
361 &SendsInProgress,
362 NULL)) {
363 /* Bail out if anything but NO_ERROR */
364 LingerWait = 0;
365 break;
366 }
367
368 /* Bail out if no more sends are pending */
369 if (!SendsInProgress) break;
370
371 /*
372 * We have to execute a sleep, so it's kind of like
373 * a block. If the socket is Nonblock, we cannot
374 * go on since asyncronous operation is expected
375 * and we cannot offer it
376 */
377 if (Socket->SharedData.NonBlocking) {
378 Socket->SharedData.State = OldState;
379 *lpErrno = WSAEWOULDBLOCK;
380 return SOCKET_ERROR;
381 }
382
383 /* Now we can sleep, and decrement the linger wait */
384 /*
385 * FIXME: It seems Windows does some funky acceleration
386 * since the waiting seems to be longer and longer. I
387 * don't think this improves performance so much, so we
388 * wait a fixed time instead.
389 */
390 Sleep(SleepWait);
391 LingerWait -= SleepWait;
392 }
393
394 /*
395 * We have reached the timeout or sends are over.
396 * Disconnect if the timeout has been reached.
397 */
398 if (LingerWait <= 0) {
399
400 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
401 DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
402
403 /* Send IOCTL */
404 Status = NtDeviceIoControlFile((HANDLE)Handle,
405 SockEvent,
406 NULL,
407 NULL,
408 &IoStatusBlock,
409 IOCTL_AFD_DISCONNECT,
410 &DisconnectInfo,
411 sizeof(DisconnectInfo),
412 NULL,
413 0);
414
415 /* Wait for return */
416 if (Status == STATUS_PENDING) {
417 WaitForSingleObject(SockEvent, INFINITE);
418 }
419 }
420 }
421
422 /* FIXME: We should notify the Helper DLL of WSH_NOTIFY_CLOSE */
423
424 /* Cleanup Time! */
425 Socket->HelperContext = NULL;
426 Socket->SharedData.AsyncDisabledEvents = -1;
427 NtClose(Socket->TdiAddressHandle);
428 Socket->TdiAddressHandle = NULL;
429 NtClose(Socket->TdiConnectionHandle);
430 Socket->TdiConnectionHandle = NULL;
431
432 /* Close the handle */
433 NtClose((HANDLE)Handle);
434
435 return NO_ERROR;
436 }
437
438 INT
439 WSPAPI
440 WSPBind(
441 SOCKET Handle,
442 const struct sockaddr *SocketAddress,
443 int SocketAddressLength,
444 LPINT lpErrno)
445 /*
446 * FUNCTION: Associates a local address with a socket
447 * ARGUMENTS:
448 * s = Socket descriptor
449 * name = Pointer to local address
450 * namelen = Length of name
451 * lpErrno = Address of buffer for error information
452 * RETURNS:
453 * 0, or SOCKET_ERROR if the socket could not be bound
454 */
455 {
456 IO_STATUS_BLOCK IOSB;
457 PAFD_BIND_DATA BindData;
458 PSOCKET_INFORMATION Socket = NULL;
459 NTSTATUS Status;
460 UCHAR BindBuffer[0x1A];
461 SOCKADDR_INFO SocketInfo;
462 HANDLE SockEvent;
463
464 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
465 NULL, 1, FALSE );
466
467 if( !NT_SUCCESS(Status) ) return -1;
468
469 /* Get the Socket Structure associate to this Socket*/
470 Socket = GetSocketStructure(Handle);
471
472 /* Dynamic Structure...ugh */
473 BindData = (PAFD_BIND_DATA)BindBuffer;
474
475 /* Set up Address in TDI Format */
476 BindData->Address.TAAddressCount = 1;
477 BindData->Address.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
478 BindData->Address.Address[0].AddressType = SocketAddress->sa_family;
479 RtlCopyMemory (BindData->Address.Address[0].Address,
480 SocketAddress->sa_data,
481 SocketAddressLength - sizeof(SocketAddress->sa_family));
482
483 /* Get Address Information */
484 Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
485 SocketAddressLength,
486 &SocketInfo);
487
488 /* Set the Share Type */
489 if (Socket->SharedData.ExclusiveAddressUse) {
490 BindData->ShareType = AFD_SHARE_EXCLUSIVE;
491 }
492 else if (SocketInfo.EndpointInfo == SockaddrEndpointInfoWildcard) {
493 BindData->ShareType = AFD_SHARE_WILDCARD;
494 }
495 else if (Socket->SharedData.ReuseAddresses) {
496 BindData->ShareType = AFD_SHARE_REUSE;
497 } else {
498 BindData->ShareType = AFD_SHARE_UNIQUE;
499 }
500
501 /* Send IOCTL */
502 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
503 SockEvent,
504 NULL,
505 NULL,
506 &IOSB,
507 IOCTL_AFD_BIND,
508 BindData,
509 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
510 BindData,
511 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
512
513 /* Wait for return */
514 if (Status == STATUS_PENDING) {
515 WaitForSingleObject(SockEvent, INFINITE);
516 }
517
518 /* Set up Socket Data */
519 Socket->SharedData.State = SocketBound;
520 Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
521
522 NtClose( SockEvent );
523
524 return MsafdReturnWithErrno
525 ( IOSB.Status, lpErrno, IOSB.Information, NULL );
526 }
527
528 int
529 WSPAPI
530 WSPListen(
531 SOCKET Handle,
532 int Backlog,
533 LPINT lpErrno)
534 {
535 IO_STATUS_BLOCK IOSB;
536 AFD_LISTEN_DATA ListenData;
537 PSOCKET_INFORMATION Socket = NULL;
538 HANDLE SockEvent;
539 NTSTATUS Status;
540
541 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
542 NULL, 1, FALSE );
543
544 if( !NT_SUCCESS(Status) ) return -1;
545
546 /* Get the Socket Structure associate to this Socket*/
547 Socket = GetSocketStructure(Handle);
548
549 /* Set Up Listen Structure */
550 ListenData.UseSAN = FALSE;
551 ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
552 ListenData.Backlog = Backlog;
553
554 /* Send IOCTL */
555 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
556 SockEvent,
557 NULL,
558 NULL,
559 &IOSB,
560 IOCTL_AFD_START_LISTEN,
561 &ListenData,
562 sizeof(ListenData),
563 NULL,
564 0);
565
566 /* Wait for return */
567 if (Status == STATUS_PENDING) {
568 WaitForSingleObject(SockEvent, INFINITE);
569 }
570
571 /* Set to Listening */
572 Socket->SharedData.Listening = TRUE;
573
574 NtClose( SockEvent );
575
576 return MsafdReturnWithErrno
577 ( IOSB.Status, lpErrno, IOSB.Information, NULL );
578 }
579
580
581 int
582 WSPAPI
583 WSPSelect(
584 int nfds,
585 fd_set *readfds,
586 fd_set *writefds,
587 fd_set *exceptfds,
588 struct timeval *timeout,
589 LPINT lpErrno)
590 {
591 IO_STATUS_BLOCK IOSB;
592 PAFD_POLL_INFO PollInfo;
593 NTSTATUS Status;
594 ULONG HandleCount, OutCount = 0;
595 ULONG PollBufferSize;
596 PVOID PollBuffer;
597 ULONG i, j = 0, x;
598 HANDLE SockEvent;
599 BOOL HandleCounted;
600 LARGE_INTEGER Timeout;
601
602 /* Find out how many sockets we have, and how large the buffer needs
603 * to be */
604
605 HandleCount =
606 ( readfds ? readfds->fd_count : 0 ) +
607 ( writefds ? writefds->fd_count : 0 ) +
608 ( exceptfds ? exceptfds->fd_count : 0 );
609
610 if( HandleCount < 0 || nfds != 0 ) HandleCount = nfds * 3;
611
612 PollBufferSize = sizeof(*PollInfo) + (HandleCount * sizeof(AFD_HANDLE));
613
614 AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n",
615 HandleCount, PollBufferSize));
616
617 /* Convert Timeout to NT Format */
618 if (timeout == NULL) {
619 Timeout.u.LowPart = -1;
620 Timeout.u.HighPart = 0x7FFFFFFF;
621 AFD_DbgPrint(MAX_TRACE,("Infinite timeout\n"));
622 } else {
623 Timeout = RtlEnlargedIntegerMultiply
624 ((timeout->tv_sec * 1000) + (timeout->tv_usec / 1000), -10000);
625 /* Negative timeouts are illegal. Since the kernel represents an
626 * incremental timeout as a negative number, we check for a positive
627 * result.
628 */
629 if (Timeout.QuadPart > 0) {
630 if (lpErrno) *lpErrno = WSAEINVAL;
631 return SOCKET_ERROR;
632 }
633 AFD_DbgPrint(MAX_TRACE,("Timeout: Orig %d.%06d kernel %d\n",
634 timeout->tv_sec, timeout->tv_usec,
635 Timeout.u.LowPart));
636 }
637
638 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
639 NULL, 1, FALSE );
640
641 if( !NT_SUCCESS(Status) ) return SOCKET_ERROR;
642
643 /* Allocate */
644 PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
645
646 if (!PollBuffer) {
647 if (lpErrno) *lpErrno = WSAEFAULT;
648 NtClose(SockEvent);
649 return SOCKET_ERROR;
650 }
651
652 PollInfo = (PAFD_POLL_INFO)PollBuffer;
653
654 RtlZeroMemory( PollInfo, PollBufferSize );
655
656 /* Number of handles for AFD to Check */
657 PollInfo->HandleCount = HandleCount;
658 PollInfo->Exclusive = FALSE;
659 PollInfo->Timeout = Timeout;
660
661 if (readfds != NULL) {
662 for (i = 0; i < readfds->fd_count; i++, j++) {
663 PollInfo->Handles[j].Handle = readfds->fd_array[i];
664 PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
665 AFD_EVENT_DISCONNECT |
666 AFD_EVENT_ABORT |
667 AFD_EVENT_ACCEPT;
668 }
669 }
670 if (writefds != NULL) {
671 for (i = 0; i < writefds->fd_count; i++, j++) {
672 PollInfo->Handles[j].Handle = writefds->fd_array[i];
673 PollInfo->Handles[j].Events = AFD_EVENT_SEND |
674 AFD_EVENT_CONNECT;
675 }
676 }
677 if (exceptfds != NULL) {
678 for (i = 0; i < exceptfds->fd_count; i++, j++) {
679 PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
680 PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE |
681 AFD_EVENT_CONNECT_FAIL;
682 }
683 }
684
685 /* Send IOCTL */
686 Status = NtDeviceIoControlFile( (HANDLE)PollInfo->Handles[0].Handle,
687 SockEvent,
688 NULL,
689 NULL,
690 &IOSB,
691 IOCTL_AFD_SELECT,
692 PollInfo,
693 PollBufferSize,
694 PollInfo,
695 PollBufferSize);
696
697 AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
698
699 /* Wait for Completition */
700 if (Status == STATUS_PENDING) {
701 WaitForSingleObject(SockEvent, INFINITE);
702 }
703
704 /* Clear the Structures */
705 if( readfds ) FD_ZERO(readfds);
706 if( writefds ) FD_ZERO(writefds);
707 if( exceptfds ) FD_ZERO(exceptfds);
708
709 /* Loop through return structure */
710 HandleCount = PollInfo->HandleCount;
711
712 /* Return in FDSET Format */
713 for (i = 0; i < HandleCount; i++) {
714 HandleCounted = FALSE;
715 for(x = 1; x; x<<=1) {
716 switch (PollInfo->Handles[i].Events & x) {
717 case AFD_EVENT_RECEIVE:
718 case AFD_EVENT_DISCONNECT:
719 case AFD_EVENT_ABORT:
720 case AFD_EVENT_ACCEPT:
721 case AFD_EVENT_CLOSE:
722 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
723 PollInfo->Handles[i].Events,
724 PollInfo->Handles[i].Handle));
725 if (! HandleCounted) {
726 OutCount++;
727 HandleCounted = TRUE;
728 }
729 if( readfds ) FD_SET(PollInfo->Handles[i].Handle, readfds);
730 break;
731
732 case AFD_EVENT_SEND: case AFD_EVENT_CONNECT:
733 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
734 PollInfo->Handles[i].Events,
735 PollInfo->Handles[i].Handle));
736 if (! HandleCounted) {
737 OutCount++;
738 HandleCounted = TRUE;
739 }
740 if( writefds ) FD_SET(PollInfo->Handles[i].Handle, writefds);
741 break;
742
743 case AFD_EVENT_OOB_RECEIVE: case AFD_EVENT_CONNECT_FAIL:
744 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
745 PollInfo->Handles[i].Events,
746 PollInfo->Handles[i].Handle));
747 if (! HandleCounted) {
748 OutCount++;
749 HandleCounted = TRUE;
750 }
751 if( exceptfds ) FD_SET(PollInfo->Handles[i].Handle, exceptfds);
752 break;
753 }
754 }
755 }
756
757 HeapFree( GlobalHeap, 0, PollBuffer );
758 NtClose( SockEvent );
759
760
761 if( lpErrno ) {
762 switch( IOSB.Status ) {
763 case STATUS_SUCCESS:
764 case STATUS_TIMEOUT: *lpErrno = 0; break;
765 default: *lpErrno = WSAEINVAL; break;
766 }
767 AFD_DbgPrint(MID_TRACE,("*lpErrno = %x\n", *lpErrno));
768 }
769
770 AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
771
772 return OutCount;
773 }
774
775 SOCKET
776 WSPAPI
777 WSPAccept(
778 SOCKET Handle,
779 struct sockaddr *SocketAddress,
780 int *SocketAddressLength,
781 LPCONDITIONPROC lpfnCondition,
782 DWORD_PTR dwCallbackData,
783 LPINT lpErrno)
784 {
785 IO_STATUS_BLOCK IOSB;
786 PAFD_RECEIVED_ACCEPT_DATA ListenReceiveData;
787 AFD_ACCEPT_DATA AcceptData;
788 AFD_DEFER_ACCEPT_DATA DeferData;
789 AFD_PENDING_ACCEPT_DATA PendingAcceptData;
790 PSOCKET_INFORMATION Socket = NULL;
791 NTSTATUS Status;
792 struct fd_set ReadSet;
793 struct timeval Timeout;
794 PVOID PendingData = NULL;
795 ULONG PendingDataLength = 0;
796 PVOID CalleeDataBuffer;
797 WSABUF CallerData, CalleeID, CallerID, CalleeData;
798 PSOCKADDR RemoteAddress = NULL;
799 GROUP GroupID = 0;
800 ULONG CallBack;
801 WSAPROTOCOL_INFOW ProtocolInfo;
802 SOCKET AcceptSocket;
803 UCHAR ReceiveBuffer[0x1A];
804 HANDLE SockEvent;
805
806 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
807 NULL, 1, FALSE );
808
809 if( !NT_SUCCESS(Status) ) {
810 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
811 return INVALID_SOCKET;
812 }
813
814 /* Dynamic Structure...ugh */
815 ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
816
817 /* Get the Socket Structure associate to this Socket*/
818 Socket = GetSocketStructure(Handle);
819
820 /* If this is non-blocking, make sure there's something for us to accept */
821 FD_ZERO(&ReadSet);
822 FD_SET(Socket->Handle, &ReadSet);
823 Timeout.tv_sec=0;
824 Timeout.tv_usec=0;
825
826 WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
827
828 if (ReadSet.fd_array[0] != Socket->Handle) return 0;
829
830 /* Send IOCTL */
831 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
832 SockEvent,
833 NULL,
834 NULL,
835 &IOSB,
836 IOCTL_AFD_WAIT_FOR_LISTEN,
837 NULL,
838 0,
839 ListenReceiveData,
840 0xA + sizeof(*ListenReceiveData));
841
842 /* Wait for return */
843 if (Status == STATUS_PENDING) {
844 WaitForSingleObject(SockEvent, INFINITE);
845 Status = IOSB.Status;
846 }
847
848 if (!NT_SUCCESS(Status)) {
849 NtClose( SockEvent );
850 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
851 return INVALID_SOCKET;
852 }
853
854 if (lpfnCondition != NULL) {
855 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0) {
856 /* Find out how much data is pending */
857 PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
858 PendingAcceptData.ReturnSize = TRUE;
859
860 /* Send IOCTL */
861 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
862 SockEvent,
863 NULL,
864 NULL,
865 &IOSB,
866 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
867 &PendingAcceptData,
868 sizeof(PendingAcceptData),
869 &PendingAcceptData,
870 sizeof(PendingAcceptData));
871
872 /* Wait for return */
873 if (Status == STATUS_PENDING) {
874 WaitForSingleObject(SockEvent, INFINITE);
875 Status = IOSB.Status;
876 }
877
878 if (!NT_SUCCESS(Status)) {
879 NtClose( SockEvent );
880 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
881 return INVALID_SOCKET;
882 }
883
884 /* How much data to allocate */
885 PendingDataLength = IOSB.Information;
886
887 if (PendingDataLength) {
888 /* Allocate needed space */
889 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
890
891 /* We want the data now */
892 PendingAcceptData.ReturnSize = FALSE;
893
894 /* Send IOCTL */
895 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
896 SockEvent,
897 NULL,
898 NULL,
899 &IOSB,
900 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
901 &PendingAcceptData,
902 sizeof(PendingAcceptData),
903 PendingData,
904 PendingDataLength);
905
906 /* Wait for return */
907 if (Status == STATUS_PENDING) {
908 WaitForSingleObject(SockEvent, INFINITE);
909 Status = IOSB.Status;
910 }
911
912 if (!NT_SUCCESS(Status)) {
913 NtClose( SockEvent );
914 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
915 return INVALID_SOCKET;
916 }
917 }
918 }
919
920 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0) {
921 /* I don't support this yet */
922 }
923
924 /* Build Callee ID */
925 CalleeID.buf = (PVOID)Socket->LocalAddress;
926 CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
927
928 /* Set up Address in SOCKADDR Format */
929 RtlCopyMemory (RemoteAddress,
930 &ListenReceiveData->Address.Address[0].AddressType,
931 sizeof(*RemoteAddress));
932
933 /* Build Caller ID */
934 CallerID.buf = (PVOID)RemoteAddress;
935 CallerID.len = sizeof(*RemoteAddress);
936
937 /* Build Caller Data */
938 CallerData.buf = PendingData;
939 CallerData.len = PendingDataLength;
940
941 /* Check if socket supports Conditional Accept */
942 if (Socket->SharedData.UseDelayedAcceptance != 0) {
943 /* Allocate Buffer for Callee Data */
944 CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
945 CalleeData.buf = CalleeDataBuffer;
946 CalleeData.len = 4096;
947 } else {
948 /* Nothing */
949 CalleeData.buf = 0;
950 CalleeData.len = 0;
951 }
952
953 /* Call the Condition Function */
954 CallBack = (lpfnCondition)( &CallerID,
955 CallerData.buf == NULL
956 ? NULL
957 : & CallerData,
958 NULL,
959 NULL,
960 &CalleeID,
961 CalleeData.buf == NULL
962 ? NULL
963 : & CalleeData,
964 &GroupID,
965 dwCallbackData);
966
967 if (((CallBack == CF_ACCEPT) && GroupID) != 0) {
968 /* TBD: Check for Validity */
969 }
970
971 if (CallBack == CF_ACCEPT) {
972
973 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0) {
974 /* I don't support this yet */
975 }
976
977 if (CalleeData.buf) {
978 // SockSetConnectData Sockets(SocketID), IOCTL_AFD_SET_CONNECT_DATA, CalleeData.Buffer, CalleeData.BuffSize, 0
979 }
980
981 } else {
982 /* Callback rejected. Build Defer Structure */
983 DeferData.SequenceNumber = ListenReceiveData->SequenceNumber;
984 DeferData.RejectConnection = (CallBack == CF_REJECT);
985
986 /* Send IOCTL */
987 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
988 SockEvent,
989 NULL,
990 NULL,
991 &IOSB,
992 IOCTL_AFD_DEFER_ACCEPT,
993 &DeferData,
994 sizeof(DeferData),
995 NULL,
996 0);
997
998 /* Wait for return */
999 if (Status == STATUS_PENDING) {
1000 WaitForSingleObject(SockEvent, INFINITE);
1001 Status = IOSB.Status;
1002 }
1003
1004 NtClose( SockEvent );
1005
1006 if (!NT_SUCCESS(Status)) {
1007 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1008 return INVALID_SOCKET;
1009 }
1010
1011 if (CallBack == CF_REJECT ) {
1012 *lpErrno = WSAECONNREFUSED;
1013 return INVALID_SOCKET;
1014 } else {
1015 *lpErrno = WSAECONNREFUSED;
1016 return INVALID_SOCKET;
1017 }
1018 }
1019 }
1020
1021 /* Create a new Socket */
1022 ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
1023 ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
1024 ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
1025
1026 AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
1027 Socket->SharedData.SocketType,
1028 Socket->SharedData.Protocol,
1029 &ProtocolInfo,
1030 GroupID,
1031 Socket->SharedData.CreateFlags,
1032 NULL);
1033
1034 /* Set up the Accept Structure */
1035 AcceptData.ListenHandle = AcceptSocket;
1036 AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1037
1038 /* Send IOCTL to Accept */
1039 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1040 SockEvent,
1041 NULL,
1042 NULL,
1043 &IOSB,
1044 IOCTL_AFD_ACCEPT,
1045 &AcceptData,
1046 sizeof(AcceptData),
1047 NULL,
1048 0);
1049
1050 /* Wait for return */
1051 if (Status == STATUS_PENDING) {
1052 WaitForSingleObject(SockEvent, INFINITE);
1053 Status = IOSB.Status;
1054 }
1055
1056 if (!NT_SUCCESS(Status)) {
1057 WSPCloseSocket( AcceptSocket, lpErrno );
1058 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1059 return INVALID_SOCKET;
1060 }
1061
1062 /* Return Address in SOCKADDR FORMAT */
1063 if( SocketAddress ) {
1064 RtlCopyMemory (SocketAddress,
1065 &ListenReceiveData->Address.Address[0].AddressType,
1066 sizeof(*RemoteAddress));
1067 if( SocketAddressLength )
1068 *SocketAddressLength =
1069 ListenReceiveData->Address.Address[0].AddressLength;
1070 }
1071
1072 NtClose( SockEvent );
1073
1074 /* Re-enable Async Event */
1075 SockReenableAsyncSelectEvent(Socket, FD_ACCEPT);
1076
1077 AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
1078
1079 *lpErrno = 0;
1080
1081 /* Return Socket */
1082 return AcceptSocket;
1083 }
1084
1085 int
1086 WSPAPI
1087 WSPConnect(
1088 SOCKET Handle,
1089 const struct sockaddr * SocketAddress,
1090 int SocketAddressLength,
1091 LPWSABUF lpCallerData,
1092 LPWSABUF lpCalleeData,
1093 LPQOS lpSQOS,
1094 LPQOS lpGQOS,
1095 LPINT lpErrno)
1096 {
1097 IO_STATUS_BLOCK IOSB;
1098 PAFD_CONNECT_INFO ConnectInfo;
1099 PSOCKET_INFORMATION Socket = NULL;
1100 NTSTATUS Status;
1101 UCHAR ConnectBuffer[0x22];
1102 ULONG ConnectDataLength;
1103 ULONG InConnectDataLength;
1104 INT BindAddressLength;
1105 PSOCKADDR BindAddress;
1106 HANDLE SockEvent;
1107
1108 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1109 NULL, 1, FALSE );
1110
1111 if( !NT_SUCCESS(Status) ) return -1;
1112
1113 AFD_DbgPrint(MID_TRACE,("Called\n"));
1114
1115 /* Get the Socket Structure associate to this Socket*/
1116 Socket = GetSocketStructure(Handle);
1117
1118 /* Bind us First */
1119 if (Socket->SharedData.State == SocketOpen) {
1120
1121 /* Get the Wildcard Address */
1122 BindAddressLength = Socket->HelperData->MaxWSAddressLength;
1123 BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
1124 Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
1125 BindAddress,
1126 &BindAddressLength);
1127
1128 /* Bind it */
1129 WSPBind(Handle, BindAddress, BindAddressLength, NULL);
1130 }
1131
1132 /* Set the Connect Data */
1133 if (lpCallerData != NULL) {
1134 ConnectDataLength = lpCallerData->len;
1135 Status = NtDeviceIoControlFile((HANDLE)Handle,
1136 SockEvent,
1137 NULL,
1138 NULL,
1139 &IOSB,
1140 IOCTL_AFD_SET_CONNECT_DATA,
1141 lpCallerData->buf,
1142 ConnectDataLength,
1143 NULL,
1144 0);
1145 /* Wait for return */
1146 if (Status == STATUS_PENDING) {
1147 WaitForSingleObject(SockEvent, INFINITE);
1148 Status = IOSB.Status;
1149 }
1150 }
1151
1152 /* Dynamic Structure...ugh */
1153 ConnectInfo = (PAFD_CONNECT_INFO)ConnectBuffer;
1154
1155 /* Set up Address in TDI Format */
1156 ConnectInfo->RemoteAddress.TAAddressCount = 1;
1157 ConnectInfo->RemoteAddress.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
1158 ConnectInfo->RemoteAddress.Address[0].AddressType = SocketAddress->sa_family;
1159 RtlCopyMemory (ConnectInfo->RemoteAddress.Address[0].Address,
1160 SocketAddress->sa_data,
1161 SocketAddressLength - sizeof(SocketAddress->sa_family));
1162
1163 /*
1164 * Disable FD_WRITE and FD_CONNECT
1165 * The latter fixes a race condition where the FD_CONNECT is re-enabled
1166 * at the end of this function right after the Async Thread disables it.
1167 * This should only happen at the *next* WSPConnect
1168 */
1169 if (Socket->SharedData.AsyncEvents & FD_CONNECT) {
1170 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
1171 }
1172
1173 /* Tell AFD that we want Connection Data back, have it allocate a buffer */
1174 if (lpCalleeData != NULL) {
1175 InConnectDataLength = lpCalleeData->len;
1176 Status = NtDeviceIoControlFile((HANDLE)Handle,
1177 SockEvent,
1178 NULL,
1179 NULL,
1180 &IOSB,
1181 IOCTL_AFD_SET_CONNECT_DATA_SIZE,
1182 &InConnectDataLength,
1183 sizeof(InConnectDataLength),
1184 NULL,
1185 0);
1186
1187 /* Wait for return */
1188 if (Status == STATUS_PENDING) {
1189 WaitForSingleObject(SockEvent, INFINITE);
1190 Status = IOSB.Status;
1191 }
1192 }
1193
1194 /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
1195 ConnectInfo->Root = 0;
1196 ConnectInfo->UseSAN = FALSE;
1197 ConnectInfo->Unknown = 0;
1198
1199 /* FIXME: Handle Async Connect */
1200 if (Socket->SharedData.NonBlocking) {
1201 AFD_DbgPrint(MIN_TRACE, ("Async Connect UNIMPLEMENTED!\n"));
1202 }
1203
1204 /* Send IOCTL */
1205 Status = NtDeviceIoControlFile((HANDLE)Handle,
1206 SockEvent,
1207 NULL,
1208 NULL,
1209 &IOSB,
1210 IOCTL_AFD_CONNECT,
1211 ConnectInfo,
1212 0x22,
1213 NULL,
1214 0);
1215 /* Wait for return */
1216 if (Status == STATUS_PENDING) {
1217 WaitForSingleObject(SockEvent, INFINITE);
1218 Status = IOSB.Status;
1219 }
1220
1221 /* Get any pending connect data */
1222 if (lpCalleeData != NULL) {
1223 Status = NtDeviceIoControlFile((HANDLE)Handle,
1224 SockEvent,
1225 NULL,
1226 NULL,
1227 &IOSB,
1228 IOCTL_AFD_GET_CONNECT_DATA,
1229 NULL,
1230 0,
1231 lpCalleeData->buf,
1232 lpCalleeData->len);
1233 /* Wait for return */
1234 if (Status == STATUS_PENDING) {
1235 WaitForSingleObject(SockEvent, INFINITE);
1236 Status = IOSB.Status;
1237 }
1238 }
1239
1240 /* Re-enable Async Event */
1241 SockReenableAsyncSelectEvent(Socket, FD_WRITE);
1242
1243 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
1244 SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
1245
1246 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1247
1248 NtClose( SockEvent );
1249
1250 return MsafdReturnWithErrno( IOSB.Status, lpErrno, 0, NULL );
1251 }
1252 int
1253 WSPAPI
1254 WSPShutdown(
1255 SOCKET Handle,
1256 int HowTo,
1257 LPINT lpErrno)
1258
1259 {
1260 IO_STATUS_BLOCK IOSB;
1261 AFD_DISCONNECT_INFO DisconnectInfo;
1262 PSOCKET_INFORMATION Socket = NULL;
1263 NTSTATUS Status;
1264 HANDLE SockEvent;
1265
1266 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1267 NULL, 1, FALSE );
1268
1269 if( !NT_SUCCESS(Status) ) return -1;
1270
1271 AFD_DbgPrint(MID_TRACE,("Called\n"));
1272
1273 /* Get the Socket Structure associate to this Socket*/
1274 Socket = GetSocketStructure(Handle);
1275
1276 /* Set AFD Disconnect Type */
1277 switch (HowTo) {
1278
1279 case SD_RECEIVE:
1280 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
1281 Socket->SharedData.ReceiveShutdown = TRUE;
1282 break;
1283
1284 case SD_SEND:
1285 DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
1286 Socket->SharedData.SendShutdown = TRUE;
1287 break;
1288
1289 case SD_BOTH:
1290 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
1291 Socket->SharedData.ReceiveShutdown = TRUE;
1292 Socket->SharedData.SendShutdown = TRUE;
1293 break;
1294 }
1295
1296 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
1297
1298 /* Send IOCTL */
1299 Status = NtDeviceIoControlFile((HANDLE)Handle,
1300 SockEvent,
1301 NULL,
1302 NULL,
1303 &IOSB,
1304 IOCTL_AFD_DISCONNECT,
1305 &DisconnectInfo,
1306 sizeof(DisconnectInfo),
1307 NULL,
1308 0);
1309
1310 /* Wait for return */
1311 if (Status == STATUS_PENDING) {
1312 WaitForSingleObject(SockEvent, INFINITE);
1313 }
1314
1315 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1316
1317 NtClose( SockEvent );
1318
1319 return MsafdReturnWithErrno( IOSB.Status, lpErrno, 0, NULL );
1320 }
1321
1322
1323 INT
1324 WSPAPI
1325 WSPGetSockName(
1326 IN SOCKET Handle,
1327 OUT LPSOCKADDR Name,
1328 IN OUT LPINT NameLength,
1329 OUT LPINT lpErrno)
1330 {
1331 IO_STATUS_BLOCK IOSB;
1332 ULONG TdiAddressSize;
1333 PTDI_ADDRESS_INFO TdiAddress;
1334 PTRANSPORT_ADDRESS SocketAddress;
1335 PSOCKET_INFORMATION Socket = NULL;
1336 NTSTATUS Status;
1337 HANDLE SockEvent;
1338
1339 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1340 NULL, 1, FALSE );
1341
1342 if( !NT_SUCCESS(Status) ) return SOCKET_ERROR;
1343
1344 /* Get the Socket Structure associate to this Socket*/
1345 Socket = GetSocketStructure(Handle);
1346
1347 /* Allocate a buffer for the address */
1348 TdiAddressSize = FIELD_OFFSET(TDI_ADDRESS_INFO,
1349 Address.Address[0].Address) +
1350 Socket->SharedData.SizeOfLocalAddress;
1351 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1352
1353 if ( TdiAddress == NULL ) {
1354 NtClose( SockEvent );
1355 *lpErrno = WSAENOBUFS;
1356 return SOCKET_ERROR;
1357 }
1358
1359 SocketAddress = &TdiAddress->Address;
1360
1361 /* Send IOCTL */
1362 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1363 SockEvent,
1364 NULL,
1365 NULL,
1366 &IOSB,
1367 IOCTL_AFD_GET_SOCK_NAME,
1368 NULL,
1369 0,
1370 TdiAddress,
1371 TdiAddressSize);
1372
1373 /* Wait for return */
1374 if (Status == STATUS_PENDING) {
1375 WaitForSingleObject(SockEvent, INFINITE);
1376 Status = IOSB.Status;
1377 }
1378
1379 NtClose( SockEvent );
1380
1381 if (NT_SUCCESS(Status)) {
1382 if (*NameLength >= SocketAddress->Address[0].AddressLength) {
1383 Name->sa_family = SocketAddress->Address[0].AddressType;
1384 RtlCopyMemory (Name->sa_data,
1385 SocketAddress->Address[0].Address,
1386 SocketAddress->Address[0].AddressLength);
1387 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1388 AFD_DbgPrint
1389 (MID_TRACE,
1390 ("NameLength %d Address: %x Port %x\n",
1391 *NameLength,
1392 ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1393 ((struct sockaddr_in *)Name)->sin_port));
1394 HeapFree(GlobalHeap, 0, TdiAddress);
1395 return 0;
1396 } else {
1397 HeapFree(GlobalHeap, 0, TdiAddress);
1398 *lpErrno = WSAEFAULT;
1399 return SOCKET_ERROR;
1400 }
1401 }
1402
1403 return MsafdReturnWithErrno
1404 ( IOSB.Status, lpErrno, 0, NULL );
1405 }
1406
1407
1408 INT
1409 WSPAPI
1410 WSPGetPeerName(
1411 IN SOCKET s,
1412 OUT LPSOCKADDR Name,
1413 IN OUT LPINT NameLength,
1414 OUT LPINT lpErrno)
1415 {
1416 IO_STATUS_BLOCK IOSB;
1417 ULONG TdiAddressSize;
1418 PTDI_ADDRESS_INFO TdiAddress;
1419 PTRANSPORT_ADDRESS SocketAddress;
1420 PSOCKET_INFORMATION Socket = NULL;
1421 NTSTATUS Status;
1422 HANDLE SockEvent;
1423
1424 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1425 NULL, 1, FALSE );
1426
1427 if( !NT_SUCCESS(Status) ) return SOCKET_ERROR;
1428
1429 /* Get the Socket Structure associate to this Socket*/
1430 Socket = GetSocketStructure(s);
1431
1432 /* Allocate a buffer for the address */
1433 TdiAddressSize = FIELD_OFFSET(TDI_ADDRESS_INFO,
1434 Address.Address[0].Address) +
1435 Socket->SharedData.SizeOfLocalAddress;
1436 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1437
1438 if ( TdiAddress == NULL ) {
1439 NtClose( SockEvent );
1440 *lpErrno = WSAENOBUFS;
1441 return SOCKET_ERROR;
1442 }
1443
1444 SocketAddress = &TdiAddress->Address;
1445
1446 /* Send IOCTL */
1447 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1448 SockEvent,
1449 NULL,
1450 NULL,
1451 &IOSB,
1452 IOCTL_AFD_GET_PEER_NAME,
1453 NULL,
1454 0,
1455 TdiAddress,
1456 TdiAddressSize);
1457
1458 /* Wait for return */
1459 if (Status == STATUS_PENDING) {
1460 WaitForSingleObject(SockEvent, INFINITE);
1461 Status = IOSB.Status;
1462 }
1463
1464 NtClose( SockEvent );
1465
1466 if (NT_SUCCESS(Status)) {
1467 if (*NameLength >= SocketAddress->Address[0].AddressLength) {
1468 Name->sa_family = SocketAddress->Address[0].AddressType;
1469 RtlCopyMemory (Name->sa_data,
1470 SocketAddress->Address[0].Address,
1471 SocketAddress->Address[0].AddressLength);
1472 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1473 AFD_DbgPrint
1474 (MID_TRACE,
1475 ("NameLength %d Address: %s Port %x\n",
1476 *NameLength,
1477 ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1478 ((struct sockaddr_in *)Name)->sin_port));
1479 HeapFree(GlobalHeap, 0, TdiAddress);
1480 return 0;
1481 } else {
1482 HeapFree(GlobalHeap, 0, TdiAddress);
1483 *lpErrno = WSAEFAULT;
1484 return SOCKET_ERROR;
1485 }
1486 }
1487
1488 return MsafdReturnWithErrno
1489 ( IOSB.Status, lpErrno, 0, NULL );
1490 }
1491
1492 INT
1493 WSPAPI
1494 WSPIoctl(
1495 IN SOCKET Handle,
1496 IN DWORD dwIoControlCode,
1497 IN LPVOID lpvInBuffer,
1498 IN DWORD cbInBuffer,
1499 OUT LPVOID lpvOutBuffer,
1500 IN DWORD cbOutBuffer,
1501 OUT LPDWORD lpcbBytesReturned,
1502 IN LPWSAOVERLAPPED lpOverlapped,
1503 IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
1504 IN LPWSATHREADID lpThreadId,
1505 OUT LPINT lpErrno)
1506 {
1507 PSOCKET_INFORMATION Socket = NULL;
1508
1509 /* Get the Socket Structure associate to this Socket*/
1510 Socket = GetSocketStructure(Handle);
1511
1512 switch( dwIoControlCode ) {
1513 case FIONBIO:
1514 if( cbInBuffer < sizeof(INT) ) return SOCKET_ERROR;
1515 Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
1516 AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n",
1517 Handle, Socket->SharedData.NonBlocking));
1518 return 0;
1519
1520 default:
1521 *lpErrno = WSAEINVAL;
1522 return SOCKET_ERROR;
1523 }
1524 }
1525
1526
1527 INT
1528 WSPAPI
1529 WSPGetSockOpt(
1530 IN SOCKET Handle,
1531 IN INT Level,
1532 IN INT OptionName,
1533 OUT CHAR FAR* OptionValue,
1534 IN OUT LPINT OptionLength,
1535 OUT LPINT lpErrno)
1536 {
1537 PSOCKET_INFORMATION Socket = NULL;
1538 PVOID Buffer;
1539 INT BufferSize;
1540 BOOLEAN BoolBuffer;
1541
1542 /* Get the Socket Structure associate to this Socket*/
1543 Socket = GetSocketStructure(Handle);
1544 if (Socket == NULL)
1545 {
1546 *lpErrno = WSAENOTSOCK;
1547 return SOCKET_ERROR;
1548 }
1549
1550 AFD_DbgPrint(MID_TRACE, ("Called\n"));
1551
1552 switch (Level)
1553 {
1554 case SOL_SOCKET:
1555 switch (OptionName)
1556 {
1557 case SO_TYPE:
1558 Buffer = &Socket->SharedData.SocketType;
1559 BufferSize = sizeof(INT);
1560 break;
1561
1562 case SO_RCVBUF:
1563 Buffer = &Socket->SharedData.SizeOfRecvBuffer;
1564 BufferSize = sizeof(INT);
1565 break;
1566
1567 case SO_SNDBUF:
1568 Buffer = &Socket->SharedData.SizeOfSendBuffer;
1569 BufferSize = sizeof(INT);
1570 break;
1571
1572 case SO_ACCEPTCONN:
1573 BoolBuffer = Socket->SharedData.Listening;
1574 Buffer = &BoolBuffer;
1575 BufferSize = sizeof(BOOLEAN);
1576 break;
1577
1578 case SO_BROADCAST:
1579 BoolBuffer = Socket->SharedData.Broadcast;
1580 Buffer = &BoolBuffer;
1581 BufferSize = sizeof(BOOLEAN);
1582 break;
1583
1584 case SO_DEBUG:
1585 BoolBuffer = Socket->SharedData.Debug;
1586 Buffer = &BoolBuffer;
1587 BufferSize = sizeof(BOOLEAN);
1588 break;
1589
1590 /* case SO_CONDITIONAL_ACCEPT: */
1591 case SO_DONTLINGER:
1592 case SO_DONTROUTE:
1593 case SO_ERROR:
1594 case SO_GROUP_ID:
1595 case SO_GROUP_PRIORITY:
1596 case SO_KEEPALIVE:
1597 case SO_LINGER:
1598 case SO_MAX_MSG_SIZE:
1599 case SO_OOBINLINE:
1600 case SO_PROTOCOL_INFO:
1601 case SO_REUSEADDR:
1602 AFD_DbgPrint(MID_TRACE, ("Unimplemented option (%x)\n",
1603 OptionName));
1604
1605 default:
1606 *lpErrno = WSAEINVAL;
1607 return SOCKET_ERROR;
1608 }
1609
1610 if (*OptionLength < BufferSize)
1611 {
1612 *lpErrno = WSAEFAULT;
1613 *OptionLength = BufferSize;
1614 return SOCKET_ERROR;
1615 }
1616 RtlCopyMemory(OptionValue, Buffer, BufferSize);
1617
1618 return 0;
1619
1620 case IPPROTO_TCP: /* FIXME */
1621 default:
1622 *lpErrno = WSAEINVAL;
1623 return SOCKET_ERROR;
1624 }
1625 }
1626
1627
1628 INT
1629 WSPAPI
1630 WSPStartup(
1631 IN WORD wVersionRequested,
1632 OUT LPWSPDATA lpWSPData,
1633 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
1634 IN WSPUPCALLTABLE UpcallTable,
1635 OUT LPWSPPROC_TABLE lpProcTable)
1636 /*
1637 * FUNCTION: Initialize service provider for a client
1638 * ARGUMENTS:
1639 * wVersionRequested = Highest WinSock SPI version that the caller can use
1640 * lpWSPData = Address of WSPDATA structure to initialize
1641 * lpProtocolInfo = Pointer to structure that defines the desired protocol
1642 * UpcallTable = Pointer to upcall table of the WinSock DLL
1643 * lpProcTable = Address of procedure table to initialize
1644 * RETURNS:
1645 * Status of operation
1646 */
1647 {
1648 NTSTATUS Status;
1649
1650 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
1651
1652 Status = NO_ERROR;
1653
1654 Upcalls = UpcallTable;
1655
1656 if (Status == NO_ERROR) {
1657 lpProcTable->lpWSPAccept = WSPAccept;
1658 lpProcTable->lpWSPAddressToString = WSPAddressToString;
1659 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
1660 lpProcTable->lpWSPBind = WSPBind;
1661 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
1662 lpProcTable->lpWSPCleanup = WSPCleanup;
1663 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
1664 lpProcTable->lpWSPConnect = WSPConnect;
1665 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
1666 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
1667 lpProcTable->lpWSPEventSelect = WSPEventSelect;
1668 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
1669 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
1670 lpProcTable->lpWSPGetSockName = WSPGetSockName;
1671 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
1672 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
1673 lpProcTable->lpWSPIoctl = WSPIoctl;
1674 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
1675 lpProcTable->lpWSPListen = WSPListen;
1676 lpProcTable->lpWSPRecv = WSPRecv;
1677 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
1678 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
1679 lpProcTable->lpWSPSelect = WSPSelect;
1680 lpProcTable->lpWSPSend = WSPSend;
1681 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
1682 lpProcTable->lpWSPSendTo = WSPSendTo;
1683 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
1684 lpProcTable->lpWSPShutdown = WSPShutdown;
1685 lpProcTable->lpWSPSocket = WSPSocket;
1686 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
1687
1688 lpWSPData->wVersion = MAKEWORD(2, 2);
1689 lpWSPData->wHighVersion = MAKEWORD(2, 2);
1690 }
1691
1692 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
1693
1694 return Status;
1695 }
1696
1697
1698 INT
1699 WSPAPI
1700 WSPCleanup(
1701 OUT LPINT lpErrno)
1702 /*
1703 * FUNCTION: Cleans up service provider for a client
1704 * ARGUMENTS:
1705 * lpErrno = Address of buffer for error information
1706 * RETURNS:
1707 * 0 if successful, or SOCKET_ERROR if not
1708 */
1709 {
1710 AFD_DbgPrint(MAX_TRACE, ("\n"));
1711
1712
1713 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
1714
1715 *lpErrno = NO_ERROR;
1716
1717 return 0;
1718 }
1719
1720
1721
1722 int
1723 GetSocketInformation(
1724 PSOCKET_INFORMATION Socket,
1725 ULONG AfdInformationClass,
1726 PULONG Ulong OPTIONAL,
1727 PLARGE_INTEGER LargeInteger OPTIONAL)
1728 {
1729 IO_STATUS_BLOCK IOSB;
1730 AFD_INFO InfoData;
1731 NTSTATUS Status;
1732 HANDLE SockEvent;
1733
1734 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1735 NULL, 1, FALSE );
1736
1737 if( !NT_SUCCESS(Status) ) return -1;
1738
1739 /* Set Info Class */
1740 InfoData.InformationClass = AfdInformationClass;
1741
1742 /* Send IOCTL */
1743 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1744 SockEvent,
1745 NULL,
1746 NULL,
1747 &IOSB,
1748 IOCTL_AFD_GET_INFO,
1749 &InfoData,
1750 sizeof(InfoData),
1751 &InfoData,
1752 sizeof(InfoData));
1753
1754 /* Wait for return */
1755 if (Status == STATUS_PENDING) {
1756 WaitForSingleObject(SockEvent, INFINITE);
1757 }
1758
1759 /* Return Information */
1760 *Ulong = InfoData.Information.Ulong;
1761 if (LargeInteger != NULL) {
1762 *LargeInteger = InfoData.Information.LargeInteger;
1763 }
1764
1765 NtClose( SockEvent );
1766
1767 return 0;
1768
1769 }
1770
1771
1772 int
1773 SetSocketInformation(
1774 PSOCKET_INFORMATION Socket,
1775 ULONG AfdInformationClass,
1776 PULONG Ulong OPTIONAL,
1777 PLARGE_INTEGER LargeInteger OPTIONAL)
1778 {
1779 IO_STATUS_BLOCK IOSB;
1780 AFD_INFO InfoData;
1781 NTSTATUS Status;
1782 HANDLE SockEvent;
1783
1784 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1785 NULL, 1, FALSE );
1786
1787 if( !NT_SUCCESS(Status) ) return -1;
1788
1789 /* Set Info Class */
1790 InfoData.InformationClass = AfdInformationClass;
1791
1792 /* Set Information */
1793 InfoData.Information.Ulong = *Ulong;
1794 if (LargeInteger != NULL) {
1795 InfoData.Information.LargeInteger = *LargeInteger;
1796 }
1797
1798 AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
1799 AfdInformationClass, *Ulong));
1800
1801 /* Send IOCTL */
1802 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1803 SockEvent,
1804 NULL,
1805 NULL,
1806 &IOSB,
1807 IOCTL_AFD_GET_INFO,
1808 &InfoData,
1809 sizeof(InfoData),
1810 NULL,
1811 0);
1812
1813 /* Wait for return */
1814 if (Status == STATUS_PENDING) {
1815 WaitForSingleObject(SockEvent, INFINITE);
1816 }
1817
1818 NtClose( SockEvent );
1819
1820 return 0;
1821
1822 }
1823
1824 PSOCKET_INFORMATION
1825 GetSocketStructure(
1826 SOCKET Handle)
1827 {
1828 ULONG i;
1829
1830 for (i=0; i<SocketCount; i++) {
1831 if (Sockets[i]->Handle == Handle) {
1832 return Sockets[i];
1833 }
1834 }
1835 return 0;
1836 }
1837
1838 int CreateContext(PSOCKET_INFORMATION Socket)
1839 {
1840 IO_STATUS_BLOCK IOSB;
1841 SOCKET_CONTEXT ContextData;
1842 NTSTATUS Status;
1843 HANDLE SockEvent;
1844
1845 Status = NtCreateEvent( &SockEvent, GENERIC_READ | GENERIC_WRITE,
1846 NULL, 1, FALSE );
1847
1848 if( !NT_SUCCESS(Status) ) return -1;
1849
1850 /* Create Context */
1851 ContextData.SharedData = Socket->SharedData;
1852 ContextData.SizeOfHelperData = 0;
1853 RtlCopyMemory (&ContextData.LocalAddress,
1854 Socket->LocalAddress,
1855 Socket->SharedData.SizeOfLocalAddress);
1856 RtlCopyMemory (&ContextData.RemoteAddress,
1857 Socket->RemoteAddress,
1858 Socket->SharedData.SizeOfRemoteAddress);
1859
1860 /* Send IOCTL */
1861 Status = NtDeviceIoControlFile( (HANDLE)Socket->Handle,
1862 SockEvent,
1863 NULL,
1864 NULL,
1865 &IOSB,
1866 IOCTL_AFD_SET_CONTEXT,
1867 &ContextData,
1868 sizeof(ContextData),
1869 NULL,
1870 0);
1871
1872 /* Wait for Completition */
1873 if (Status == STATUS_PENDING) {
1874 WaitForSingleObject(SockEvent, INFINITE);
1875 }
1876
1877 NtClose( SockEvent );
1878
1879 return 0;
1880 }
1881
1882 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
1883 {
1884 HANDLE hAsyncThread;
1885 DWORD AsyncThreadId;
1886 HANDLE AsyncEvent;
1887 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
1888 NTSTATUS Status;
1889
1890 /* Check if the Thread Already Exists */
1891 if (SockAsyncThreadRefCount) {
1892 return TRUE;
1893 }
1894
1895 /* Create the Completion Port */
1896 if (!SockAsyncCompletionPort) {
1897 Status = NtCreateIoCompletion(&SockAsyncCompletionPort,
1898 IO_COMPLETION_ALL_ACCESS,
1899 NULL,
1900 2); // Allow 2 threads only
1901
1902 /* Protect Handle */
1903 HandleFlags.ProtectFromClose = TRUE;
1904 HandleFlags.Inherit = FALSE;
1905 Status = NtSetInformationObject(SockAsyncCompletionPort,
1906 ObjectHandleFlagInformation,
1907 &HandleFlags,
1908 sizeof(HandleFlags));
1909 }
1910
1911 /* Create the Async Event */
1912 Status = NtCreateEvent(&AsyncEvent,
1913 EVENT_ALL_ACCESS,
1914 NULL,
1915 NotificationEvent,
1916 FALSE);
1917
1918 /* Create the Async Thread */
1919 hAsyncThread = CreateThread(NULL,
1920 0,
1921 (LPTHREAD_START_ROUTINE)SockAsyncThread,
1922 NULL,
1923 0,
1924 &AsyncThreadId);
1925
1926 /* Close the Handle */
1927 NtClose(hAsyncThread);
1928
1929 /* Increase the Reference Count */
1930 SockAsyncThreadRefCount++;
1931 return TRUE;
1932 }
1933
1934 int SockAsyncThread(PVOID ThreadParam)
1935 {
1936 PVOID AsyncContext;
1937 PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
1938 IO_STATUS_BLOCK IOSB;
1939 NTSTATUS Status;
1940
1941 /* Make the Thread Higher Priority */
1942 SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
1943
1944 /* Do a KQUEUE/WorkItem Style Loop, thanks to IoCompletion Ports */
1945 do {
1946 Status = NtRemoveIoCompletion (SockAsyncCompletionPort,
1947 (PVOID*)&AsyncCompletionRoutine,
1948 &AsyncContext,
1949 &IOSB,
1950 NULL);
1951
1952 /* Call the Async Function */
1953 if (NT_SUCCESS(Status)) {
1954 (*AsyncCompletionRoutine)(AsyncContext, &IOSB);
1955 } else {
1956 /* It Failed, sleep for a second */
1957 Sleep(1000);
1958 }
1959 } while ((Status != STATUS_TIMEOUT));
1960
1961 /* The Thread has Ended */
1962 return 0;
1963 }
1964
1965 BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
1966 {
1967 UNICODE_STRING AfdHelper;
1968 OBJECT_ATTRIBUTES ObjectAttributes;
1969 IO_STATUS_BLOCK IoSb;
1970 NTSTATUS Status;
1971 FILE_COMPLETION_INFORMATION CompletionInfo;
1972 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
1973
1974 /* First, make sure we're not already intialized */
1975 if (SockAsyncHelperAfdHandle) {
1976 return TRUE;
1977 }
1978
1979 /* Set up Handle Name and Object */
1980 RtlInitUnicodeString(&AfdHelper, L"\\Device\\Afd\\AsyncSelectHlp" );
1981 InitializeObjectAttributes(&ObjectAttributes,
1982 &AfdHelper,
1983 OBJ_INHERIT | OBJ_CASE_INSENSITIVE,
1984 NULL,
1985 NULL);
1986
1987 /* Open the Handle to AFD */
1988 Status = NtCreateFile(&SockAsyncHelperAfdHandle,
1989 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
1990 &ObjectAttributes,
1991 &IoSb,
1992 NULL,
1993 0,
1994 FILE_SHARE_READ | FILE_SHARE_WRITE,
1995 FILE_OPEN_IF,
1996 0,
1997 NULL,
1998 0);
1999
2000 /*
2001 * Now Set up the Completion Port Information
2002 * This means that whenever a Poll is finished, the routine will be executed
2003 */
2004 CompletionInfo.Port = SockAsyncCompletionPort;
2005 CompletionInfo.Key = SockAsyncSelectCompletionRoutine;
2006 Status = NtSetInformationFile(SockAsyncHelperAfdHandle,
2007 &IoSb,
2008 &CompletionInfo,
2009 sizeof(CompletionInfo),
2010 FileCompletionInformation);
2011
2012
2013 /* Protect the Handle */
2014 HandleFlags.ProtectFromClose = TRUE;
2015 HandleFlags.Inherit = FALSE;
2016 Status = NtSetInformationObject(SockAsyncCompletionPort,
2017 ObjectHandleFlagInformation,
2018 &HandleFlags,
2019 sizeof(HandleFlags));
2020
2021
2022 /* Set this variable to true so that Send/Recv/Accept will know wether to renable disabled events */
2023 SockAsyncSelectCalled = TRUE;
2024 return TRUE;
2025 }
2026
2027 VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2028 {
2029
2030 PASYNC_DATA AsyncData = Context;
2031 PSOCKET_INFORMATION Socket;
2032 ULONG x;
2033
2034 /* Get the Socket */
2035 Socket = AsyncData->ParentSocket;
2036
2037 /* Check if the Sequence Number Changed behind our back */
2038 if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber ){
2039 return;
2040 }
2041
2042 /* Check we were manually called b/c of a failure */
2043 if (!NT_SUCCESS(IoStatusBlock->Status)) {
2044 /* FIXME: Perform Upcall */
2045 return;
2046 }
2047
2048 for (x = 1; x; x<<=1) {
2049 switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x) {
2050 case AFD_EVENT_RECEIVE:
2051 if (0 != (Socket->SharedData.AsyncEvents & FD_READ) && 0 == (Socket->SharedData.AsyncDisabledEvents & FD_READ)) {
2052 /* Make the Notifcation */
2053 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2054 Socket->SharedData.wMsg,
2055 Socket->Handle,
2056 WSAMAKESELECTREPLY(FD_READ, 0));
2057 /* Disable this event until the next read(); */
2058 Socket->SharedData.AsyncDisabledEvents |= FD_READ;
2059 }
2060 break;
2061
2062 case AFD_EVENT_OOB_RECEIVE:
2063 if (0 != (Socket->SharedData.AsyncEvents & FD_OOB) && 0 == (Socket->SharedData.AsyncDisabledEvents & FD_OOB)) {
2064 /* Make the Notifcation */
2065 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2066 Socket->SharedData.wMsg,
2067 Socket->Handle,
2068 WSAMAKESELECTREPLY(FD_OOB, 0));
2069 /* Disable this event until the next read(); */
2070 Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
2071 }
2072 break;
2073
2074 case AFD_EVENT_SEND:
2075 if (0 != (Socket->SharedData.AsyncEvents & FD_WRITE) && 0 == (Socket->SharedData.AsyncDisabledEvents & FD_WRITE)) {
2076 /* Make the Notifcation */
2077 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2078 Socket->SharedData.wMsg,
2079 Socket->Handle,
2080 WSAMAKESELECTREPLY(FD_WRITE, 0));
2081 /* Disable this event until the next write(); */
2082 Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
2083 }
2084 break;
2085
2086 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2087 case AFD_EVENT_CONNECT:
2088 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) && 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT)) {
2089 /* Make the Notifcation */
2090 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2091 Socket->SharedData.wMsg,
2092 Socket->Handle,
2093 WSAMAKESELECTREPLY(FD_CONNECT, 0));
2094 /* Disable this event forever; */
2095 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT;
2096 }
2097 break;
2098
2099 case AFD_EVENT_ACCEPT:
2100 if (0 != (Socket->SharedData.AsyncEvents & FD_ACCEPT) && 0 == (Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT)) {
2101 /* Make the Notifcation */
2102 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2103 Socket->SharedData.wMsg,
2104 Socket->Handle,
2105 WSAMAKESELECTREPLY(FD_ACCEPT, 0));
2106 /* Disable this event until the next accept(); */
2107 Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
2108 }
2109 break;
2110
2111 case AFD_EVENT_DISCONNECT:
2112 case AFD_EVENT_ABORT:
2113 case AFD_EVENT_CLOSE:
2114 if (0 != (Socket->SharedData.AsyncEvents & FD_CLOSE) && 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CLOSE)) {
2115 /* Make the Notifcation */
2116 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2117 Socket->SharedData.wMsg,
2118 Socket->Handle,
2119 WSAMAKESELECTREPLY(FD_CLOSE, 0));
2120 /* Disable this event forever; */
2121 Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
2122 }
2123 break;
2124
2125 /* FIXME: Support QOS */
2126 }
2127 }
2128
2129 /* Check if there are any events left for us to check */
2130 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 ) {
2131 return;
2132 }
2133
2134 /* Keep Polling */
2135 SockProcessAsyncSelect(Socket, AsyncData);
2136 return;
2137 }
2138
2139 VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
2140 {
2141
2142 ULONG lNetworkEvents;
2143 NTSTATUS Status;
2144
2145 /* Set up the Async Data Event Info */
2146 AsyncData->AsyncSelectInfo.Timeout.HighPart = 0x7FFFFFFF;
2147 AsyncData->AsyncSelectInfo.Timeout.LowPart = 0xFFFFFFFF;
2148 AsyncData->AsyncSelectInfo.HandleCount = 1;
2149 AsyncData->AsyncSelectInfo.Exclusive = TRUE;
2150 AsyncData->AsyncSelectInfo.Handles[0].Handle = Socket->Handle;
2151 AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
2152
2153 /* Remove unwanted events */
2154 lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
2155
2156 /* Set Events to wait for */
2157 if (lNetworkEvents & FD_READ) {
2158 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_RECEIVE;
2159 }
2160
2161 if (lNetworkEvents & FD_WRITE) {
2162 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_SEND;
2163 }
2164
2165 if (lNetworkEvents & FD_OOB) {
2166 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_OOB_RECEIVE;
2167 }
2168
2169 if (lNetworkEvents & FD_ACCEPT) {
2170 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_ACCEPT;
2171 }
2172
2173 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2174 if (lNetworkEvents & FD_CONNECT) {
2175 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_CONNECT | AFD_EVENT_CONNECT_FAIL;
2176 }
2177
2178 if (lNetworkEvents & FD_CLOSE) {
2179 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
2180 }
2181
2182 if (lNetworkEvents & FD_QOS) {
2183 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_QOS;
2184 }
2185
2186 if (lNetworkEvents & FD_GROUP_QOS) {
2187 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_GROUP_QOS;
2188 }
2189
2190 /* Send IOCTL */
2191 Status = NtDeviceIoControlFile (SockAsyncHelperAfdHandle,
2192 NULL,
2193 NULL,
2194 AsyncData,
2195 &AsyncData->IoStatusBlock,
2196 IOCTL_AFD_SELECT,
2197 &AsyncData->AsyncSelectInfo,
2198 sizeof(AsyncData->AsyncSelectInfo),
2199 &AsyncData->AsyncSelectInfo,
2200 sizeof(AsyncData->AsyncSelectInfo));
2201
2202 /* I/O Manager Won't call the completion routine, let's do it manually */
2203 if (NT_SUCCESS(Status)) {
2204 return;
2205 } else {
2206 AsyncData->IoStatusBlock.Status = Status;
2207 SockAsyncSelectCompletionRoutine(AsyncData, &AsyncData->IoStatusBlock);
2208 }
2209 }
2210
2211 VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2212 {
2213 PASYNC_DATA AsyncData = Context;
2214 BOOL FreeContext = TRUE;
2215 PSOCKET_INFORMATION Socket;
2216
2217 /* Get the Socket */
2218 Socket = AsyncData->ParentSocket;
2219
2220 /* If someone closed it, stop the function */
2221 if (Socket->SharedData.State != SocketClosed) {
2222 /* Check if the Sequence Number changed by now, in which case quit */
2223 if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber) {
2224 /* Do the actuall select, if needed */
2225 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents))) {
2226 SockProcessAsyncSelect(Socket, AsyncData);
2227 FreeContext = FALSE;
2228 }
2229 }
2230 }
2231
2232 /* Free the Context */
2233 if (FreeContext) {
2234 HeapFree(GetProcessHeap(), 0, AsyncData);
2235 }
2236
2237 return;
2238 }
2239
2240 VOID
2241 SockReenableAsyncSelectEvent (
2242 IN PSOCKET_INFORMATION Socket,
2243 IN ULONG Event
2244 )
2245 {
2246 PASYNC_DATA AsyncData;
2247
2248 /* Make sure the event is actually disabled */
2249 if (!(Socket->SharedData.AsyncDisabledEvents & Event)) {
2250 return;
2251 }
2252
2253 /* Re-enable it */
2254 Socket->SharedData.AsyncDisabledEvents &= ~Event;
2255
2256 /* Return if no more events are being polled */
2257 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 ) {
2258 return;
2259 }
2260
2261 /* Wait on new events */
2262 AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
2263
2264 /* Create the Asynch Thread if Needed */
2265 SockCreateOrReferenceAsyncThread();
2266
2267 /* Increase the sequence number to stop anything else */
2268 Socket->SharedData.SequenceNumber++;
2269
2270 /* Set up the Async Data */
2271 AsyncData->ParentSocket = Socket;
2272 AsyncData->SequenceNumber = Socket->SharedData.SequenceNumber;
2273
2274 /* Begin Async Select by using I/O Completion */
2275 NtSetIoCompletion(SockAsyncCompletionPort,
2276 (PVOID)&SockProcessQueuedAsyncSelect,
2277 AsyncData,
2278 0,
2279 0);
2280
2281 /* All done */
2282 return;
2283 }
2284
2285 BOOL
2286 STDCALL
2287 DllMain(HANDLE hInstDll,
2288 ULONG dwReason,
2289 PVOID Reserved)
2290 {
2291
2292 switch (dwReason) {
2293 case DLL_PROCESS_ATTACH:
2294
2295 AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
2296
2297 /* Don't need thread attach notifications
2298 so disable them to improve performance */
2299 DisableThreadLibraryCalls(hInstDll);
2300
2301 /* List of DLL Helpers */
2302 InitializeListHead(&SockHelpersListHead);
2303
2304 /* Heap to use when allocating */
2305 GlobalHeap = GetProcessHeap();
2306
2307 /* Allocate Heap for 1024 Sockets, can be expanded later */
2308 Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
2309
2310 AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
2311
2312 break;
2313
2314 case DLL_THREAD_ATTACH:
2315 break;
2316
2317 case DLL_THREAD_DETACH:
2318 break;
2319
2320 case DLL_PROCESS_DETACH:
2321 break;
2322 }
2323
2324 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
2325
2326 return TRUE;
2327 }
2328
2329 /* EOF */
2330
2331