Sync trunk r40500
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
4 * FILE: misc/dllmain.c
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7 * Alex Ionescu (alex@relsoft.net)
8 * REVISIONS:
9 * CSH 01/09-2000 Created
10 * Alex 16/07/2004 - Complete Rewrite
11 */
12
13 #include <msafd.h>
14
15 #include <debug.h>
16
17 #ifdef DBG
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
19 DWORD DebugTraceLevel = 0;
20 #endif /* DBG */
21
22 HANDLE GlobalHeap;
23 WSPUPCALLTABLE Upcalls;
24 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
25 ULONG SocketCount = 0;
26 PSOCKET_INFORMATION *Sockets = NULL;
27 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
28 ULONG SockAsyncThreadRefCount;
29 HANDLE SockAsyncHelperAfdHandle;
30 HANDLE SockAsyncCompletionPort;
31 BOOLEAN SockAsyncSelectCalled;
32
33
34
35 /*
36 * FUNCTION: Creates a new socket
37 * ARGUMENTS:
38 * af = Address family
39 * type = Socket type
40 * protocol = Protocol type
41 * lpProtocolInfo = Pointer to protocol information
42 * g = Reserved
43 * dwFlags = Socket flags
44 * lpErrno = Address of buffer for error information
45 * RETURNS:
46 * Created socket, or INVALID_SOCKET if it could not be created
47 */
48 SOCKET
49 WSPAPI
50 WSPSocket(int AddressFamily,
51 int SocketType,
52 int Protocol,
53 LPWSAPROTOCOL_INFOW lpProtocolInfo,
54 GROUP g,
55 DWORD dwFlags,
56 LPINT lpErrno)
57 {
58 OBJECT_ATTRIBUTES Object;
59 IO_STATUS_BLOCK IOSB;
60 USHORT SizeOfPacket;
61 ULONG SizeOfEA;
62 PAFD_CREATE_PACKET AfdPacket;
63 HANDLE Sock;
64 PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
65 PFILE_FULL_EA_INFORMATION EABuffer = NULL;
66 PHELPER_DATA HelperData;
67 PVOID HelperDLLContext;
68 DWORD HelperEvents;
69 UNICODE_STRING TransportName;
70 UNICODE_STRING DevName;
71 LARGE_INTEGER GroupData;
72 INT Status;
73
74 AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
75 AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
76 AddressFamily, SocketType, Protocol));
77
78 /* Get Helper Data and Transport */
79 Status = SockGetTdiName (&AddressFamily,
80 &SocketType,
81 &Protocol,
82 g,
83 dwFlags,
84 &TransportName,
85 &HelperDLLContext,
86 &HelperData,
87 &HelperEvents);
88
89 /* Check for error */
90 if (Status != NO_ERROR)
91 {
92 AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
93 goto error;
94 }
95
96 /* AFD Device Name */
97 RtlInitUnicodeString(&DevName, L"\\Device\\Afd\\Endpoint");
98
99 /* Set Socket Data */
100 Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
101 RtlZeroMemory(Socket, sizeof(*Socket));
102 Socket->RefCount = 2;
103 Socket->Handle = -1;
104 Socket->SharedData.Listening = FALSE;
105 Socket->SharedData.State = SocketOpen;
106 Socket->SharedData.AddressFamily = AddressFamily;
107 Socket->SharedData.SocketType = SocketType;
108 Socket->SharedData.Protocol = Protocol;
109 Socket->HelperContext = HelperDLLContext;
110 Socket->HelperData = HelperData;
111 Socket->HelperEvents = HelperEvents;
112 Socket->LocalAddress = &Socket->WSLocalAddress;
113 Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
114 Socket->RemoteAddress = &Socket->WSRemoteAddress;
115 Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
116 Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
117 Socket->SharedData.CreateFlags = dwFlags;
118 Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
119 Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
120 Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
121 Socket->SharedData.GroupID = g;
122 Socket->SharedData.GroupType = 0;
123 Socket->SharedData.UseSAN = FALSE;
124 Socket->SharedData.NonBlocking = FALSE; /* Sockets start blocking */
125 Socket->SanData = NULL;
126
127 /* Ask alex about this */
128 if( Socket->SharedData.SocketType == SOCK_DGRAM ||
129 Socket->SharedData.SocketType == SOCK_RAW )
130 {
131 AFD_DbgPrint(MID_TRACE,("Connectionless socket\n"));
132 Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
133 }
134
135 /* Packet Size */
136 SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
137
138 /* EA Size */
139 SizeOfEA = SizeOfPacket + sizeof(FILE_FULL_EA_INFORMATION) + AFD_PACKET_COMMAND_LENGTH;
140
141 /* Set up EA Buffer */
142 EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
143 RtlZeroMemory(EABuffer, SizeOfEA);
144 EABuffer->NextEntryOffset = 0;
145 EABuffer->Flags = 0;
146 EABuffer->EaNameLength = AFD_PACKET_COMMAND_LENGTH;
147 RtlCopyMemory (EABuffer->EaName,
148 AfdCommand,
149 AFD_PACKET_COMMAND_LENGTH + 1);
150 EABuffer->EaValueLength = SizeOfPacket;
151
152 /* Set up AFD Packet */
153 AfdPacket = (PAFD_CREATE_PACKET)(EABuffer->EaName + EABuffer->EaNameLength + 1);
154 AfdPacket->SizeOfTransportName = TransportName.Length;
155 RtlCopyMemory (AfdPacket->TransportName,
156 TransportName.Buffer,
157 TransportName.Length + sizeof(WCHAR));
158 AfdPacket->GroupID = g;
159
160 /* Set up Endpoint Flags */
161 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0)
162 {
163 if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW))
164 {
165 /* Only RAW or UDP can be Connectionless */
166 goto error;
167 }
168 AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
169 }
170
171 if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0)
172 {
173 if (SocketType == SOCK_STREAM)
174 {
175 if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
176 {
177 /* The Provider doesn't actually support Message Oriented Streams */
178 goto error;
179 }
180 }
181 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MESSAGE_ORIENTED;
182 }
183
184 if (SocketType == SOCK_RAW) AfdPacket->EndpointFlags |= AFD_ENDPOINT_RAW;
185
186 if (dwFlags & (WSA_FLAG_MULTIPOINT_C_ROOT |
187 WSA_FLAG_MULTIPOINT_C_LEAF |
188 WSA_FLAG_MULTIPOINT_D_ROOT |
189 WSA_FLAG_MULTIPOINT_D_LEAF))
190 {
191 if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
192 {
193 /* The Provider doesn't actually support Multipoint */
194 goto error;
195 }
196 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
197
198 if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT)
199 {
200 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
201 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0))
202 {
203 /* The Provider doesn't support Control Planes, or you already gave a leaf */
204 goto error;
205 }
206 AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
207 }
208
209 if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT)
210 {
211 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
212 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0))
213 {
214 /* The Provider doesn't support Data Planes, or you already gave a leaf */
215 goto error;
216 }
217 AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
218 }
219 }
220
221 /* Set up Object Attributes */
222 InitializeObjectAttributes (&Object,
223 &DevName,
224 OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
225 0,
226 0);
227
228 /* Create the Socket as asynchronous. That means we have to block
229 ourselves after every call to NtDeviceIoControlFile. This is
230 because the kernel doesn't support overlapping synchronous I/O
231 requests (made from multiple threads) at this time (Sep 2005) */
232 ZwCreateFile(&Sock,
233 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
234 &Object,
235 &IOSB,
236 NULL,
237 0,
238 FILE_SHARE_READ | FILE_SHARE_WRITE, FILE_OPEN_IF,
239 0,
240 EABuffer,
241 SizeOfEA);
242
243 /* Save Handle */
244 Socket->Handle = (SOCKET)Sock;
245
246 /* XXX See if there's a structure we can reuse -- We need to do this
247 * more properly. */
248 PrevSocket = GetSocketStructure( (SOCKET)Sock );
249
250 if( PrevSocket )
251 {
252 RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
253 RtlFreeHeap( GlobalHeap, 0, Socket );
254 Socket = PrevSocket;
255 }
256
257 /* Save Group Info */
258 if (g != 0)
259 {
260 GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
261 Socket->SharedData.GroupID = GroupData.u.LowPart;
262 Socket->SharedData.GroupType = GroupData.u.HighPart;
263 }
264
265 /* Get Window Sizes and Save them */
266 GetSocketInformation (Socket,
267 AFD_INFO_SEND_WINDOW_SIZE,
268 &Socket->SharedData.SizeOfSendBuffer,
269 NULL);
270
271 GetSocketInformation (Socket,
272 AFD_INFO_RECEIVE_WINDOW_SIZE,
273 &Socket->SharedData.SizeOfRecvBuffer,
274 NULL);
275
276 /* Save in Process Sockets List */
277 Sockets[SocketCount] = Socket;
278 SocketCount ++;
279
280 /* Create the Socket Context */
281 CreateContext(Socket);
282
283 /* Notify Winsock */
284 Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
285
286 /* Return Socket Handle */
287 AFD_DbgPrint(MID_TRACE,("Success %x\n", Sock));
288
289 return (SOCKET)Sock;
290
291 error:
292 AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
293
294 if( lpErrno )
295 *lpErrno = Status;
296
297 return INVALID_SOCKET;
298 }
299
300
301 DWORD MsafdReturnWithErrno(NTSTATUS Status,
302 LPINT Errno,
303 DWORD Received,
304 LPDWORD ReturnedBytes)
305 {
306 if( ReturnedBytes )
307 *ReturnedBytes = 0;
308 if( Errno )
309 {
310 switch (Status)
311 {
312 case STATUS_CANT_WAIT:
313 *Errno = WSAEWOULDBLOCK;
314 break;
315 case STATUS_TIMEOUT:
316 *Errno = WSAETIMEDOUT;
317 break;
318 case STATUS_SUCCESS:
319 /* Return Number of bytes Read */
320 if( ReturnedBytes )
321 *ReturnedBytes = Received;
322 break;
323 case STATUS_END_OF_FILE:
324 *Errno = WSAESHUTDOWN;
325 break;
326 case STATUS_PENDING:
327 *Errno = WSA_IO_PENDING;
328 break;
329 case STATUS_BUFFER_TOO_SMALL:
330 case STATUS_BUFFER_OVERFLOW:
331 DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
332 *Errno = WSAEMSGSIZE;
333 break;
334 case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
335 case STATUS_INSUFFICIENT_RESOURCES:
336 DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
337 *Errno = WSA_NOT_ENOUGH_MEMORY;
338 break;
339 case STATUS_INVALID_CONNECTION:
340 DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
341 *Errno = WSAEAFNOSUPPORT;
342 break;
343 case STATUS_REMOTE_NOT_LISTENING:
344 DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
345 *Errno = WSAECONNRESET;
346 break;
347 case STATUS_FILE_CLOSED:
348 DbgPrint("MSAFD: STATUS_FILE_CLOSED\n");
349 *Errno = WSAENOTSOCK;
350 break;
351 case STATUS_INVALID_PARAMETER:
352 DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
353 *Errno = WSAEINVAL;
354 break;
355 case STATUS_CANCELLED:
356 DbgPrint("MSAFD: STATUS_CANCELLED\n");
357 *Errno = WSAENOTSOCK;
358 break;
359 default:
360 DbgPrint("MSAFD: Error %x is unknown\n", Status);
361 *Errno = WSAEINVAL;
362 break;
363 }
364 }
365
366 /* Success */
367 return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
368 }
369
370 /*
371 * FUNCTION: Closes an open socket
372 * ARGUMENTS:
373 * s = Socket descriptor
374 * lpErrno = Address of buffer for error information
375 * RETURNS:
376 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
377 */
378 INT
379 WSPAPI
380 WSPCloseSocket(IN SOCKET Handle,
381 OUT LPINT lpErrno)
382 {
383 IO_STATUS_BLOCK IoStatusBlock;
384 PSOCKET_INFORMATION Socket = NULL;
385 NTSTATUS Status;
386 HANDLE SockEvent;
387 AFD_DISCONNECT_INFO DisconnectInfo;
388 SOCKET_STATE OldState;
389
390 /* Create the Wait Event */
391 Status = NtCreateEvent(&SockEvent,
392 GENERIC_READ | GENERIC_WRITE,
393 NULL,
394 1,
395 FALSE);
396
397 if(!NT_SUCCESS(Status))
398 return SOCKET_ERROR;
399
400 /* Get the Socket Structure associate to this Socket*/
401 Socket = GetSocketStructure(Handle);
402
403 /* If a Close is already in Process, give up */
404 if (Socket->SharedData.State == SocketClosed)
405 {
406 NtClose(SockEvent);
407 *lpErrno = WSAENOTSOCK;
408 return SOCKET_ERROR;
409 }
410
411 /* Set the state to close */
412 OldState = Socket->SharedData.State;
413 Socket->SharedData.State = SocketClosed;
414
415 /* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
416 /* FIXME: Should we do this on Datagram Sockets too? */
417 if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
418 {
419 ULONG LingerWait;
420 ULONG SendsInProgress;
421 ULONG SleepWait;
422
423 /* We need to respect the timeout */
424 SleepWait = 100;
425 LingerWait = Socket->SharedData.LingerData.l_linger * 1000;
426
427 /* Loop until no more sends are pending, within the timeout */
428 while (LingerWait)
429 {
430 /* Find out how many Sends are in Progress */
431 if (GetSocketInformation(Socket,
432 AFD_INFO_SENDS_IN_PROGRESS,
433 &SendsInProgress,
434 NULL))
435 {
436 /* Bail out if anything but NO_ERROR */
437 LingerWait = 0;
438 break;
439 }
440
441 /* Bail out if no more sends are pending */
442 if (!SendsInProgress)
443 break;
444 /*
445 * We have to execute a sleep, so it's kind of like
446 * a block. If the socket is Nonblock, we cannot
447 * go on since asyncronous operation is expected
448 * and we cannot offer it
449 */
450 if (Socket->SharedData.NonBlocking)
451 {
452 NtClose(SockEvent);
453 Socket->SharedData.State = OldState;
454 *lpErrno = WSAEWOULDBLOCK;
455 return SOCKET_ERROR;
456 }
457
458 /* Now we can sleep, and decrement the linger wait */
459 /*
460 * FIXME: It seems Windows does some funky acceleration
461 * since the waiting seems to be longer and longer. I
462 * don't think this improves performance so much, so we
463 * wait a fixed time instead.
464 */
465 Sleep(SleepWait);
466 LingerWait -= SleepWait;
467 }
468
469 /*
470 * We have reached the timeout or sends are over.
471 * Disconnect if the timeout has been reached.
472 */
473 if (LingerWait <= 0)
474 {
475 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
476 DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
477
478 /* Send IOCTL */
479 Status = NtDeviceIoControlFile((HANDLE)Handle,
480 SockEvent,
481 NULL,
482 NULL,
483 &IoStatusBlock,
484 IOCTL_AFD_DISCONNECT,
485 &DisconnectInfo,
486 sizeof(DisconnectInfo),
487 NULL,
488 0);
489
490 /* Wait for return */
491 if (Status == STATUS_PENDING)
492 {
493 WaitForSingleObject(SockEvent, INFINITE);
494 }
495 }
496 }
497
498 /* FIXME: We should notify the Helper DLL of WSH_NOTIFY_CLOSE */
499
500 /* Cleanup Time! */
501 Socket->HelperContext = NULL;
502 Socket->SharedData.AsyncDisabledEvents = -1;
503 NtClose(Socket->TdiAddressHandle);
504 Socket->TdiAddressHandle = NULL;
505 NtClose(Socket->TdiConnectionHandle);
506 Socket->TdiConnectionHandle = NULL;
507
508 /* Close the handle */
509 NtClose((HANDLE)Handle);
510 NtClose(SockEvent);
511
512 return NO_ERROR;
513 }
514
515
516 /*
517 * FUNCTION: Associates a local address with a socket
518 * ARGUMENTS:
519 * s = Socket descriptor
520 * name = Pointer to local address
521 * namelen = Length of name
522 * lpErrno = Address of buffer for error information
523 * RETURNS:
524 * 0, or SOCKET_ERROR if the socket could not be bound
525 */
526 INT
527 WSPAPI
528 WSPBind(SOCKET Handle,
529 const struct sockaddr *SocketAddress,
530 int SocketAddressLength,
531 LPINT lpErrno)
532 {
533 IO_STATUS_BLOCK IOSB;
534 PAFD_BIND_DATA BindData;
535 PSOCKET_INFORMATION Socket = NULL;
536 NTSTATUS Status;
537 UCHAR BindBuffer[0x1A];
538 SOCKADDR_INFO SocketInfo;
539 HANDLE SockEvent;
540
541 Status = NtCreateEvent(&SockEvent,
542 GENERIC_READ | GENERIC_WRITE,
543 NULL,
544 1,
545 FALSE);
546
547 if( !NT_SUCCESS(Status) )
548 return -1;
549
550 /* Get the Socket Structure associate to this Socket*/
551 Socket = GetSocketStructure(Handle);
552
553 /* Dynamic Structure...ugh */
554 BindData = (PAFD_BIND_DATA)BindBuffer;
555
556 /* Set up Address in TDI Format */
557 BindData->Address.TAAddressCount = 1;
558 BindData->Address.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
559 BindData->Address.Address[0].AddressType = SocketAddress->sa_family;
560 RtlCopyMemory (BindData->Address.Address[0].Address,
561 SocketAddress->sa_data,
562 SocketAddressLength - sizeof(SocketAddress->sa_family));
563
564 /* Get Address Information */
565 Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
566 SocketAddressLength,
567 &SocketInfo);
568
569 /* Set the Share Type */
570 if (Socket->SharedData.ExclusiveAddressUse)
571 {
572 BindData->ShareType = AFD_SHARE_EXCLUSIVE;
573 }
574 else if (SocketInfo.EndpointInfo == SockaddrEndpointInfoWildcard)
575 {
576 BindData->ShareType = AFD_SHARE_WILDCARD;
577 }
578 else if (Socket->SharedData.ReuseAddresses)
579 {
580 BindData->ShareType = AFD_SHARE_REUSE;
581 }
582 else
583 {
584 BindData->ShareType = AFD_SHARE_UNIQUE;
585 }
586
587 /* Send IOCTL */
588 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
589 SockEvent,
590 NULL,
591 NULL,
592 &IOSB,
593 IOCTL_AFD_BIND,
594 BindData,
595 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
596 BindData,
597 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
598
599 /* Wait for return */
600 if (Status == STATUS_PENDING)
601 {
602 WaitForSingleObject(SockEvent, INFINITE);
603 Status = IOSB.Status;
604 }
605
606 /* Set up Socket Data */
607 Socket->SharedData.State = SocketBound;
608 Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
609
610 NtClose( SockEvent );
611
612 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
613 }
614
615 int
616 WSPAPI
617 WSPListen(SOCKET Handle,
618 int Backlog,
619 LPINT lpErrno)
620 {
621 IO_STATUS_BLOCK IOSB;
622 AFD_LISTEN_DATA ListenData;
623 PSOCKET_INFORMATION Socket = NULL;
624 HANDLE SockEvent;
625 NTSTATUS Status;
626
627 /* Get the Socket Structure associate to this Socket*/
628 Socket = GetSocketStructure(Handle);
629
630 if (Socket->SharedData.Listening)
631 return 0;
632
633 Status = NtCreateEvent(&SockEvent,
634 GENERIC_READ | GENERIC_WRITE,
635 NULL,
636 1,
637 FALSE);
638
639 if( !NT_SUCCESS(Status) )
640 return -1;
641
642 /* Set Up Listen Structure */
643 ListenData.UseSAN = FALSE;
644 ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
645 ListenData.Backlog = Backlog;
646
647 /* Send IOCTL */
648 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
649 SockEvent,
650 NULL,
651 NULL,
652 &IOSB,
653 IOCTL_AFD_START_LISTEN,
654 &ListenData,
655 sizeof(ListenData),
656 NULL,
657 0);
658
659 /* Wait for return */
660 if (Status == STATUS_PENDING)
661 {
662 WaitForSingleObject(SockEvent, INFINITE);
663 Status = IOSB.Status;
664 }
665
666 /* Set to Listening */
667 Socket->SharedData.Listening = TRUE;
668
669 NtClose( SockEvent );
670
671 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
672 }
673
674
675 int
676 WSPAPI
677 WSPSelect(int nfds,
678 fd_set *readfds,
679 fd_set *writefds,
680 fd_set *exceptfds,
681 struct timeval *timeout,
682 LPINT lpErrno)
683 {
684 IO_STATUS_BLOCK IOSB;
685 PAFD_POLL_INFO PollInfo;
686 NTSTATUS Status;
687 LONG HandleCount, OutCount = 0;
688 ULONG PollBufferSize;
689 PVOID PollBuffer;
690 ULONG i, j = 0, x;
691 HANDLE SockEvent;
692 BOOL HandleCounted;
693 LARGE_INTEGER Timeout;
694
695 /* Find out how many sockets we have, and how large the buffer needs
696 * to be */
697
698 HandleCount = ( readfds ? readfds->fd_count : 0 ) +
699 ( writefds ? writefds->fd_count : 0 ) +
700 ( exceptfds ? exceptfds->fd_count : 0 );
701
702 if( HandleCount < 0 || nfds != 0 )
703 HandleCount = nfds * 3;
704
705 PollBufferSize = sizeof(*PollInfo) + (HandleCount * sizeof(AFD_HANDLE));
706
707 AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n",
708 HandleCount, PollBufferSize));
709
710 /* Convert Timeout to NT Format */
711 if (timeout == NULL)
712 {
713 Timeout.u.LowPart = -1;
714 Timeout.u.HighPart = 0x7FFFFFFF;
715 AFD_DbgPrint(MAX_TRACE,("Infinite timeout\n"));
716 }
717 else
718 {
719 Timeout = RtlEnlargedIntegerMultiply
720 ((timeout->tv_sec * 1000) + (timeout->tv_usec / 1000), -10000);
721 /* Negative timeouts are illegal. Since the kernel represents an
722 * incremental timeout as a negative number, we check for a positive
723 * result.
724 */
725 if (Timeout.QuadPart > 0)
726 {
727 if (lpErrno) *lpErrno = WSAEINVAL;
728 return SOCKET_ERROR;
729 }
730 AFD_DbgPrint(MAX_TRACE,("Timeout: Orig %d.%06d kernel %d\n",
731 timeout->tv_sec, timeout->tv_usec,
732 Timeout.u.LowPart));
733 }
734
735 Status = NtCreateEvent(&SockEvent,
736 GENERIC_READ | GENERIC_WRITE,
737 NULL,
738 1,
739 FALSE);
740
741 if( !NT_SUCCESS(Status) )
742 return SOCKET_ERROR;
743
744 /* Allocate */
745 PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
746
747 if (!PollBuffer)
748 {
749 if (lpErrno)
750 *lpErrno = WSAEFAULT;
751 NtClose(SockEvent);
752 return SOCKET_ERROR;
753 }
754
755 PollInfo = (PAFD_POLL_INFO)PollBuffer;
756
757 RtlZeroMemory( PollInfo, PollBufferSize );
758
759 /* Number of handles for AFD to Check */
760 PollInfo->Exclusive = FALSE;
761 PollInfo->Timeout = Timeout;
762
763 if (readfds != NULL) {
764 for (i = 0; i < readfds->fd_count; i++, j++)
765 {
766 PollInfo->Handles[j].Handle = readfds->fd_array[i];
767 PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
768 AFD_EVENT_DISCONNECT |
769 AFD_EVENT_ABORT |
770 AFD_EVENT_ACCEPT;
771 }
772 }
773 if (writefds != NULL)
774 {
775 for (i = 0; i < writefds->fd_count; i++, j++)
776 {
777 PollInfo->Handles[j].Handle = writefds->fd_array[i];
778 PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
779 }
780 }
781 if (exceptfds != NULL)
782 {
783 for (i = 0; i < exceptfds->fd_count; i++, j++)
784 {
785 PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
786 PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
787 }
788 }
789
790 PollInfo->HandleCount = j;
791 PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
792
793 /* Send IOCTL */
794 Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
795 SockEvent,
796 NULL,
797 NULL,
798 &IOSB,
799 IOCTL_AFD_SELECT,
800 PollInfo,
801 PollBufferSize,
802 PollInfo,
803 PollBufferSize);
804
805 AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
806
807 /* Wait for Completition */
808 if (Status == STATUS_PENDING)
809 {
810 WaitForSingleObject(SockEvent, INFINITE);
811 }
812
813 /* Clear the Structures */
814 if( readfds )
815 FD_ZERO(readfds);
816 if( writefds )
817 FD_ZERO(writefds);
818 if( exceptfds )
819 FD_ZERO(exceptfds);
820
821 /* Loop through return structure */
822 HandleCount = PollInfo->HandleCount;
823
824 /* Return in FDSET Format */
825 for (i = 0; i < HandleCount; i++)
826 {
827 HandleCounted = FALSE;
828 for(x = 1; x; x<<=1)
829 {
830 switch (PollInfo->Handles[i].Events & x)
831 {
832 case AFD_EVENT_RECEIVE:
833 case AFD_EVENT_DISCONNECT:
834 case AFD_EVENT_ABORT:
835 case AFD_EVENT_ACCEPT:
836 case AFD_EVENT_CLOSE:
837 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
838 PollInfo->Handles[i].Events,
839 PollInfo->Handles[i].Handle));
840 if (! HandleCounted)
841 {
842 OutCount++;
843 HandleCounted = TRUE;
844 }
845 if( readfds )
846 FD_SET(PollInfo->Handles[i].Handle, readfds);
847 break;
848 case AFD_EVENT_SEND:
849 case AFD_EVENT_CONNECT:
850 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
851 PollInfo->Handles[i].Events,
852 PollInfo->Handles[i].Handle));
853 if (! HandleCounted)
854 {
855 OutCount++;
856 HandleCounted = TRUE;
857 }
858 if( writefds )
859 FD_SET(PollInfo->Handles[i].Handle, writefds);
860 break;
861 case AFD_EVENT_OOB_RECEIVE:
862 case AFD_EVENT_CONNECT_FAIL:
863 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
864 PollInfo->Handles[i].Events,
865 PollInfo->Handles[i].Handle));
866 if (! HandleCounted)
867 {
868 OutCount++;
869 HandleCounted = TRUE;
870 }
871 if( exceptfds )
872 FD_SET(PollInfo->Handles[i].Handle, exceptfds);
873 break;
874 }
875 }
876 }
877
878 HeapFree( GlobalHeap, 0, PollBuffer );
879 NtClose( SockEvent );
880
881 if( lpErrno )
882 {
883 switch( IOSB.Status )
884 {
885 case STATUS_SUCCESS:
886 case STATUS_TIMEOUT:
887 *lpErrno = 0;
888 break;
889 default:
890 *lpErrno = WSAEINVAL;
891 break;
892 }
893 AFD_DbgPrint(MID_TRACE,("*lpErrno = %x\n", *lpErrno));
894 }
895
896 AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
897
898 return OutCount;
899 }
900
901 SOCKET
902 WSPAPI
903 WSPAccept(SOCKET Handle,
904 struct sockaddr *SocketAddress,
905 int *SocketAddressLength,
906 LPCONDITIONPROC lpfnCondition,
907 DWORD_PTR dwCallbackData,
908 LPINT lpErrno)
909 {
910 IO_STATUS_BLOCK IOSB;
911 PAFD_RECEIVED_ACCEPT_DATA ListenReceiveData;
912 AFD_ACCEPT_DATA AcceptData;
913 AFD_DEFER_ACCEPT_DATA DeferData;
914 AFD_PENDING_ACCEPT_DATA PendingAcceptData;
915 PSOCKET_INFORMATION Socket = NULL;
916 NTSTATUS Status;
917 struct fd_set ReadSet;
918 struct timeval Timeout;
919 PVOID PendingData = NULL;
920 ULONG PendingDataLength = 0;
921 PVOID CalleeDataBuffer;
922 WSABUF CallerData, CalleeID, CallerID, CalleeData;
923 PSOCKADDR RemoteAddress = NULL;
924 GROUP GroupID = 0;
925 ULONG CallBack;
926 WSAPROTOCOL_INFOW ProtocolInfo;
927 SOCKET AcceptSocket;
928 UCHAR ReceiveBuffer[0x1A];
929 HANDLE SockEvent;
930
931 Status = NtCreateEvent(&SockEvent,
932 GENERIC_READ | GENERIC_WRITE,
933 NULL,
934 1,
935 FALSE);
936
937 if( !NT_SUCCESS(Status) )
938 {
939 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
940 return INVALID_SOCKET;
941 }
942
943 /* Dynamic Structure...ugh */
944 ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
945
946 /* Get the Socket Structure associate to this Socket*/
947 Socket = GetSocketStructure(Handle);
948
949 /* If this is non-blocking, make sure there's something for us to accept */
950 FD_ZERO(&ReadSet);
951 FD_SET(Socket->Handle, &ReadSet);
952 Timeout.tv_sec=0;
953 Timeout.tv_usec=0;
954
955 WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
956
957 if (ReadSet.fd_array[0] != Socket->Handle)
958 {
959 NtClose(SockEvent);
960 return 0;
961 }
962
963 /* Send IOCTL */
964 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
965 SockEvent,
966 NULL,
967 NULL,
968 &IOSB,
969 IOCTL_AFD_WAIT_FOR_LISTEN,
970 NULL,
971 0,
972 ListenReceiveData,
973 0xA + sizeof(*ListenReceiveData));
974
975 /* Wait for return */
976 if (Status == STATUS_PENDING)
977 {
978 WaitForSingleObject(SockEvent, INFINITE);
979 Status = IOSB.Status;
980 }
981
982 if (!NT_SUCCESS(Status))
983 {
984 NtClose( SockEvent );
985 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
986 return INVALID_SOCKET;
987 }
988
989 if (lpfnCondition != NULL)
990 {
991 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0)
992 {
993 /* Find out how much data is pending */
994 PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
995 PendingAcceptData.ReturnSize = TRUE;
996
997 /* Send IOCTL */
998 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
999 SockEvent,
1000 NULL,
1001 NULL,
1002 &IOSB,
1003 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1004 &PendingAcceptData,
1005 sizeof(PendingAcceptData),
1006 &PendingAcceptData,
1007 sizeof(PendingAcceptData));
1008
1009 /* Wait for return */
1010 if (Status == STATUS_PENDING)
1011 {
1012 WaitForSingleObject(SockEvent, INFINITE);
1013 Status = IOSB.Status;
1014 }
1015
1016 if (!NT_SUCCESS(Status))
1017 {
1018 NtClose( SockEvent );
1019 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1020 return INVALID_SOCKET;
1021 }
1022
1023 /* How much data to allocate */
1024 PendingDataLength = IOSB.Information;
1025
1026 if (PendingDataLength)
1027 {
1028 /* Allocate needed space */
1029 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
1030
1031 /* We want the data now */
1032 PendingAcceptData.ReturnSize = FALSE;
1033
1034 /* Send IOCTL */
1035 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1036 SockEvent,
1037 NULL,
1038 NULL,
1039 &IOSB,
1040 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1041 &PendingAcceptData,
1042 sizeof(PendingAcceptData),
1043 PendingData,
1044 PendingDataLength);
1045
1046 /* Wait for return */
1047 if (Status == STATUS_PENDING)
1048 {
1049 WaitForSingleObject(SockEvent, INFINITE);
1050 Status = IOSB.Status;
1051 }
1052
1053 if (!NT_SUCCESS(Status))
1054 {
1055 NtClose( SockEvent );
1056 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1057 return INVALID_SOCKET;
1058 }
1059 }
1060 }
1061
1062 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1063 {
1064 /* I don't support this yet */
1065 }
1066
1067 /* Build Callee ID */
1068 CalleeID.buf = (PVOID)Socket->LocalAddress;
1069 CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
1070
1071 /* Set up Address in SOCKADDR Format */
1072 RtlCopyMemory (RemoteAddress,
1073 &ListenReceiveData->Address.Address[0].AddressType,
1074 sizeof(*RemoteAddress));
1075
1076 /* Build Caller ID */
1077 CallerID.buf = (PVOID)RemoteAddress;
1078 CallerID.len = sizeof(*RemoteAddress);
1079
1080 /* Build Caller Data */
1081 CallerData.buf = PendingData;
1082 CallerData.len = PendingDataLength;
1083
1084 /* Check if socket supports Conditional Accept */
1085 if (Socket->SharedData.UseDelayedAcceptance != 0)
1086 {
1087 /* Allocate Buffer for Callee Data */
1088 CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
1089 CalleeData.buf = CalleeDataBuffer;
1090 CalleeData.len = 4096;
1091 }
1092 else
1093 {
1094 /* Nothing */
1095 CalleeData.buf = 0;
1096 CalleeData.len = 0;
1097 }
1098
1099 /* Call the Condition Function */
1100 CallBack = (lpfnCondition)(&CallerID,
1101 CallerData.buf == NULL ? NULL : &CallerData,
1102 NULL,
1103 NULL,
1104 &CalleeID,
1105 CalleeData.buf == NULL ? NULL : &CalleeData,
1106 &GroupID,
1107 dwCallbackData);
1108
1109 if (((CallBack == CF_ACCEPT) && GroupID) != 0)
1110 {
1111 /* TBD: Check for Validity */
1112 }
1113
1114 if (CallBack == CF_ACCEPT)
1115 {
1116 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1117 {
1118 /* I don't support this yet */
1119 }
1120 if (CalleeData.buf)
1121 {
1122 // SockSetConnectData Sockets(SocketID), IOCTL_AFD_SET_CONNECT_DATA, CalleeData.Buffer, CalleeData.BuffSize, 0
1123 }
1124 }
1125 else
1126 {
1127 /* Callback rejected. Build Defer Structure */
1128 DeferData.SequenceNumber = ListenReceiveData->SequenceNumber;
1129 DeferData.RejectConnection = (CallBack == CF_REJECT);
1130
1131 /* Send IOCTL */
1132 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1133 SockEvent,
1134 NULL,
1135 NULL,
1136 &IOSB,
1137 IOCTL_AFD_DEFER_ACCEPT,
1138 &DeferData,
1139 sizeof(DeferData),
1140 NULL,
1141 0);
1142
1143 /* Wait for return */
1144 if (Status == STATUS_PENDING)
1145 {
1146 WaitForSingleObject(SockEvent, INFINITE);
1147 Status = IOSB.Status;
1148 }
1149
1150 NtClose( SockEvent );
1151
1152 if (!NT_SUCCESS(Status))
1153 {
1154 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1155 return INVALID_SOCKET;
1156 }
1157
1158 if (CallBack == CF_REJECT )
1159 {
1160 *lpErrno = WSAECONNREFUSED;
1161 return INVALID_SOCKET;
1162 }
1163 else
1164 {
1165 *lpErrno = WSAECONNREFUSED;
1166 return INVALID_SOCKET;
1167 }
1168 }
1169 }
1170
1171 /* Create a new Socket */
1172 ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
1173 ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
1174 ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
1175
1176 AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
1177 Socket->SharedData.SocketType,
1178 Socket->SharedData.Protocol,
1179 &ProtocolInfo,
1180 GroupID,
1181 Socket->SharedData.CreateFlags,
1182 NULL);
1183
1184 /* Set up the Accept Structure */
1185 AcceptData.ListenHandle = AcceptSocket;
1186 AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1187
1188 /* Send IOCTL to Accept */
1189 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1190 SockEvent,
1191 NULL,
1192 NULL,
1193 &IOSB,
1194 IOCTL_AFD_ACCEPT,
1195 &AcceptData,
1196 sizeof(AcceptData),
1197 NULL,
1198 0);
1199
1200 /* Wait for return */
1201 if (Status == STATUS_PENDING)
1202 {
1203 WaitForSingleObject(SockEvent, INFINITE);
1204 Status = IOSB.Status;
1205 }
1206
1207 if (!NT_SUCCESS(Status))
1208 {
1209 NtClose(SockEvent);
1210 WSPCloseSocket( AcceptSocket, lpErrno );
1211 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1212 return INVALID_SOCKET;
1213 }
1214
1215 /* Return Address in SOCKADDR FORMAT */
1216 if( SocketAddress )
1217 {
1218 RtlCopyMemory (SocketAddress,
1219 &ListenReceiveData->Address.Address[0].AddressType,
1220 sizeof(*RemoteAddress));
1221 if( SocketAddressLength )
1222 *SocketAddressLength = ListenReceiveData->Address.Address[0].AddressLength;
1223 }
1224
1225 NtClose( SockEvent );
1226
1227 /* Re-enable Async Event */
1228 SockReenableAsyncSelectEvent(Socket, FD_ACCEPT);
1229
1230 AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
1231
1232 *lpErrno = 0;
1233
1234 /* Return Socket */
1235 return AcceptSocket;
1236 }
1237
1238 int
1239 WSPAPI
1240 WSPConnect(SOCKET Handle,
1241 const struct sockaddr * SocketAddress,
1242 int SocketAddressLength,
1243 LPWSABUF lpCallerData,
1244 LPWSABUF lpCalleeData,
1245 LPQOS lpSQOS,
1246 LPQOS lpGQOS,
1247 LPINT lpErrno)
1248 {
1249 IO_STATUS_BLOCK IOSB;
1250 PAFD_CONNECT_INFO ConnectInfo;
1251 PSOCKET_INFORMATION Socket = NULL;
1252 NTSTATUS Status;
1253 UCHAR ConnectBuffer[0x22];
1254 ULONG ConnectDataLength;
1255 ULONG InConnectDataLength;
1256 INT BindAddressLength;
1257 PSOCKADDR BindAddress;
1258 HANDLE SockEvent;
1259
1260 Status = NtCreateEvent(&SockEvent,
1261 GENERIC_READ | GENERIC_WRITE,
1262 NULL,
1263 1,
1264 FALSE);
1265
1266 if( !NT_SUCCESS(Status) )
1267 return -1;
1268
1269 AFD_DbgPrint(MID_TRACE,("Called\n"));
1270
1271 /* Get the Socket Structure associate to this Socket*/
1272 Socket = GetSocketStructure(Handle);
1273
1274 /* Bind us First */
1275 if (Socket->SharedData.State == SocketOpen)
1276 {
1277 /* Get the Wildcard Address */
1278 BindAddressLength = Socket->HelperData->MaxWSAddressLength;
1279 BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
1280 Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
1281 BindAddress,
1282 &BindAddressLength);
1283 /* Bind it */
1284 WSPBind(Handle, BindAddress, BindAddressLength, NULL);
1285 }
1286
1287 /* Set the Connect Data */
1288 if (lpCallerData != NULL)
1289 {
1290 ConnectDataLength = lpCallerData->len;
1291 Status = NtDeviceIoControlFile((HANDLE)Handle,
1292 SockEvent,
1293 NULL,
1294 NULL,
1295 &IOSB,
1296 IOCTL_AFD_SET_CONNECT_DATA,
1297 lpCallerData->buf,
1298 ConnectDataLength,
1299 NULL,
1300 0);
1301 /* Wait for return */
1302 if (Status == STATUS_PENDING)
1303 {
1304 WaitForSingleObject(SockEvent, INFINITE);
1305 Status = IOSB.Status;
1306 }
1307 }
1308
1309 /* Dynamic Structure...ugh */
1310 ConnectInfo = (PAFD_CONNECT_INFO)ConnectBuffer;
1311
1312 /* Set up Address in TDI Format */
1313 ConnectInfo->RemoteAddress.TAAddressCount = 1;
1314 ConnectInfo->RemoteAddress.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
1315 ConnectInfo->RemoteAddress.Address[0].AddressType = SocketAddress->sa_family;
1316 RtlCopyMemory (ConnectInfo->RemoteAddress.Address[0].Address,
1317 SocketAddress->sa_data,
1318 SocketAddressLength - sizeof(SocketAddress->sa_family));
1319
1320 /*
1321 * Disable FD_WRITE and FD_CONNECT
1322 * The latter fixes a race condition where the FD_CONNECT is re-enabled
1323 * at the end of this function right after the Async Thread disables it.
1324 * This should only happen at the *next* WSPConnect
1325 */
1326 if (Socket->SharedData.AsyncEvents & FD_CONNECT)
1327 {
1328 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
1329 }
1330
1331 /* Tell AFD that we want Connection Data back, have it allocate a buffer */
1332 if (lpCalleeData != NULL)
1333 {
1334 InConnectDataLength = lpCalleeData->len;
1335 Status = NtDeviceIoControlFile((HANDLE)Handle,
1336 SockEvent,
1337 NULL,
1338 NULL,
1339 &IOSB,
1340 IOCTL_AFD_SET_CONNECT_DATA_SIZE,
1341 &InConnectDataLength,
1342 sizeof(InConnectDataLength),
1343 NULL,
1344 0);
1345
1346 /* Wait for return */
1347 if (Status == STATUS_PENDING)
1348 {
1349 WaitForSingleObject(SockEvent, INFINITE);
1350 Status = IOSB.Status;
1351 }
1352 }
1353
1354 /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
1355 ConnectInfo->Root = 0;
1356 ConnectInfo->UseSAN = FALSE;
1357 ConnectInfo->Unknown = 0;
1358
1359 /* FIXME: Handle Async Connect */
1360 if (Socket->SharedData.NonBlocking)
1361 {
1362 AFD_DbgPrint(MIN_TRACE, ("Async Connect UNIMPLEMENTED!\n"));
1363 }
1364
1365 /* Send IOCTL */
1366 Status = NtDeviceIoControlFile((HANDLE)Handle,
1367 SockEvent,
1368 NULL,
1369 NULL,
1370 &IOSB,
1371 IOCTL_AFD_CONNECT,
1372 ConnectInfo,
1373 0x22,
1374 NULL,
1375 0);
1376 /* Wait for return */
1377 if (Status == STATUS_PENDING)
1378 {
1379 WaitForSingleObject(SockEvent, INFINITE);
1380 Status = IOSB.Status;
1381 }
1382
1383 /* Get any pending connect data */
1384 if (lpCalleeData != NULL)
1385 {
1386 Status = NtDeviceIoControlFile((HANDLE)Handle,
1387 SockEvent,
1388 NULL,
1389 NULL,
1390 &IOSB,
1391 IOCTL_AFD_GET_CONNECT_DATA,
1392 NULL,
1393 0,
1394 lpCalleeData->buf,
1395 lpCalleeData->len);
1396 /* Wait for return */
1397 if (Status == STATUS_PENDING)
1398 {
1399 WaitForSingleObject(SockEvent, INFINITE);
1400 Status = IOSB.Status;
1401 }
1402 }
1403
1404 /* Re-enable Async Event */
1405 SockReenableAsyncSelectEvent(Socket, FD_WRITE);
1406
1407 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
1408 SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
1409
1410 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1411
1412 NtClose( SockEvent );
1413
1414 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1415 }
1416 int
1417 WSPAPI
1418 WSPShutdown(SOCKET Handle,
1419 int HowTo,
1420 LPINT lpErrno)
1421
1422 {
1423 IO_STATUS_BLOCK IOSB;
1424 AFD_DISCONNECT_INFO DisconnectInfo;
1425 PSOCKET_INFORMATION Socket = NULL;
1426 NTSTATUS Status;
1427 HANDLE SockEvent;
1428
1429 Status = NtCreateEvent(&SockEvent,
1430 GENERIC_READ | GENERIC_WRITE,
1431 NULL,
1432 1,
1433 FALSE);
1434
1435 if( !NT_SUCCESS(Status) )
1436 return -1;
1437
1438 AFD_DbgPrint(MID_TRACE,("Called\n"));
1439
1440 /* Get the Socket Structure associate to this Socket*/
1441 Socket = GetSocketStructure(Handle);
1442
1443 /* Set AFD Disconnect Type */
1444 switch (HowTo)
1445 {
1446 case SD_RECEIVE:
1447 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
1448 Socket->SharedData.ReceiveShutdown = TRUE;
1449 break;
1450 case SD_SEND:
1451 DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
1452 Socket->SharedData.SendShutdown = TRUE;
1453 break;
1454 case SD_BOTH:
1455 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
1456 Socket->SharedData.ReceiveShutdown = TRUE;
1457 Socket->SharedData.SendShutdown = TRUE;
1458 break;
1459 }
1460
1461 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
1462
1463 /* Send IOCTL */
1464 Status = NtDeviceIoControlFile((HANDLE)Handle,
1465 SockEvent,
1466 NULL,
1467 NULL,
1468 &IOSB,
1469 IOCTL_AFD_DISCONNECT,
1470 &DisconnectInfo,
1471 sizeof(DisconnectInfo),
1472 NULL,
1473 0);
1474
1475 /* Wait for return */
1476 if (Status == STATUS_PENDING)
1477 {
1478 WaitForSingleObject(SockEvent, INFINITE);
1479 Status = IOSB.Status;
1480 }
1481
1482 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1483
1484 NtClose( SockEvent );
1485
1486 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1487 }
1488
1489
1490 INT
1491 WSPAPI
1492 WSPGetSockName(IN SOCKET Handle,
1493 OUT LPSOCKADDR Name,
1494 IN OUT LPINT NameLength,
1495 OUT LPINT lpErrno)
1496 {
1497 IO_STATUS_BLOCK IOSB;
1498 ULONG TdiAddressSize;
1499 PTDI_ADDRESS_INFO TdiAddress;
1500 PTRANSPORT_ADDRESS SocketAddress;
1501 PSOCKET_INFORMATION Socket = NULL;
1502 NTSTATUS Status;
1503 HANDLE SockEvent;
1504
1505 Status = NtCreateEvent(&SockEvent,
1506 GENERIC_READ | GENERIC_WRITE,
1507 NULL,
1508 1,
1509 FALSE);
1510
1511 if( !NT_SUCCESS(Status) )
1512 return SOCKET_ERROR;
1513
1514 /* Get the Socket Structure associate to this Socket*/
1515 Socket = GetSocketStructure(Handle);
1516
1517 /* Allocate a buffer for the address */
1518 TdiAddressSize =
1519 sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
1520 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1521
1522 if ( TdiAddress == NULL )
1523 {
1524 NtClose( SockEvent );
1525 *lpErrno = WSAENOBUFS;
1526 return SOCKET_ERROR;
1527 }
1528
1529 SocketAddress = &TdiAddress->Address;
1530
1531 /* Send IOCTL */
1532 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1533 SockEvent,
1534 NULL,
1535 NULL,
1536 &IOSB,
1537 IOCTL_AFD_GET_SOCK_NAME,
1538 NULL,
1539 0,
1540 TdiAddress,
1541 TdiAddressSize);
1542
1543 /* Wait for return */
1544 if (Status == STATUS_PENDING)
1545 {
1546 WaitForSingleObject(SockEvent, INFINITE);
1547 Status = IOSB.Status;
1548 }
1549
1550 NtClose( SockEvent );
1551
1552 if (NT_SUCCESS(Status))
1553 {
1554 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1555 {
1556 Name->sa_family = SocketAddress->Address[0].AddressType;
1557 RtlCopyMemory (Name->sa_data,
1558 SocketAddress->Address[0].Address,
1559 SocketAddress->Address[0].AddressLength);
1560 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1561 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
1562 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1563 ((struct sockaddr_in *)Name)->sin_port));
1564 HeapFree(GlobalHeap, 0, TdiAddress);
1565 return 0;
1566 }
1567 else
1568 {
1569 HeapFree(GlobalHeap, 0, TdiAddress);
1570 *lpErrno = WSAEFAULT;
1571 return SOCKET_ERROR;
1572 }
1573 }
1574
1575 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1576 }
1577
1578
1579 INT
1580 WSPAPI
1581 WSPGetPeerName(IN SOCKET s,
1582 OUT LPSOCKADDR Name,
1583 IN OUT LPINT NameLength,
1584 OUT LPINT lpErrno)
1585 {
1586 IO_STATUS_BLOCK IOSB;
1587 ULONG TdiAddressSize;
1588 PTRANSPORT_ADDRESS SocketAddress;
1589 PSOCKET_INFORMATION Socket = NULL;
1590 NTSTATUS Status;
1591 HANDLE SockEvent;
1592
1593 Status = NtCreateEvent(&SockEvent,
1594 GENERIC_READ | GENERIC_WRITE,
1595 NULL,
1596 1,
1597 FALSE);
1598
1599 if( !NT_SUCCESS(Status) )
1600 return SOCKET_ERROR;
1601
1602 /* Get the Socket Structure associate to this Socket*/
1603 Socket = GetSocketStructure(s);
1604
1605 /* Allocate a buffer for the address */
1606 TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
1607 SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1608
1609 if ( SocketAddress == NULL )
1610 {
1611 NtClose( SockEvent );
1612 *lpErrno = WSAENOBUFS;
1613 return SOCKET_ERROR;
1614 }
1615
1616 /* Send IOCTL */
1617 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1618 SockEvent,
1619 NULL,
1620 NULL,
1621 &IOSB,
1622 IOCTL_AFD_GET_PEER_NAME,
1623 NULL,
1624 0,
1625 SocketAddress,
1626 TdiAddressSize);
1627
1628 /* Wait for return */
1629 if (Status == STATUS_PENDING)
1630 {
1631 WaitForSingleObject(SockEvent, INFINITE);
1632 Status = IOSB.Status;
1633 }
1634
1635 NtClose( SockEvent );
1636
1637 if (NT_SUCCESS(Status))
1638 {
1639 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1640 {
1641 Name->sa_family = SocketAddress->Address[0].AddressType;
1642 RtlCopyMemory (Name->sa_data,
1643 SocketAddress->Address[0].Address,
1644 SocketAddress->Address[0].AddressLength);
1645 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1646 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
1647 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1648 ((struct sockaddr_in *)Name)->sin_port));
1649 HeapFree(GlobalHeap, 0, SocketAddress);
1650 return 0;
1651 }
1652 else
1653 {
1654 HeapFree(GlobalHeap, 0, SocketAddress);
1655 *lpErrno = WSAEFAULT;
1656 return SOCKET_ERROR;
1657 }
1658 }
1659
1660 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1661 }
1662
1663 INT
1664 WSPAPI
1665 WSPIoctl(IN SOCKET Handle,
1666 IN DWORD dwIoControlCode,
1667 IN LPVOID lpvInBuffer,
1668 IN DWORD cbInBuffer,
1669 OUT LPVOID lpvOutBuffer,
1670 IN DWORD cbOutBuffer,
1671 OUT LPDWORD lpcbBytesReturned,
1672 IN LPWSAOVERLAPPED lpOverlapped,
1673 IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
1674 IN LPWSATHREADID lpThreadId,
1675 OUT LPINT lpErrno)
1676 {
1677 PSOCKET_INFORMATION Socket = NULL;
1678
1679 /* Get the Socket Structure associate to this Socket*/
1680 Socket = GetSocketStructure(Handle);
1681
1682 switch( dwIoControlCode )
1683 {
1684 case FIONBIO:
1685 if( cbInBuffer < sizeof(INT) )
1686 return SOCKET_ERROR;
1687 Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
1688 AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n", Handle, Socket->SharedData.NonBlocking));
1689 return 0;
1690 case FIONREAD:
1691 return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
1692 default:
1693 *lpErrno = WSAEINVAL;
1694 return SOCKET_ERROR;
1695 }
1696 }
1697
1698
1699 INT
1700 WSPAPI
1701 WSPGetSockOpt(IN SOCKET Handle,
1702 IN INT Level,
1703 IN INT OptionName,
1704 OUT CHAR FAR* OptionValue,
1705 IN OUT LPINT OptionLength,
1706 OUT LPINT lpErrno)
1707 {
1708 PSOCKET_INFORMATION Socket = NULL;
1709 PVOID Buffer;
1710 INT BufferSize;
1711 BOOLEAN BoolBuffer;
1712
1713 /* Get the Socket Structure associate to this Socket*/
1714 Socket = GetSocketStructure(Handle);
1715 if (Socket == NULL)
1716 {
1717 *lpErrno = WSAENOTSOCK;
1718 return SOCKET_ERROR;
1719 }
1720
1721 AFD_DbgPrint(MID_TRACE, ("Called\n"));
1722
1723 switch (Level)
1724 {
1725 case SOL_SOCKET:
1726 switch (OptionName)
1727 {
1728 case SO_TYPE:
1729 Buffer = &Socket->SharedData.SocketType;
1730 BufferSize = sizeof(INT);
1731 break;
1732
1733 case SO_RCVBUF:
1734 Buffer = &Socket->SharedData.SizeOfRecvBuffer;
1735 BufferSize = sizeof(INT);
1736 break;
1737
1738 case SO_SNDBUF:
1739 Buffer = &Socket->SharedData.SizeOfSendBuffer;
1740 BufferSize = sizeof(INT);
1741 break;
1742
1743 case SO_ACCEPTCONN:
1744 BoolBuffer = Socket->SharedData.Listening;
1745 Buffer = &BoolBuffer;
1746 BufferSize = sizeof(BOOLEAN);
1747 break;
1748
1749 case SO_BROADCAST:
1750 BoolBuffer = Socket->SharedData.Broadcast;
1751 Buffer = &BoolBuffer;
1752 BufferSize = sizeof(BOOLEAN);
1753 break;
1754
1755 case SO_DEBUG:
1756 BoolBuffer = Socket->SharedData.Debug;
1757 Buffer = &BoolBuffer;
1758 BufferSize = sizeof(BOOLEAN);
1759 break;
1760
1761 /* case SO_CONDITIONAL_ACCEPT: */
1762 case SO_DONTLINGER:
1763 case SO_DONTROUTE:
1764 case SO_ERROR:
1765 case SO_GROUP_ID:
1766 case SO_GROUP_PRIORITY:
1767 case SO_KEEPALIVE:
1768 case SO_LINGER:
1769 case SO_MAX_MSG_SIZE:
1770 case SO_OOBINLINE:
1771 case SO_PROTOCOL_INFO:
1772 case SO_REUSEADDR:
1773 AFD_DbgPrint(MID_TRACE, ("Unimplemented option (%x)\n",
1774 OptionName));
1775
1776 default:
1777 *lpErrno = WSAEINVAL;
1778 return SOCKET_ERROR;
1779 }
1780
1781 if (*OptionLength < BufferSize)
1782 {
1783 *lpErrno = WSAEFAULT;
1784 *OptionLength = BufferSize;
1785 return SOCKET_ERROR;
1786 }
1787 RtlCopyMemory(OptionValue, Buffer, BufferSize);
1788
1789 return 0;
1790
1791 case IPPROTO_TCP: /* FIXME */
1792 default:
1793 *lpErrno = WSAEINVAL;
1794 return SOCKET_ERROR;
1795 }
1796 }
1797
1798
1799 /*
1800 * FUNCTION: Initialize service provider for a client
1801 * ARGUMENTS:
1802 * wVersionRequested = Highest WinSock SPI version that the caller can use
1803 * lpWSPData = Address of WSPDATA structure to initialize
1804 * lpProtocolInfo = Pointer to structure that defines the desired protocol
1805 * UpcallTable = Pointer to upcall table of the WinSock DLL
1806 * lpProcTable = Address of procedure table to initialize
1807 * RETURNS:
1808 * Status of operation
1809 */
1810 INT
1811 WSPAPI
1812 WSPStartup(IN WORD wVersionRequested,
1813 OUT LPWSPDATA lpWSPData,
1814 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
1815 IN WSPUPCALLTABLE UpcallTable,
1816 OUT LPWSPPROC_TABLE lpProcTable)
1817
1818 {
1819 NTSTATUS Status;
1820
1821 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
1822 Status = NO_ERROR;
1823 Upcalls = UpcallTable;
1824
1825 if (Status == NO_ERROR)
1826 {
1827 lpProcTable->lpWSPAccept = WSPAccept;
1828 lpProcTable->lpWSPAddressToString = WSPAddressToString;
1829 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
1830 lpProcTable->lpWSPBind = WSPBind;
1831 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
1832 lpProcTable->lpWSPCleanup = WSPCleanup;
1833 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
1834 lpProcTable->lpWSPConnect = WSPConnect;
1835 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
1836 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
1837 lpProcTable->lpWSPEventSelect = WSPEventSelect;
1838 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
1839 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
1840 lpProcTable->lpWSPGetSockName = WSPGetSockName;
1841 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
1842 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
1843 lpProcTable->lpWSPIoctl = WSPIoctl;
1844 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
1845 lpProcTable->lpWSPListen = WSPListen;
1846 lpProcTable->lpWSPRecv = WSPRecv;
1847 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
1848 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
1849 lpProcTable->lpWSPSelect = WSPSelect;
1850 lpProcTable->lpWSPSend = WSPSend;
1851 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
1852 lpProcTable->lpWSPSendTo = WSPSendTo;
1853 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
1854 lpProcTable->lpWSPShutdown = WSPShutdown;
1855 lpProcTable->lpWSPSocket = WSPSocket;
1856 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
1857 lpWSPData->wVersion = MAKEWORD(2, 2);
1858 lpWSPData->wHighVersion = MAKEWORD(2, 2);
1859 }
1860
1861 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
1862
1863 return Status;
1864 }
1865
1866
1867 /*
1868 * FUNCTION: Cleans up service provider for a client
1869 * ARGUMENTS:
1870 * lpErrno = Address of buffer for error information
1871 * RETURNS:
1872 * 0 if successful, or SOCKET_ERROR if not
1873 */
1874 INT
1875 WSPAPI
1876 WSPCleanup(OUT LPINT lpErrno)
1877
1878 {
1879 AFD_DbgPrint(MAX_TRACE, ("\n"));
1880 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
1881 *lpErrno = NO_ERROR;
1882
1883 return 0;
1884 }
1885
1886
1887
1888 int
1889 GetSocketInformation(PSOCKET_INFORMATION Socket,
1890 ULONG AfdInformationClass,
1891 PULONG Ulong OPTIONAL,
1892 PLARGE_INTEGER LargeInteger OPTIONAL)
1893 {
1894 IO_STATUS_BLOCK IOSB;
1895 AFD_INFO InfoData;
1896 NTSTATUS Status;
1897 HANDLE SockEvent;
1898
1899 Status = NtCreateEvent(&SockEvent,
1900 GENERIC_READ | GENERIC_WRITE,
1901 NULL,
1902 1,
1903 FALSE);
1904
1905 if( !NT_SUCCESS(Status) )
1906 return -1;
1907
1908 /* Set Info Class */
1909 InfoData.InformationClass = AfdInformationClass;
1910
1911 /* Send IOCTL */
1912 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1913 SockEvent,
1914 NULL,
1915 NULL,
1916 &IOSB,
1917 IOCTL_AFD_GET_INFO,
1918 &InfoData,
1919 sizeof(InfoData),
1920 &InfoData,
1921 sizeof(InfoData));
1922
1923 /* Wait for return */
1924 if (Status == STATUS_PENDING)
1925 {
1926 WaitForSingleObject(SockEvent, INFINITE);
1927 }
1928
1929 /* Return Information */
1930 *Ulong = InfoData.Information.Ulong;
1931 if (LargeInteger != NULL)
1932 {
1933 *LargeInteger = InfoData.Information.LargeInteger;
1934 }
1935
1936 NtClose( SockEvent );
1937
1938 return 0;
1939
1940 }
1941
1942
1943 int
1944 SetSocketInformation(PSOCKET_INFORMATION Socket,
1945 ULONG AfdInformationClass,
1946 PULONG Ulong OPTIONAL,
1947 PLARGE_INTEGER LargeInteger OPTIONAL)
1948 {
1949 IO_STATUS_BLOCK IOSB;
1950 AFD_INFO InfoData;
1951 NTSTATUS Status;
1952 HANDLE SockEvent;
1953
1954 Status = NtCreateEvent(&SockEvent,
1955 GENERIC_READ | GENERIC_WRITE,
1956 NULL,
1957 1,
1958 FALSE);
1959
1960 if( !NT_SUCCESS(Status) )
1961 return -1;
1962
1963 /* Set Info Class */
1964 InfoData.InformationClass = AfdInformationClass;
1965
1966 /* Set Information */
1967 InfoData.Information.Ulong = *Ulong;
1968 if (LargeInteger != NULL)
1969 {
1970 InfoData.Information.LargeInteger = *LargeInteger;
1971 }
1972
1973 AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
1974 AfdInformationClass, *Ulong));
1975
1976 /* Send IOCTL */
1977 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1978 SockEvent,
1979 NULL,
1980 NULL,
1981 &IOSB,
1982 IOCTL_AFD_GET_INFO,
1983 &InfoData,
1984 sizeof(InfoData),
1985 NULL,
1986 0);
1987
1988 /* Wait for return */
1989 if (Status == STATUS_PENDING)
1990 {
1991 WaitForSingleObject(SockEvent, INFINITE);
1992 }
1993
1994 NtClose( SockEvent );
1995
1996 return 0;
1997
1998 }
1999
2000 PSOCKET_INFORMATION
2001 GetSocketStructure(SOCKET Handle)
2002 {
2003 ULONG i;
2004
2005 for (i=0; i<SocketCount; i++)
2006 {
2007 if (Sockets[i]->Handle == Handle)
2008 {
2009 return Sockets[i];
2010 }
2011 }
2012 return 0;
2013 }
2014
2015 int CreateContext(PSOCKET_INFORMATION Socket)
2016 {
2017 IO_STATUS_BLOCK IOSB;
2018 SOCKET_CONTEXT ContextData;
2019 NTSTATUS Status;
2020 HANDLE SockEvent;
2021
2022 Status = NtCreateEvent(&SockEvent,
2023 GENERIC_READ | GENERIC_WRITE,
2024 NULL,
2025 1,
2026 FALSE);
2027
2028 if( !NT_SUCCESS(Status) )
2029 return -1;
2030
2031 /* Create Context */
2032 ContextData.SharedData = Socket->SharedData;
2033 ContextData.SizeOfHelperData = 0;
2034 RtlCopyMemory (&ContextData.LocalAddress,
2035 Socket->LocalAddress,
2036 Socket->SharedData.SizeOfLocalAddress);
2037 RtlCopyMemory (&ContextData.RemoteAddress,
2038 Socket->RemoteAddress,
2039 Socket->SharedData.SizeOfRemoteAddress);
2040
2041 /* Send IOCTL */
2042 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
2043 SockEvent,
2044 NULL,
2045 NULL,
2046 &IOSB,
2047 IOCTL_AFD_SET_CONTEXT,
2048 &ContextData,
2049 sizeof(ContextData),
2050 NULL,
2051 0);
2052
2053 /* Wait for Completition */
2054 if (Status == STATUS_PENDING)
2055 {
2056 WaitForSingleObject(SockEvent, INFINITE);
2057 }
2058
2059 NtClose( SockEvent );
2060
2061 return 0;
2062 }
2063
2064 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
2065 {
2066 HANDLE hAsyncThread;
2067 DWORD AsyncThreadId;
2068 HANDLE AsyncEvent;
2069 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2070 NTSTATUS Status;
2071
2072 /* Check if the Thread Already Exists */
2073 if (SockAsyncThreadRefCount)
2074 {
2075 return TRUE;
2076 }
2077
2078 /* Create the Completion Port */
2079 if (!SockAsyncCompletionPort)
2080 {
2081 Status = NtCreateIoCompletion(&SockAsyncCompletionPort,
2082 IO_COMPLETION_ALL_ACCESS,
2083 NULL,
2084 2); // Allow 2 threads only
2085
2086 /* Protect Handle */
2087 HandleFlags.ProtectFromClose = TRUE;
2088 HandleFlags.Inherit = FALSE;
2089 Status = NtSetInformationObject(SockAsyncCompletionPort,
2090 ObjectHandleFlagInformation,
2091 &HandleFlags,
2092 sizeof(HandleFlags));
2093 }
2094
2095 /* Create the Async Event */
2096 Status = NtCreateEvent(&AsyncEvent,
2097 EVENT_ALL_ACCESS,
2098 NULL,
2099 NotificationEvent,
2100 FALSE);
2101
2102 /* Create the Async Thread */
2103 hAsyncThread = CreateThread(NULL,
2104 0,
2105 (LPTHREAD_START_ROUTINE)SockAsyncThread,
2106 NULL,
2107 0,
2108 &AsyncThreadId);
2109
2110 /* Close the Handle */
2111 NtClose(hAsyncThread);
2112
2113 /* Increase the Reference Count */
2114 SockAsyncThreadRefCount++;
2115 return TRUE;
2116 }
2117
2118 int SockAsyncThread(PVOID ThreadParam)
2119 {
2120 PVOID AsyncContext;
2121 PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
2122 IO_STATUS_BLOCK IOSB;
2123 NTSTATUS Status;
2124
2125 /* Make the Thread Higher Priority */
2126 SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
2127
2128 /* Do a KQUEUE/WorkItem Style Loop, thanks to IoCompletion Ports */
2129 do
2130 {
2131 Status = NtRemoveIoCompletion (SockAsyncCompletionPort,
2132 (PVOID*)&AsyncCompletionRoutine,
2133 &AsyncContext,
2134 &IOSB,
2135 NULL);
2136 /* Call the Async Function */
2137 if (NT_SUCCESS(Status))
2138 {
2139 (*AsyncCompletionRoutine)(AsyncContext, &IOSB);
2140 }
2141 else
2142 {
2143 /* It Failed, sleep for a second */
2144 Sleep(1000);
2145 }
2146 } while ((Status != STATUS_TIMEOUT));
2147
2148 /* The Thread has Ended */
2149 return 0;
2150 }
2151
2152 BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
2153 {
2154 UNICODE_STRING AfdHelper;
2155 OBJECT_ATTRIBUTES ObjectAttributes;
2156 IO_STATUS_BLOCK IoSb;
2157 NTSTATUS Status;
2158 FILE_COMPLETION_INFORMATION CompletionInfo;
2159 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2160
2161 /* First, make sure we're not already intialized */
2162 if (SockAsyncHelperAfdHandle)
2163 {
2164 return TRUE;
2165 }
2166
2167 /* Set up Handle Name and Object */
2168 RtlInitUnicodeString(&AfdHelper, L"\\Device\\Afd\\AsyncSelectHlp" );
2169 InitializeObjectAttributes(&ObjectAttributes,
2170 &AfdHelper,
2171 OBJ_INHERIT | OBJ_CASE_INSENSITIVE,
2172 NULL,
2173 NULL);
2174
2175 /* Open the Handle to AFD */
2176 Status = NtCreateFile(&SockAsyncHelperAfdHandle,
2177 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
2178 &ObjectAttributes,
2179 &IoSb,
2180 NULL,
2181 0,
2182 FILE_SHARE_READ | FILE_SHARE_WRITE,
2183 FILE_OPEN_IF,
2184 0,
2185 NULL,
2186 0);
2187
2188 /*
2189 * Now Set up the Completion Port Information
2190 * This means that whenever a Poll is finished, the routine will be executed
2191 */
2192 CompletionInfo.Port = SockAsyncCompletionPort;
2193 CompletionInfo.Key = SockAsyncSelectCompletionRoutine;
2194 Status = NtSetInformationFile(SockAsyncHelperAfdHandle,
2195 &IoSb,
2196 &CompletionInfo,
2197 sizeof(CompletionInfo),
2198 FileCompletionInformation);
2199
2200
2201 /* Protect the Handle */
2202 HandleFlags.ProtectFromClose = TRUE;
2203 HandleFlags.Inherit = FALSE;
2204 Status = NtSetInformationObject(SockAsyncCompletionPort,
2205 ObjectHandleFlagInformation,
2206 &HandleFlags,
2207 sizeof(HandleFlags));
2208
2209
2210 /* Set this variable to true so that Send/Recv/Accept will know wether to renable disabled events */
2211 SockAsyncSelectCalled = TRUE;
2212 return TRUE;
2213 }
2214
2215 VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2216 {
2217
2218 PASYNC_DATA AsyncData = Context;
2219 PSOCKET_INFORMATION Socket;
2220 ULONG x;
2221
2222 /* Get the Socket */
2223 Socket = AsyncData->ParentSocket;
2224
2225 /* Check if the Sequence Number Changed behind our back */
2226 if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber )
2227 {
2228 return;
2229 }
2230
2231 /* Check we were manually called b/c of a failure */
2232 if (!NT_SUCCESS(IoStatusBlock->Status))
2233 {
2234 /* FIXME: Perform Upcall */
2235 return;
2236 }
2237
2238 for (x = 1; x; x<<=1)
2239 {
2240 switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x)
2241 {
2242 case AFD_EVENT_RECEIVE:
2243 if (0 != (Socket->SharedData.AsyncEvents & FD_READ) &&
2244 0 == (Socket->SharedData.AsyncDisabledEvents & FD_READ))
2245 {
2246 /* Make the Notifcation */
2247 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2248 Socket->SharedData.wMsg,
2249 Socket->Handle,
2250 WSAMAKESELECTREPLY(FD_READ, 0));
2251 /* Disable this event until the next read(); */
2252 Socket->SharedData.AsyncDisabledEvents |= FD_READ;
2253 }
2254 break;
2255
2256 case AFD_EVENT_OOB_RECEIVE:
2257 if (0 != (Socket->SharedData.AsyncEvents & FD_OOB) &&
2258 0 == (Socket->SharedData.AsyncDisabledEvents & FD_OOB))
2259 {
2260 /* Make the Notifcation */
2261 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2262 Socket->SharedData.wMsg,
2263 Socket->Handle,
2264 WSAMAKESELECTREPLY(FD_OOB, 0));
2265 /* Disable this event until the next read(); */
2266 Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
2267 }
2268 break;
2269
2270 case AFD_EVENT_SEND:
2271 if (0 != (Socket->SharedData.AsyncEvents & FD_WRITE) &&
2272 0 == (Socket->SharedData.AsyncDisabledEvents & FD_WRITE))
2273 {
2274 /* Make the Notifcation */
2275 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2276 Socket->SharedData.wMsg,
2277 Socket->Handle,
2278 WSAMAKESELECTREPLY(FD_WRITE, 0));
2279 /* Disable this event until the next write(); */
2280 Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
2281 }
2282 break;
2283
2284 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2285 case AFD_EVENT_CONNECT:
2286 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
2287 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
2288 {
2289 /* Make the Notifcation */
2290 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2291 Socket->SharedData.wMsg,
2292 Socket->Handle,
2293 WSAMAKESELECTREPLY(FD_CONNECT, 0));
2294 /* Disable this event forever; */
2295 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT;
2296 }
2297 break;
2298
2299 case AFD_EVENT_ACCEPT:
2300 if (0 != (Socket->SharedData.AsyncEvents & FD_ACCEPT) &&
2301 0 == (Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT))
2302 {
2303 /* Make the Notifcation */
2304 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2305 Socket->SharedData.wMsg,
2306 Socket->Handle,
2307 WSAMAKESELECTREPLY(FD_ACCEPT, 0));
2308 /* Disable this event until the next accept(); */
2309 Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
2310 }
2311 break;
2312
2313 case AFD_EVENT_DISCONNECT:
2314 case AFD_EVENT_ABORT:
2315 case AFD_EVENT_CLOSE:
2316 if (0 != (Socket->SharedData.AsyncEvents & FD_CLOSE) &&
2317 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CLOSE))
2318 {
2319 /* Make the Notifcation */
2320 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2321 Socket->SharedData.wMsg,
2322 Socket->Handle,
2323 WSAMAKESELECTREPLY(FD_CLOSE, 0));
2324 /* Disable this event forever; */
2325 Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
2326 }
2327 break;
2328 /* FIXME: Support QOS */
2329 }
2330 }
2331
2332 /* Check if there are any events left for us to check */
2333 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2334 {
2335 return;
2336 }
2337
2338 /* Keep Polling */
2339 SockProcessAsyncSelect(Socket, AsyncData);
2340 return;
2341 }
2342
2343 VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
2344 {
2345
2346 ULONG lNetworkEvents;
2347 NTSTATUS Status;
2348
2349 /* Set up the Async Data Event Info */
2350 AsyncData->AsyncSelectInfo.Timeout.HighPart = 0x7FFFFFFF;
2351 AsyncData->AsyncSelectInfo.Timeout.LowPart = 0xFFFFFFFF;
2352 AsyncData->AsyncSelectInfo.HandleCount = 1;
2353 AsyncData->AsyncSelectInfo.Exclusive = TRUE;
2354 AsyncData->AsyncSelectInfo.Handles[0].Handle = Socket->Handle;
2355 AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
2356
2357 /* Remove unwanted events */
2358 lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
2359
2360 /* Set Events to wait for */
2361 if (lNetworkEvents & FD_READ)
2362 {
2363 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_RECEIVE;
2364 }
2365
2366 if (lNetworkEvents & FD_WRITE)
2367 {
2368 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_SEND;
2369 }
2370
2371 if (lNetworkEvents & FD_OOB)
2372 {
2373 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_OOB_RECEIVE;
2374 }
2375
2376 if (lNetworkEvents & FD_ACCEPT)
2377 {
2378 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_ACCEPT;
2379 }
2380
2381 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2382 if (lNetworkEvents & FD_CONNECT)
2383 {
2384 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_CONNECT | AFD_EVENT_CONNECT_FAIL;
2385 }
2386
2387 if (lNetworkEvents & FD_CLOSE)
2388 {
2389 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
2390 }
2391
2392 if (lNetworkEvents & FD_QOS)
2393 {
2394 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_QOS;
2395 }
2396
2397 if (lNetworkEvents & FD_GROUP_QOS)
2398 {
2399 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_GROUP_QOS;
2400 }
2401
2402 /* Send IOCTL */
2403 Status = NtDeviceIoControlFile (SockAsyncHelperAfdHandle,
2404 NULL,
2405 NULL,
2406 AsyncData,
2407 &AsyncData->IoStatusBlock,
2408 IOCTL_AFD_SELECT,
2409 &AsyncData->AsyncSelectInfo,
2410 sizeof(AsyncData->AsyncSelectInfo),
2411 &AsyncData->AsyncSelectInfo,
2412 sizeof(AsyncData->AsyncSelectInfo));
2413
2414 /* I/O Manager Won't call the completion routine, let's do it manually */
2415 if (NT_SUCCESS(Status))
2416 {
2417 return;
2418 }
2419 else
2420 {
2421 AsyncData->IoStatusBlock.Status = Status;
2422 SockAsyncSelectCompletionRoutine(AsyncData, &AsyncData->IoStatusBlock);
2423 }
2424 }
2425
2426 VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2427 {
2428 PASYNC_DATA AsyncData = Context;
2429 BOOL FreeContext = TRUE;
2430 PSOCKET_INFORMATION Socket;
2431
2432 /* Get the Socket */
2433 Socket = AsyncData->ParentSocket;
2434
2435 /* If someone closed it, stop the function */
2436 if (Socket->SharedData.State != SocketClosed)
2437 {
2438 /* Check if the Sequence Number changed by now, in which case quit */
2439 if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber)
2440 {
2441 /* Do the actuall select, if needed */
2442 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)))
2443 {
2444 SockProcessAsyncSelect(Socket, AsyncData);
2445 FreeContext = FALSE;
2446 }
2447 }
2448 }
2449
2450 /* Free the Context */
2451 if (FreeContext)
2452 {
2453 HeapFree(GetProcessHeap(), 0, AsyncData);
2454 }
2455
2456 return;
2457 }
2458
2459 VOID
2460 SockReenableAsyncSelectEvent (IN PSOCKET_INFORMATION Socket,
2461 IN ULONG Event)
2462 {
2463 PASYNC_DATA AsyncData;
2464
2465 /* Make sure the event is actually disabled */
2466 if (!(Socket->SharedData.AsyncDisabledEvents & Event))
2467 {
2468 return;
2469 }
2470
2471 /* Re-enable it */
2472 Socket->SharedData.AsyncDisabledEvents &= ~Event;
2473
2474 /* Return if no more events are being polled */
2475 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2476 {
2477 return;
2478 }
2479
2480 /* Wait on new events */
2481 AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
2482
2483 /* Create the Asynch Thread if Needed */
2484 SockCreateOrReferenceAsyncThread();
2485
2486 /* Increase the sequence number to stop anything else */
2487 Socket->SharedData.SequenceNumber++;
2488
2489 /* Set up the Async Data */
2490 AsyncData->ParentSocket = Socket;
2491 AsyncData->SequenceNumber = Socket->SharedData.SequenceNumber;
2492
2493 /* Begin Async Select by using I/O Completion */
2494 NtSetIoCompletion(SockAsyncCompletionPort,
2495 (PVOID)&SockProcessQueuedAsyncSelect,
2496 AsyncData,
2497 0,
2498 0);
2499
2500 /* All done */
2501 return;
2502 }
2503
2504 BOOL
2505 WINAPI
2506 DllMain(HANDLE hInstDll,
2507 ULONG dwReason,
2508 PVOID Reserved)
2509 {
2510
2511 switch (dwReason)
2512 {
2513 case DLL_PROCESS_ATTACH:
2514
2515 AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
2516
2517 /* Don't need thread attach notifications
2518 so disable them to improve performance */
2519 DisableThreadLibraryCalls(hInstDll);
2520
2521 /* List of DLL Helpers */
2522 InitializeListHead(&SockHelpersListHead);
2523
2524 /* Heap to use when allocating */
2525 GlobalHeap = GetProcessHeap();
2526
2527 /* Allocate Heap for 1024 Sockets, can be expanded later */
2528 Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
2529
2530 AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
2531
2532 break;
2533
2534 case DLL_THREAD_ATTACH:
2535 break;
2536
2537 case DLL_THREAD_DETACH:
2538 break;
2539
2540 case DLL_PROCESS_DETACH:
2541 break;
2542 }
2543
2544 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
2545
2546 return TRUE;
2547 }
2548
2549 /* EOF */
2550
2551