- Fix a bad typo
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
4 * FILE: misc/dllmain.c
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7 * Alex Ionescu (alex@relsoft.net)
8 * REVISIONS:
9 * CSH 01/09-2000 Created
10 * Alex 16/07/2004 - Complete Rewrite
11 */
12
13 #include <msafd.h>
14
15 #include <debug.h>
16
17 #ifdef DBG
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
19 DWORD DebugTraceLevel = 0;
20 #endif /* DBG */
21
22 HANDLE GlobalHeap;
23 WSPUPCALLTABLE Upcalls;
24 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
25 ULONG SocketCount = 0;
26 PSOCKET_INFORMATION *Sockets = NULL;
27 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
28 ULONG SockAsyncThreadRefCount;
29 HANDLE SockAsyncHelperAfdHandle;
30 HANDLE SockAsyncCompletionPort;
31 BOOLEAN SockAsyncSelectCalled;
32
33
34
35 /*
36 * FUNCTION: Creates a new socket
37 * ARGUMENTS:
38 * af = Address family
39 * type = Socket type
40 * protocol = Protocol type
41 * lpProtocolInfo = Pointer to protocol information
42 * g = Reserved
43 * dwFlags = Socket flags
44 * lpErrno = Address of buffer for error information
45 * RETURNS:
46 * Created socket, or INVALID_SOCKET if it could not be created
47 */
48 SOCKET
49 WSPAPI
50 WSPSocket(int AddressFamily,
51 int SocketType,
52 int Protocol,
53 LPWSAPROTOCOL_INFOW lpProtocolInfo,
54 GROUP g,
55 DWORD dwFlags,
56 LPINT lpErrno)
57 {
58 OBJECT_ATTRIBUTES Object;
59 IO_STATUS_BLOCK IOSB;
60 USHORT SizeOfPacket;
61 ULONG SizeOfEA;
62 PAFD_CREATE_PACKET AfdPacket;
63 HANDLE Sock;
64 PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
65 PFILE_FULL_EA_INFORMATION EABuffer = NULL;
66 PHELPER_DATA HelperData;
67 PVOID HelperDLLContext;
68 DWORD HelperEvents;
69 UNICODE_STRING TransportName;
70 UNICODE_STRING DevName;
71 LARGE_INTEGER GroupData;
72 INT Status;
73
74 AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
75 AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
76 AddressFamily, SocketType, Protocol));
77
78 /* Get Helper Data and Transport */
79 Status = SockGetTdiName (&AddressFamily,
80 &SocketType,
81 &Protocol,
82 g,
83 dwFlags,
84 &TransportName,
85 &HelperDLLContext,
86 &HelperData,
87 &HelperEvents);
88
89 /* Check for error */
90 if (Status != NO_ERROR)
91 {
92 AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
93 goto error;
94 }
95
96 /* AFD Device Name */
97 RtlInitUnicodeString(&DevName, L"\\Device\\Afd\\Endpoint");
98
99 /* Set Socket Data */
100 Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
101 RtlZeroMemory(Socket, sizeof(*Socket));
102 Socket->RefCount = 2;
103 Socket->Handle = -1;
104 Socket->SharedData.Listening = FALSE;
105 Socket->SharedData.State = SocketOpen;
106 Socket->SharedData.AddressFamily = AddressFamily;
107 Socket->SharedData.SocketType = SocketType;
108 Socket->SharedData.Protocol = Protocol;
109 Socket->HelperContext = HelperDLLContext;
110 Socket->HelperData = HelperData;
111 Socket->HelperEvents = HelperEvents;
112 Socket->LocalAddress = &Socket->WSLocalAddress;
113 Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
114 Socket->RemoteAddress = &Socket->WSRemoteAddress;
115 Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
116 Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
117 Socket->SharedData.CreateFlags = dwFlags;
118 Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
119 Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
120 Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
121 Socket->SharedData.GroupID = g;
122 Socket->SharedData.GroupType = 0;
123 Socket->SharedData.UseSAN = FALSE;
124 Socket->SharedData.NonBlocking = FALSE; /* Sockets start blocking */
125 Socket->SanData = NULL;
126
127 /* Ask alex about this */
128 if( Socket->SharedData.SocketType == SOCK_DGRAM ||
129 Socket->SharedData.SocketType == SOCK_RAW )
130 {
131 AFD_DbgPrint(MID_TRACE,("Connectionless socket\n"));
132 Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
133 }
134
135 /* Packet Size */
136 SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
137
138 /* EA Size */
139 SizeOfEA = SizeOfPacket + sizeof(FILE_FULL_EA_INFORMATION) + AFD_PACKET_COMMAND_LENGTH;
140
141 /* Set up EA Buffer */
142 EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
143 RtlZeroMemory(EABuffer, SizeOfEA);
144 EABuffer->NextEntryOffset = 0;
145 EABuffer->Flags = 0;
146 EABuffer->EaNameLength = AFD_PACKET_COMMAND_LENGTH;
147 RtlCopyMemory (EABuffer->EaName,
148 AfdCommand,
149 AFD_PACKET_COMMAND_LENGTH + 1);
150 EABuffer->EaValueLength = SizeOfPacket;
151
152 /* Set up AFD Packet */
153 AfdPacket = (PAFD_CREATE_PACKET)(EABuffer->EaName + EABuffer->EaNameLength + 1);
154 AfdPacket->SizeOfTransportName = TransportName.Length;
155 RtlCopyMemory (AfdPacket->TransportName,
156 TransportName.Buffer,
157 TransportName.Length + sizeof(WCHAR));
158 AfdPacket->GroupID = g;
159
160 /* Set up Endpoint Flags */
161 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0)
162 {
163 if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW))
164 {
165 /* Only RAW or UDP can be Connectionless */
166 goto error;
167 }
168 AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
169 }
170
171 if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0)
172 {
173 if (SocketType == SOCK_STREAM)
174 {
175 if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
176 {
177 /* The Provider doesn't actually support Message Oriented Streams */
178 goto error;
179 }
180 }
181 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MESSAGE_ORIENTED;
182 }
183
184 if (SocketType == SOCK_RAW) AfdPacket->EndpointFlags |= AFD_ENDPOINT_RAW;
185
186 if (dwFlags & (WSA_FLAG_MULTIPOINT_C_ROOT |
187 WSA_FLAG_MULTIPOINT_C_LEAF |
188 WSA_FLAG_MULTIPOINT_D_ROOT |
189 WSA_FLAG_MULTIPOINT_D_LEAF))
190 {
191 if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
192 {
193 /* The Provider doesn't actually support Multipoint */
194 goto error;
195 }
196 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
197
198 if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT)
199 {
200 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
201 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0))
202 {
203 /* The Provider doesn't support Control Planes, or you already gave a leaf */
204 goto error;
205 }
206 AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
207 }
208
209 if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT)
210 {
211 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
212 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0))
213 {
214 /* The Provider doesn't support Data Planes, or you already gave a leaf */
215 goto error;
216 }
217 AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
218 }
219 }
220
221 /* Set up Object Attributes */
222 InitializeObjectAttributes (&Object,
223 &DevName,
224 OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
225 0,
226 0);
227
228 /* Create the Socket as asynchronous. That means we have to block
229 ourselves after every call to NtDeviceIoControlFile. This is
230 because the kernel doesn't support overlapping synchronous I/O
231 requests (made from multiple threads) at this time (Sep 2005) */
232 ZwCreateFile(&Sock,
233 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
234 &Object,
235 &IOSB,
236 NULL,
237 0,
238 FILE_SHARE_READ | FILE_SHARE_WRITE, FILE_OPEN_IF,
239 0,
240 EABuffer,
241 SizeOfEA);
242
243 /* Save Handle */
244 Socket->Handle = (SOCKET)Sock;
245
246 /* XXX See if there's a structure we can reuse -- We need to do this
247 * more properly. */
248 PrevSocket = GetSocketStructure( (SOCKET)Sock );
249
250 if( PrevSocket )
251 {
252 RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
253 RtlFreeHeap( GlobalHeap, 0, Socket );
254 Socket = PrevSocket;
255 }
256
257 /* Save Group Info */
258 if (g != 0)
259 {
260 GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
261 Socket->SharedData.GroupID = GroupData.u.LowPart;
262 Socket->SharedData.GroupType = GroupData.u.HighPart;
263 }
264
265 /* Get Window Sizes and Save them */
266 GetSocketInformation (Socket,
267 AFD_INFO_SEND_WINDOW_SIZE,
268 &Socket->SharedData.SizeOfSendBuffer,
269 NULL);
270
271 GetSocketInformation (Socket,
272 AFD_INFO_RECEIVE_WINDOW_SIZE,
273 &Socket->SharedData.SizeOfRecvBuffer,
274 NULL);
275
276 /* Save in Process Sockets List */
277 Sockets[SocketCount] = Socket;
278 SocketCount ++;
279
280 /* Create the Socket Context */
281 CreateContext(Socket);
282
283 /* Notify Winsock */
284 Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
285
286 /* Return Socket Handle */
287 AFD_DbgPrint(MID_TRACE,("Success %x\n", Sock));
288
289 return (SOCKET)Sock;
290
291 error:
292 AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
293
294 if( lpErrno )
295 *lpErrno = Status;
296
297 return INVALID_SOCKET;
298 }
299
300
301 DWORD MsafdReturnWithErrno(NTSTATUS Status,
302 LPINT Errno,
303 DWORD Received,
304 LPDWORD ReturnedBytes)
305 {
306 if( ReturnedBytes )
307 *ReturnedBytes = 0;
308 if( Errno )
309 {
310 switch (Status)
311 {
312 case STATUS_CANT_WAIT:
313 *Errno = WSAEWOULDBLOCK;
314 break;
315 case STATUS_TIMEOUT:
316 *Errno = WSAETIMEDOUT;
317 break;
318 case STATUS_SUCCESS:
319 /* Return Number of bytes Read */
320 if( ReturnedBytes )
321 *ReturnedBytes = Received;
322 break;
323 case STATUS_FILE_CLOSED:
324 case STATUS_END_OF_FILE:
325 *Errno = WSAESHUTDOWN;
326 break;
327 case STATUS_PENDING:
328 *Errno = WSA_IO_PENDING;
329 break;
330 case STATUS_BUFFER_TOO_SMALL:
331 case STATUS_BUFFER_OVERFLOW:
332 DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
333 *Errno = WSAEMSGSIZE;
334 break;
335 case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
336 case STATUS_INSUFFICIENT_RESOURCES:
337 DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
338 *Errno = WSA_NOT_ENOUGH_MEMORY;
339 break;
340 case STATUS_INVALID_CONNECTION:
341 DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
342 *Errno = WSAEAFNOSUPPORT;
343 break;
344 case STATUS_REMOTE_NOT_LISTENING:
345 DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
346 *Errno = WSAECONNREFUSED;
347 break;
348 case STATUS_NETWORK_UNREACHABLE:
349 DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
350 *Errno = WSAENETUNREACH;
351 break;
352 case STATUS_INVALID_PARAMETER:
353 DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
354 *Errno = WSAEINVAL;
355 break;
356 case STATUS_CANCELLED:
357 DbgPrint("MSAFD: STATUS_CANCELLED\n");
358 *Errno = WSA_OPERATION_ABORTED;
359 break;
360 default:
361 DbgPrint("MSAFD: Error %x is unknown\n", Status);
362 *Errno = WSAEINVAL;
363 break;
364 }
365 }
366
367 /* Success */
368 return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
369 }
370
371 /*
372 * FUNCTION: Closes an open socket
373 * ARGUMENTS:
374 * s = Socket descriptor
375 * lpErrno = Address of buffer for error information
376 * RETURNS:
377 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
378 */
379 INT
380 WSPAPI
381 WSPCloseSocket(IN SOCKET Handle,
382 OUT LPINT lpErrno)
383 {
384 IO_STATUS_BLOCK IoStatusBlock;
385 PSOCKET_INFORMATION Socket = NULL;
386 NTSTATUS Status;
387 HANDLE SockEvent;
388 AFD_DISCONNECT_INFO DisconnectInfo;
389 SOCKET_STATE OldState;
390
391 /* Create the Wait Event */
392 Status = NtCreateEvent(&SockEvent,
393 GENERIC_READ | GENERIC_WRITE,
394 NULL,
395 1,
396 FALSE);
397
398 if(!NT_SUCCESS(Status))
399 return SOCKET_ERROR;
400
401 /* Get the Socket Structure associate to this Socket*/
402 Socket = GetSocketStructure(Handle);
403
404 /* If a Close is already in Process, give up */
405 if (Socket->SharedData.State == SocketClosed)
406 {
407 NtClose(SockEvent);
408 *lpErrno = WSAENOTSOCK;
409 return SOCKET_ERROR;
410 }
411
412 /* Set the state to close */
413 OldState = Socket->SharedData.State;
414 Socket->SharedData.State = SocketClosed;
415
416 /* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
417 /* FIXME: Should we do this on Datagram Sockets too? */
418 if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
419 {
420 ULONG LingerWait;
421 ULONG SendsInProgress;
422 ULONG SleepWait;
423
424 /* We need to respect the timeout */
425 SleepWait = 100;
426 LingerWait = Socket->SharedData.LingerData.l_linger * 1000;
427
428 /* Loop until no more sends are pending, within the timeout */
429 while (LingerWait)
430 {
431 /* Find out how many Sends are in Progress */
432 if (GetSocketInformation(Socket,
433 AFD_INFO_SENDS_IN_PROGRESS,
434 &SendsInProgress,
435 NULL))
436 {
437 /* Bail out if anything but NO_ERROR */
438 LingerWait = 0;
439 break;
440 }
441
442 /* Bail out if no more sends are pending */
443 if (!SendsInProgress)
444 break;
445 /*
446 * We have to execute a sleep, so it's kind of like
447 * a block. If the socket is Nonblock, we cannot
448 * go on since asyncronous operation is expected
449 * and we cannot offer it
450 */
451 if (Socket->SharedData.NonBlocking)
452 {
453 NtClose(SockEvent);
454 Socket->SharedData.State = OldState;
455 *lpErrno = WSAEWOULDBLOCK;
456 return SOCKET_ERROR;
457 }
458
459 /* Now we can sleep, and decrement the linger wait */
460 /*
461 * FIXME: It seems Windows does some funky acceleration
462 * since the waiting seems to be longer and longer. I
463 * don't think this improves performance so much, so we
464 * wait a fixed time instead.
465 */
466 Sleep(SleepWait);
467 LingerWait -= SleepWait;
468 }
469
470 /*
471 * We have reached the timeout or sends are over.
472 * Disconnect if the timeout has been reached.
473 */
474 if (LingerWait <= 0)
475 {
476 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
477 DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
478
479 /* Send IOCTL */
480 Status = NtDeviceIoControlFile((HANDLE)Handle,
481 SockEvent,
482 NULL,
483 NULL,
484 &IoStatusBlock,
485 IOCTL_AFD_DISCONNECT,
486 &DisconnectInfo,
487 sizeof(DisconnectInfo),
488 NULL,
489 0);
490
491 /* Wait for return */
492 if (Status == STATUS_PENDING)
493 {
494 WaitForSingleObject(SockEvent, INFINITE);
495 }
496 }
497 }
498
499 /* FIXME: We should notify the Helper DLL of WSH_NOTIFY_CLOSE */
500
501 /* Cleanup Time! */
502 Socket->HelperContext = NULL;
503 Socket->SharedData.AsyncDisabledEvents = -1;
504 NtClose(Socket->TdiAddressHandle);
505 Socket->TdiAddressHandle = NULL;
506 NtClose(Socket->TdiConnectionHandle);
507 Socket->TdiConnectionHandle = NULL;
508
509 /* Close the handle */
510 NtClose((HANDLE)Handle);
511 NtClose(SockEvent);
512
513 return NO_ERROR;
514 }
515
516
517 /*
518 * FUNCTION: Associates a local address with a socket
519 * ARGUMENTS:
520 * s = Socket descriptor
521 * name = Pointer to local address
522 * namelen = Length of name
523 * lpErrno = Address of buffer for error information
524 * RETURNS:
525 * 0, or SOCKET_ERROR if the socket could not be bound
526 */
527 INT
528 WSPAPI
529 WSPBind(SOCKET Handle,
530 const struct sockaddr *SocketAddress,
531 int SocketAddressLength,
532 LPINT lpErrno)
533 {
534 IO_STATUS_BLOCK IOSB;
535 PAFD_BIND_DATA BindData;
536 PSOCKET_INFORMATION Socket = NULL;
537 NTSTATUS Status;
538 UCHAR BindBuffer[0x1A];
539 SOCKADDR_INFO SocketInfo;
540 HANDLE SockEvent;
541
542 Status = NtCreateEvent(&SockEvent,
543 GENERIC_READ | GENERIC_WRITE,
544 NULL,
545 1,
546 FALSE);
547
548 if( !NT_SUCCESS(Status) )
549 return -1;
550
551 /* Get the Socket Structure associate to this Socket*/
552 Socket = GetSocketStructure(Handle);
553
554 /* Dynamic Structure...ugh */
555 BindData = (PAFD_BIND_DATA)BindBuffer;
556
557 /* Set up Address in TDI Format */
558 BindData->Address.TAAddressCount = 1;
559 BindData->Address.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
560 BindData->Address.Address[0].AddressType = SocketAddress->sa_family;
561 RtlCopyMemory (BindData->Address.Address[0].Address,
562 SocketAddress->sa_data,
563 SocketAddressLength - sizeof(SocketAddress->sa_family));
564
565 /* Get Address Information */
566 Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
567 SocketAddressLength,
568 &SocketInfo);
569
570 /* Set the Share Type */
571 if (Socket->SharedData.ExclusiveAddressUse)
572 {
573 BindData->ShareType = AFD_SHARE_EXCLUSIVE;
574 }
575 else if (SocketInfo.EndpointInfo == SockaddrEndpointInfoWildcard)
576 {
577 BindData->ShareType = AFD_SHARE_WILDCARD;
578 }
579 else if (Socket->SharedData.ReuseAddresses)
580 {
581 BindData->ShareType = AFD_SHARE_REUSE;
582 }
583 else
584 {
585 BindData->ShareType = AFD_SHARE_UNIQUE;
586 }
587
588 /* Send IOCTL */
589 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
590 SockEvent,
591 NULL,
592 NULL,
593 &IOSB,
594 IOCTL_AFD_BIND,
595 BindData,
596 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
597 BindData,
598 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
599
600 /* Wait for return */
601 if (Status == STATUS_PENDING)
602 {
603 WaitForSingleObject(SockEvent, INFINITE);
604 Status = IOSB.Status;
605 }
606
607 /* Set up Socket Data */
608 Socket->SharedData.State = SocketBound;
609 Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
610
611 NtClose( SockEvent );
612
613 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
614 }
615
616 int
617 WSPAPI
618 WSPListen(SOCKET Handle,
619 int Backlog,
620 LPINT lpErrno)
621 {
622 IO_STATUS_BLOCK IOSB;
623 AFD_LISTEN_DATA ListenData;
624 PSOCKET_INFORMATION Socket = NULL;
625 HANDLE SockEvent;
626 NTSTATUS Status;
627
628 /* Get the Socket Structure associate to this Socket*/
629 Socket = GetSocketStructure(Handle);
630
631 if (Socket->SharedData.Listening)
632 return 0;
633
634 Status = NtCreateEvent(&SockEvent,
635 GENERIC_READ | GENERIC_WRITE,
636 NULL,
637 1,
638 FALSE);
639
640 if( !NT_SUCCESS(Status) )
641 return -1;
642
643 /* Set Up Listen Structure */
644 ListenData.UseSAN = FALSE;
645 ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
646 ListenData.Backlog = Backlog;
647
648 /* Send IOCTL */
649 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
650 SockEvent,
651 NULL,
652 NULL,
653 &IOSB,
654 IOCTL_AFD_START_LISTEN,
655 &ListenData,
656 sizeof(ListenData),
657 NULL,
658 0);
659
660 /* Wait for return */
661 if (Status == STATUS_PENDING)
662 {
663 WaitForSingleObject(SockEvent, INFINITE);
664 Status = IOSB.Status;
665 }
666
667 /* Set to Listening */
668 Socket->SharedData.Listening = TRUE;
669
670 NtClose( SockEvent );
671
672 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
673 }
674
675
676 int
677 WSPAPI
678 WSPSelect(int nfds,
679 fd_set *readfds,
680 fd_set *writefds,
681 fd_set *exceptfds,
682 struct timeval *timeout,
683 LPINT lpErrno)
684 {
685 IO_STATUS_BLOCK IOSB;
686 PAFD_POLL_INFO PollInfo;
687 NTSTATUS Status;
688 LONG HandleCount, OutCount = 0;
689 ULONG PollBufferSize;
690 PVOID PollBuffer;
691 ULONG i, j = 0, x;
692 HANDLE SockEvent;
693 BOOL HandleCounted;
694 LARGE_INTEGER Timeout;
695
696 /* Find out how many sockets we have, and how large the buffer needs
697 * to be */
698
699 HandleCount = ( readfds ? readfds->fd_count : 0 ) +
700 ( writefds ? writefds->fd_count : 0 ) +
701 ( exceptfds ? exceptfds->fd_count : 0 );
702
703 if( HandleCount < 0 || nfds != 0 )
704 HandleCount = nfds * 3;
705
706 PollBufferSize = sizeof(*PollInfo) + (HandleCount * sizeof(AFD_HANDLE));
707
708 AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n",
709 HandleCount, PollBufferSize));
710
711 /* Convert Timeout to NT Format */
712 if (timeout == NULL)
713 {
714 Timeout.u.LowPart = -1;
715 Timeout.u.HighPart = 0x7FFFFFFF;
716 AFD_DbgPrint(MAX_TRACE,("Infinite timeout\n"));
717 }
718 else
719 {
720 Timeout = RtlEnlargedIntegerMultiply
721 ((timeout->tv_sec * 1000) + (timeout->tv_usec / 1000), -10000);
722 /* Negative timeouts are illegal. Since the kernel represents an
723 * incremental timeout as a negative number, we check for a positive
724 * result.
725 */
726 if (Timeout.QuadPart > 0)
727 {
728 if (lpErrno) *lpErrno = WSAEINVAL;
729 return SOCKET_ERROR;
730 }
731 AFD_DbgPrint(MAX_TRACE,("Timeout: Orig %d.%06d kernel %d\n",
732 timeout->tv_sec, timeout->tv_usec,
733 Timeout.u.LowPart));
734 }
735
736 Status = NtCreateEvent(&SockEvent,
737 GENERIC_READ | GENERIC_WRITE,
738 NULL,
739 1,
740 FALSE);
741
742 if( !NT_SUCCESS(Status) )
743 return SOCKET_ERROR;
744
745 /* Allocate */
746 PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
747
748 if (!PollBuffer)
749 {
750 if (lpErrno)
751 *lpErrno = WSAEFAULT;
752 NtClose(SockEvent);
753 return SOCKET_ERROR;
754 }
755
756 PollInfo = (PAFD_POLL_INFO)PollBuffer;
757
758 RtlZeroMemory( PollInfo, PollBufferSize );
759
760 /* Number of handles for AFD to Check */
761 PollInfo->Exclusive = FALSE;
762 PollInfo->Timeout = Timeout;
763
764 if (readfds != NULL) {
765 for (i = 0; i < readfds->fd_count; i++, j++)
766 {
767 PollInfo->Handles[j].Handle = readfds->fd_array[i];
768 PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
769 AFD_EVENT_DISCONNECT |
770 AFD_EVENT_ABORT |
771 AFD_EVENT_ACCEPT;
772 }
773 }
774 if (writefds != NULL)
775 {
776 for (i = 0; i < writefds->fd_count; i++, j++)
777 {
778 PollInfo->Handles[j].Handle = writefds->fd_array[i];
779 PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
780 }
781 }
782 if (exceptfds != NULL)
783 {
784 for (i = 0; i < exceptfds->fd_count; i++, j++)
785 {
786 PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
787 PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
788 }
789 }
790
791 PollInfo->HandleCount = j;
792 PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
793
794 /* Send IOCTL */
795 Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
796 SockEvent,
797 NULL,
798 NULL,
799 &IOSB,
800 IOCTL_AFD_SELECT,
801 PollInfo,
802 PollBufferSize,
803 PollInfo,
804 PollBufferSize);
805
806 AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
807
808 /* Wait for Completition */
809 if (Status == STATUS_PENDING)
810 {
811 WaitForSingleObject(SockEvent, INFINITE);
812 }
813
814 /* Clear the Structures */
815 if( readfds )
816 FD_ZERO(readfds);
817 if( writefds )
818 FD_ZERO(writefds);
819 if( exceptfds )
820 FD_ZERO(exceptfds);
821
822 /* Loop through return structure */
823 HandleCount = PollInfo->HandleCount;
824
825 /* Return in FDSET Format */
826 for (i = 0; i < HandleCount; i++)
827 {
828 HandleCounted = FALSE;
829 for(x = 1; x; x<<=1)
830 {
831 switch (PollInfo->Handles[i].Events & x)
832 {
833 case AFD_EVENT_RECEIVE:
834 case AFD_EVENT_DISCONNECT:
835 case AFD_EVENT_ABORT:
836 case AFD_EVENT_ACCEPT:
837 case AFD_EVENT_CLOSE:
838 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
839 PollInfo->Handles[i].Events,
840 PollInfo->Handles[i].Handle));
841 if (! HandleCounted)
842 {
843 OutCount++;
844 HandleCounted = TRUE;
845 }
846 if( readfds )
847 FD_SET(PollInfo->Handles[i].Handle, readfds);
848 break;
849 case AFD_EVENT_SEND:
850 case AFD_EVENT_CONNECT:
851 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
852 PollInfo->Handles[i].Events,
853 PollInfo->Handles[i].Handle));
854 if (! HandleCounted)
855 {
856 OutCount++;
857 HandleCounted = TRUE;
858 }
859 if( writefds )
860 FD_SET(PollInfo->Handles[i].Handle, writefds);
861 break;
862 case AFD_EVENT_OOB_RECEIVE:
863 case AFD_EVENT_CONNECT_FAIL:
864 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
865 PollInfo->Handles[i].Events,
866 PollInfo->Handles[i].Handle));
867 if (! HandleCounted)
868 {
869 OutCount++;
870 HandleCounted = TRUE;
871 }
872 if( exceptfds )
873 FD_SET(PollInfo->Handles[i].Handle, exceptfds);
874 break;
875 }
876 }
877 }
878
879 HeapFree( GlobalHeap, 0, PollBuffer );
880 NtClose( SockEvent );
881
882 if( lpErrno )
883 {
884 switch( IOSB.Status )
885 {
886 case STATUS_SUCCESS:
887 case STATUS_TIMEOUT:
888 *lpErrno = 0;
889 break;
890 default:
891 *lpErrno = WSAEINVAL;
892 break;
893 }
894 AFD_DbgPrint(MID_TRACE,("*lpErrno = %x\n", *lpErrno));
895 }
896
897 AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
898
899 return OutCount;
900 }
901
902 SOCKET
903 WSPAPI
904 WSPAccept(SOCKET Handle,
905 struct sockaddr *SocketAddress,
906 int *SocketAddressLength,
907 LPCONDITIONPROC lpfnCondition,
908 DWORD_PTR dwCallbackData,
909 LPINT lpErrno)
910 {
911 IO_STATUS_BLOCK IOSB;
912 PAFD_RECEIVED_ACCEPT_DATA ListenReceiveData;
913 AFD_ACCEPT_DATA AcceptData;
914 AFD_DEFER_ACCEPT_DATA DeferData;
915 AFD_PENDING_ACCEPT_DATA PendingAcceptData;
916 PSOCKET_INFORMATION Socket = NULL;
917 NTSTATUS Status;
918 struct fd_set ReadSet;
919 struct timeval Timeout;
920 PVOID PendingData = NULL;
921 ULONG PendingDataLength = 0;
922 PVOID CalleeDataBuffer;
923 WSABUF CallerData, CalleeID, CallerID, CalleeData;
924 PSOCKADDR RemoteAddress = NULL;
925 GROUP GroupID = 0;
926 ULONG CallBack;
927 WSAPROTOCOL_INFOW ProtocolInfo;
928 SOCKET AcceptSocket;
929 UCHAR ReceiveBuffer[0x1A];
930 HANDLE SockEvent;
931
932 Status = NtCreateEvent(&SockEvent,
933 GENERIC_READ | GENERIC_WRITE,
934 NULL,
935 1,
936 FALSE);
937
938 if( !NT_SUCCESS(Status) )
939 {
940 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
941 return INVALID_SOCKET;
942 }
943
944 /* Dynamic Structure...ugh */
945 ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
946
947 /* Get the Socket Structure associate to this Socket*/
948 Socket = GetSocketStructure(Handle);
949
950 /* If this is non-blocking, make sure there's something for us to accept */
951 FD_ZERO(&ReadSet);
952 FD_SET(Socket->Handle, &ReadSet);
953 Timeout.tv_sec=0;
954 Timeout.tv_usec=0;
955
956 WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
957
958 if (ReadSet.fd_array[0] != Socket->Handle)
959 {
960 NtClose(SockEvent);
961 return 0;
962 }
963
964 /* Send IOCTL */
965 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
966 SockEvent,
967 NULL,
968 NULL,
969 &IOSB,
970 IOCTL_AFD_WAIT_FOR_LISTEN,
971 NULL,
972 0,
973 ListenReceiveData,
974 0xA + sizeof(*ListenReceiveData));
975
976 /* Wait for return */
977 if (Status == STATUS_PENDING)
978 {
979 WaitForSingleObject(SockEvent, INFINITE);
980 Status = IOSB.Status;
981 }
982
983 if (!NT_SUCCESS(Status))
984 {
985 NtClose( SockEvent );
986 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
987 return INVALID_SOCKET;
988 }
989
990 if (lpfnCondition != NULL)
991 {
992 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0)
993 {
994 /* Find out how much data is pending */
995 PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
996 PendingAcceptData.ReturnSize = TRUE;
997
998 /* Send IOCTL */
999 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1000 SockEvent,
1001 NULL,
1002 NULL,
1003 &IOSB,
1004 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1005 &PendingAcceptData,
1006 sizeof(PendingAcceptData),
1007 &PendingAcceptData,
1008 sizeof(PendingAcceptData));
1009
1010 /* Wait for return */
1011 if (Status == STATUS_PENDING)
1012 {
1013 WaitForSingleObject(SockEvent, INFINITE);
1014 Status = IOSB.Status;
1015 }
1016
1017 if (!NT_SUCCESS(Status))
1018 {
1019 NtClose( SockEvent );
1020 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1021 return INVALID_SOCKET;
1022 }
1023
1024 /* How much data to allocate */
1025 PendingDataLength = IOSB.Information;
1026
1027 if (PendingDataLength)
1028 {
1029 /* Allocate needed space */
1030 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
1031
1032 /* We want the data now */
1033 PendingAcceptData.ReturnSize = FALSE;
1034
1035 /* Send IOCTL */
1036 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1037 SockEvent,
1038 NULL,
1039 NULL,
1040 &IOSB,
1041 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1042 &PendingAcceptData,
1043 sizeof(PendingAcceptData),
1044 PendingData,
1045 PendingDataLength);
1046
1047 /* Wait for return */
1048 if (Status == STATUS_PENDING)
1049 {
1050 WaitForSingleObject(SockEvent, INFINITE);
1051 Status = IOSB.Status;
1052 }
1053
1054 if (!NT_SUCCESS(Status))
1055 {
1056 NtClose( SockEvent );
1057 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1058 return INVALID_SOCKET;
1059 }
1060 }
1061 }
1062
1063 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1064 {
1065 /* I don't support this yet */
1066 }
1067
1068 /* Build Callee ID */
1069 CalleeID.buf = (PVOID)Socket->LocalAddress;
1070 CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
1071
1072 /* Set up Address in SOCKADDR Format */
1073 RtlCopyMemory (RemoteAddress,
1074 &ListenReceiveData->Address.Address[0].AddressType,
1075 sizeof(*RemoteAddress));
1076
1077 /* Build Caller ID */
1078 CallerID.buf = (PVOID)RemoteAddress;
1079 CallerID.len = sizeof(*RemoteAddress);
1080
1081 /* Build Caller Data */
1082 CallerData.buf = PendingData;
1083 CallerData.len = PendingDataLength;
1084
1085 /* Check if socket supports Conditional Accept */
1086 if (Socket->SharedData.UseDelayedAcceptance != 0)
1087 {
1088 /* Allocate Buffer for Callee Data */
1089 CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
1090 CalleeData.buf = CalleeDataBuffer;
1091 CalleeData.len = 4096;
1092 }
1093 else
1094 {
1095 /* Nothing */
1096 CalleeData.buf = 0;
1097 CalleeData.len = 0;
1098 }
1099
1100 /* Call the Condition Function */
1101 CallBack = (lpfnCondition)(&CallerID,
1102 CallerData.buf == NULL ? NULL : &CallerData,
1103 NULL,
1104 NULL,
1105 &CalleeID,
1106 CalleeData.buf == NULL ? NULL : &CalleeData,
1107 &GroupID,
1108 dwCallbackData);
1109
1110 if (((CallBack == CF_ACCEPT) && GroupID) != 0)
1111 {
1112 /* TBD: Check for Validity */
1113 }
1114
1115 if (CallBack == CF_ACCEPT)
1116 {
1117 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1118 {
1119 /* I don't support this yet */
1120 }
1121 if (CalleeData.buf)
1122 {
1123 // SockSetConnectData Sockets(SocketID), IOCTL_AFD_SET_CONNECT_DATA, CalleeData.Buffer, CalleeData.BuffSize, 0
1124 }
1125 }
1126 else
1127 {
1128 /* Callback rejected. Build Defer Structure */
1129 DeferData.SequenceNumber = ListenReceiveData->SequenceNumber;
1130 DeferData.RejectConnection = (CallBack == CF_REJECT);
1131
1132 /* Send IOCTL */
1133 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1134 SockEvent,
1135 NULL,
1136 NULL,
1137 &IOSB,
1138 IOCTL_AFD_DEFER_ACCEPT,
1139 &DeferData,
1140 sizeof(DeferData),
1141 NULL,
1142 0);
1143
1144 /* Wait for return */
1145 if (Status == STATUS_PENDING)
1146 {
1147 WaitForSingleObject(SockEvent, INFINITE);
1148 Status = IOSB.Status;
1149 }
1150
1151 NtClose( SockEvent );
1152
1153 if (!NT_SUCCESS(Status))
1154 {
1155 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1156 return INVALID_SOCKET;
1157 }
1158
1159 if (CallBack == CF_REJECT )
1160 {
1161 *lpErrno = WSAECONNREFUSED;
1162 return INVALID_SOCKET;
1163 }
1164 else
1165 {
1166 *lpErrno = WSAECONNREFUSED;
1167 return INVALID_SOCKET;
1168 }
1169 }
1170 }
1171
1172 /* Create a new Socket */
1173 ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
1174 ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
1175 ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
1176
1177 AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
1178 Socket->SharedData.SocketType,
1179 Socket->SharedData.Protocol,
1180 &ProtocolInfo,
1181 GroupID,
1182 Socket->SharedData.CreateFlags,
1183 NULL);
1184
1185 /* Set up the Accept Structure */
1186 AcceptData.ListenHandle = AcceptSocket;
1187 AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1188
1189 /* Send IOCTL to Accept */
1190 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1191 SockEvent,
1192 NULL,
1193 NULL,
1194 &IOSB,
1195 IOCTL_AFD_ACCEPT,
1196 &AcceptData,
1197 sizeof(AcceptData),
1198 NULL,
1199 0);
1200
1201 /* Wait for return */
1202 if (Status == STATUS_PENDING)
1203 {
1204 WaitForSingleObject(SockEvent, INFINITE);
1205 Status = IOSB.Status;
1206 }
1207
1208 if (!NT_SUCCESS(Status))
1209 {
1210 NtClose(SockEvent);
1211 WSPCloseSocket( AcceptSocket, lpErrno );
1212 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1213 return INVALID_SOCKET;
1214 }
1215
1216 /* Return Address in SOCKADDR FORMAT */
1217 if( SocketAddress )
1218 {
1219 RtlCopyMemory (SocketAddress,
1220 &ListenReceiveData->Address.Address[0].AddressType,
1221 sizeof(*RemoteAddress));
1222 if( SocketAddressLength )
1223 *SocketAddressLength = ListenReceiveData->Address.Address[0].AddressLength;
1224 }
1225
1226 NtClose( SockEvent );
1227
1228 /* Re-enable Async Event */
1229 SockReenableAsyncSelectEvent(Socket, FD_ACCEPT);
1230
1231 AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
1232
1233 *lpErrno = 0;
1234
1235 /* Return Socket */
1236 return AcceptSocket;
1237 }
1238
1239 int
1240 WSPAPI
1241 WSPConnect(SOCKET Handle,
1242 const struct sockaddr * SocketAddress,
1243 int SocketAddressLength,
1244 LPWSABUF lpCallerData,
1245 LPWSABUF lpCalleeData,
1246 LPQOS lpSQOS,
1247 LPQOS lpGQOS,
1248 LPINT lpErrno)
1249 {
1250 IO_STATUS_BLOCK IOSB;
1251 PAFD_CONNECT_INFO ConnectInfo;
1252 PSOCKET_INFORMATION Socket = NULL;
1253 NTSTATUS Status;
1254 UCHAR ConnectBuffer[0x22];
1255 ULONG ConnectDataLength;
1256 ULONG InConnectDataLength;
1257 INT BindAddressLength;
1258 PSOCKADDR BindAddress;
1259 HANDLE SockEvent;
1260
1261 Status = NtCreateEvent(&SockEvent,
1262 GENERIC_READ | GENERIC_WRITE,
1263 NULL,
1264 1,
1265 FALSE);
1266
1267 if( !NT_SUCCESS(Status) )
1268 return -1;
1269
1270 AFD_DbgPrint(MID_TRACE,("Called\n"));
1271
1272 /* Get the Socket Structure associate to this Socket*/
1273 Socket = GetSocketStructure(Handle);
1274
1275 /* Bind us First */
1276 if (Socket->SharedData.State == SocketOpen)
1277 {
1278 /* Get the Wildcard Address */
1279 BindAddressLength = Socket->HelperData->MaxWSAddressLength;
1280 BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
1281 Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
1282 BindAddress,
1283 &BindAddressLength);
1284 /* Bind it */
1285 WSPBind(Handle, BindAddress, BindAddressLength, NULL);
1286 }
1287
1288 /* Set the Connect Data */
1289 if (lpCallerData != NULL)
1290 {
1291 ConnectDataLength = lpCallerData->len;
1292 Status = NtDeviceIoControlFile((HANDLE)Handle,
1293 SockEvent,
1294 NULL,
1295 NULL,
1296 &IOSB,
1297 IOCTL_AFD_SET_CONNECT_DATA,
1298 lpCallerData->buf,
1299 ConnectDataLength,
1300 NULL,
1301 0);
1302 /* Wait for return */
1303 if (Status == STATUS_PENDING)
1304 {
1305 WaitForSingleObject(SockEvent, INFINITE);
1306 Status = IOSB.Status;
1307 }
1308 }
1309
1310 /* Dynamic Structure...ugh */
1311 ConnectInfo = (PAFD_CONNECT_INFO)ConnectBuffer;
1312
1313 /* Set up Address in TDI Format */
1314 ConnectInfo->RemoteAddress.TAAddressCount = 1;
1315 ConnectInfo->RemoteAddress.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
1316 ConnectInfo->RemoteAddress.Address[0].AddressType = SocketAddress->sa_family;
1317 RtlCopyMemory (ConnectInfo->RemoteAddress.Address[0].Address,
1318 SocketAddress->sa_data,
1319 SocketAddressLength - sizeof(SocketAddress->sa_family));
1320
1321 /*
1322 * Disable FD_WRITE and FD_CONNECT
1323 * The latter fixes a race condition where the FD_CONNECT is re-enabled
1324 * at the end of this function right after the Async Thread disables it.
1325 * This should only happen at the *next* WSPConnect
1326 */
1327 if (Socket->SharedData.AsyncEvents & FD_CONNECT)
1328 {
1329 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
1330 }
1331
1332 /* Tell AFD that we want Connection Data back, have it allocate a buffer */
1333 if (lpCalleeData != NULL)
1334 {
1335 InConnectDataLength = lpCalleeData->len;
1336 Status = NtDeviceIoControlFile((HANDLE)Handle,
1337 SockEvent,
1338 NULL,
1339 NULL,
1340 &IOSB,
1341 IOCTL_AFD_SET_CONNECT_DATA_SIZE,
1342 &InConnectDataLength,
1343 sizeof(InConnectDataLength),
1344 NULL,
1345 0);
1346
1347 /* Wait for return */
1348 if (Status == STATUS_PENDING)
1349 {
1350 WaitForSingleObject(SockEvent, INFINITE);
1351 Status = IOSB.Status;
1352 }
1353 }
1354
1355 /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
1356 ConnectInfo->Root = 0;
1357 ConnectInfo->UseSAN = FALSE;
1358 ConnectInfo->Unknown = 0;
1359
1360 /* FIXME: Handle Async Connect */
1361 if (Socket->SharedData.NonBlocking)
1362 {
1363 AFD_DbgPrint(MIN_TRACE, ("Async Connect UNIMPLEMENTED!\n"));
1364 }
1365
1366 /* Send IOCTL */
1367 Status = NtDeviceIoControlFile((HANDLE)Handle,
1368 SockEvent,
1369 NULL,
1370 NULL,
1371 &IOSB,
1372 IOCTL_AFD_CONNECT,
1373 ConnectInfo,
1374 0x22,
1375 NULL,
1376 0);
1377 /* Wait for return */
1378 if (Status == STATUS_PENDING)
1379 {
1380 WaitForSingleObject(SockEvent, INFINITE);
1381 Status = IOSB.Status;
1382 }
1383
1384 /* Get any pending connect data */
1385 if (lpCalleeData != NULL)
1386 {
1387 Status = NtDeviceIoControlFile((HANDLE)Handle,
1388 SockEvent,
1389 NULL,
1390 NULL,
1391 &IOSB,
1392 IOCTL_AFD_GET_CONNECT_DATA,
1393 NULL,
1394 0,
1395 lpCalleeData->buf,
1396 lpCalleeData->len);
1397 /* Wait for return */
1398 if (Status == STATUS_PENDING)
1399 {
1400 WaitForSingleObject(SockEvent, INFINITE);
1401 Status = IOSB.Status;
1402 }
1403 }
1404
1405 /* Re-enable Async Event */
1406 SockReenableAsyncSelectEvent(Socket, FD_WRITE);
1407
1408 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
1409 SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
1410
1411 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1412
1413 NtClose( SockEvent );
1414
1415 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1416 }
1417 int
1418 WSPAPI
1419 WSPShutdown(SOCKET Handle,
1420 int HowTo,
1421 LPINT lpErrno)
1422
1423 {
1424 IO_STATUS_BLOCK IOSB;
1425 AFD_DISCONNECT_INFO DisconnectInfo;
1426 PSOCKET_INFORMATION Socket = NULL;
1427 NTSTATUS Status;
1428 HANDLE SockEvent;
1429
1430 Status = NtCreateEvent(&SockEvent,
1431 GENERIC_READ | GENERIC_WRITE,
1432 NULL,
1433 1,
1434 FALSE);
1435
1436 if( !NT_SUCCESS(Status) )
1437 return -1;
1438
1439 AFD_DbgPrint(MID_TRACE,("Called\n"));
1440
1441 /* Get the Socket Structure associate to this Socket*/
1442 Socket = GetSocketStructure(Handle);
1443
1444 /* Set AFD Disconnect Type */
1445 switch (HowTo)
1446 {
1447 case SD_RECEIVE:
1448 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
1449 Socket->SharedData.ReceiveShutdown = TRUE;
1450 break;
1451 case SD_SEND:
1452 DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
1453 Socket->SharedData.SendShutdown = TRUE;
1454 break;
1455 case SD_BOTH:
1456 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
1457 Socket->SharedData.ReceiveShutdown = TRUE;
1458 Socket->SharedData.SendShutdown = TRUE;
1459 break;
1460 }
1461
1462 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
1463
1464 /* Send IOCTL */
1465 Status = NtDeviceIoControlFile((HANDLE)Handle,
1466 SockEvent,
1467 NULL,
1468 NULL,
1469 &IOSB,
1470 IOCTL_AFD_DISCONNECT,
1471 &DisconnectInfo,
1472 sizeof(DisconnectInfo),
1473 NULL,
1474 0);
1475
1476 /* Wait for return */
1477 if (Status == STATUS_PENDING)
1478 {
1479 WaitForSingleObject(SockEvent, INFINITE);
1480 Status = IOSB.Status;
1481 }
1482
1483 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1484
1485 NtClose( SockEvent );
1486
1487 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1488 }
1489
1490
1491 INT
1492 WSPAPI
1493 WSPGetSockName(IN SOCKET Handle,
1494 OUT LPSOCKADDR Name,
1495 IN OUT LPINT NameLength,
1496 OUT LPINT lpErrno)
1497 {
1498 IO_STATUS_BLOCK IOSB;
1499 ULONG TdiAddressSize;
1500 PTDI_ADDRESS_INFO TdiAddress;
1501 PTRANSPORT_ADDRESS SocketAddress;
1502 PSOCKET_INFORMATION Socket = NULL;
1503 NTSTATUS Status;
1504 HANDLE SockEvent;
1505
1506 Status = NtCreateEvent(&SockEvent,
1507 GENERIC_READ | GENERIC_WRITE,
1508 NULL,
1509 1,
1510 FALSE);
1511
1512 if( !NT_SUCCESS(Status) )
1513 return SOCKET_ERROR;
1514
1515 /* Get the Socket Structure associate to this Socket*/
1516 Socket = GetSocketStructure(Handle);
1517
1518 /* Allocate a buffer for the address */
1519 TdiAddressSize =
1520 sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
1521 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1522
1523 if ( TdiAddress == NULL )
1524 {
1525 NtClose( SockEvent );
1526 *lpErrno = WSAENOBUFS;
1527 return SOCKET_ERROR;
1528 }
1529
1530 SocketAddress = &TdiAddress->Address;
1531
1532 /* Send IOCTL */
1533 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1534 SockEvent,
1535 NULL,
1536 NULL,
1537 &IOSB,
1538 IOCTL_AFD_GET_SOCK_NAME,
1539 NULL,
1540 0,
1541 TdiAddress,
1542 TdiAddressSize);
1543
1544 /* Wait for return */
1545 if (Status == STATUS_PENDING)
1546 {
1547 WaitForSingleObject(SockEvent, INFINITE);
1548 Status = IOSB.Status;
1549 }
1550
1551 NtClose( SockEvent );
1552
1553 if (NT_SUCCESS(Status))
1554 {
1555 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1556 {
1557 Name->sa_family = SocketAddress->Address[0].AddressType;
1558 RtlCopyMemory (Name->sa_data,
1559 SocketAddress->Address[0].Address,
1560 SocketAddress->Address[0].AddressLength);
1561 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1562 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
1563 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1564 ((struct sockaddr_in *)Name)->sin_port));
1565 HeapFree(GlobalHeap, 0, TdiAddress);
1566 return 0;
1567 }
1568 else
1569 {
1570 HeapFree(GlobalHeap, 0, TdiAddress);
1571 *lpErrno = WSAEFAULT;
1572 return SOCKET_ERROR;
1573 }
1574 }
1575
1576 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1577 }
1578
1579
1580 INT
1581 WSPAPI
1582 WSPGetPeerName(IN SOCKET s,
1583 OUT LPSOCKADDR Name,
1584 IN OUT LPINT NameLength,
1585 OUT LPINT lpErrno)
1586 {
1587 IO_STATUS_BLOCK IOSB;
1588 ULONG TdiAddressSize;
1589 PTRANSPORT_ADDRESS SocketAddress;
1590 PSOCKET_INFORMATION Socket = NULL;
1591 NTSTATUS Status;
1592 HANDLE SockEvent;
1593
1594 Status = NtCreateEvent(&SockEvent,
1595 GENERIC_READ | GENERIC_WRITE,
1596 NULL,
1597 1,
1598 FALSE);
1599
1600 if( !NT_SUCCESS(Status) )
1601 return SOCKET_ERROR;
1602
1603 /* Get the Socket Structure associate to this Socket*/
1604 Socket = GetSocketStructure(s);
1605
1606 /* Allocate a buffer for the address */
1607 TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
1608 SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1609
1610 if ( SocketAddress == NULL )
1611 {
1612 NtClose( SockEvent );
1613 *lpErrno = WSAENOBUFS;
1614 return SOCKET_ERROR;
1615 }
1616
1617 /* Send IOCTL */
1618 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1619 SockEvent,
1620 NULL,
1621 NULL,
1622 &IOSB,
1623 IOCTL_AFD_GET_PEER_NAME,
1624 NULL,
1625 0,
1626 SocketAddress,
1627 TdiAddressSize);
1628
1629 /* Wait for return */
1630 if (Status == STATUS_PENDING)
1631 {
1632 WaitForSingleObject(SockEvent, INFINITE);
1633 Status = IOSB.Status;
1634 }
1635
1636 NtClose( SockEvent );
1637
1638 if (NT_SUCCESS(Status))
1639 {
1640 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1641 {
1642 Name->sa_family = SocketAddress->Address[0].AddressType;
1643 RtlCopyMemory (Name->sa_data,
1644 SocketAddress->Address[0].Address,
1645 SocketAddress->Address[0].AddressLength);
1646 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1647 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
1648 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1649 ((struct sockaddr_in *)Name)->sin_port));
1650 HeapFree(GlobalHeap, 0, SocketAddress);
1651 return 0;
1652 }
1653 else
1654 {
1655 HeapFree(GlobalHeap, 0, SocketAddress);
1656 *lpErrno = WSAEFAULT;
1657 return SOCKET_ERROR;
1658 }
1659 }
1660
1661 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1662 }
1663
1664 INT
1665 WSPAPI
1666 WSPIoctl(IN SOCKET Handle,
1667 IN DWORD dwIoControlCode,
1668 IN LPVOID lpvInBuffer,
1669 IN DWORD cbInBuffer,
1670 OUT LPVOID lpvOutBuffer,
1671 IN DWORD cbOutBuffer,
1672 OUT LPDWORD lpcbBytesReturned,
1673 IN LPWSAOVERLAPPED lpOverlapped,
1674 IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
1675 IN LPWSATHREADID lpThreadId,
1676 OUT LPINT lpErrno)
1677 {
1678 PSOCKET_INFORMATION Socket = NULL;
1679
1680 /* Get the Socket Structure associate to this Socket*/
1681 Socket = GetSocketStructure(Handle);
1682
1683 switch( dwIoControlCode )
1684 {
1685 case FIONBIO:
1686 if( cbInBuffer < sizeof(INT) )
1687 return SOCKET_ERROR;
1688 Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
1689 AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n", Handle, Socket->SharedData.NonBlocking));
1690 return 0;
1691 case FIONREAD:
1692 return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
1693 default:
1694 *lpErrno = WSAEINVAL;
1695 return SOCKET_ERROR;
1696 }
1697 }
1698
1699
1700 INT
1701 WSPAPI
1702 WSPGetSockOpt(IN SOCKET Handle,
1703 IN INT Level,
1704 IN INT OptionName,
1705 OUT CHAR FAR* OptionValue,
1706 IN OUT LPINT OptionLength,
1707 OUT LPINT lpErrno)
1708 {
1709 PSOCKET_INFORMATION Socket = NULL;
1710 PVOID Buffer;
1711 INT BufferSize;
1712 BOOLEAN BoolBuffer;
1713
1714 /* Get the Socket Structure associate to this Socket*/
1715 Socket = GetSocketStructure(Handle);
1716 if (Socket == NULL)
1717 {
1718 *lpErrno = WSAENOTSOCK;
1719 return SOCKET_ERROR;
1720 }
1721
1722 AFD_DbgPrint(MID_TRACE, ("Called\n"));
1723
1724 switch (Level)
1725 {
1726 case SOL_SOCKET:
1727 switch (OptionName)
1728 {
1729 case SO_TYPE:
1730 Buffer = &Socket->SharedData.SocketType;
1731 BufferSize = sizeof(INT);
1732 break;
1733
1734 case SO_RCVBUF:
1735 Buffer = &Socket->SharedData.SizeOfRecvBuffer;
1736 BufferSize = sizeof(INT);
1737 break;
1738
1739 case SO_SNDBUF:
1740 Buffer = &Socket->SharedData.SizeOfSendBuffer;
1741 BufferSize = sizeof(INT);
1742 break;
1743
1744 case SO_ACCEPTCONN:
1745 BoolBuffer = Socket->SharedData.Listening;
1746 Buffer = &BoolBuffer;
1747 BufferSize = sizeof(BOOLEAN);
1748 break;
1749
1750 case SO_BROADCAST:
1751 BoolBuffer = Socket->SharedData.Broadcast;
1752 Buffer = &BoolBuffer;
1753 BufferSize = sizeof(BOOLEAN);
1754 break;
1755
1756 case SO_DEBUG:
1757 BoolBuffer = Socket->SharedData.Debug;
1758 Buffer = &BoolBuffer;
1759 BufferSize = sizeof(BOOLEAN);
1760 break;
1761
1762 /* case SO_CONDITIONAL_ACCEPT: */
1763 case SO_DONTLINGER:
1764 case SO_DONTROUTE:
1765 case SO_ERROR:
1766 case SO_GROUP_ID:
1767 case SO_GROUP_PRIORITY:
1768 case SO_KEEPALIVE:
1769 case SO_LINGER:
1770 case SO_MAX_MSG_SIZE:
1771 case SO_OOBINLINE:
1772 case SO_PROTOCOL_INFO:
1773 case SO_REUSEADDR:
1774 AFD_DbgPrint(MID_TRACE, ("Unimplemented option (%x)\n",
1775 OptionName));
1776
1777 default:
1778 *lpErrno = WSAEINVAL;
1779 return SOCKET_ERROR;
1780 }
1781
1782 if (*OptionLength < BufferSize)
1783 {
1784 *lpErrno = WSAEFAULT;
1785 *OptionLength = BufferSize;
1786 return SOCKET_ERROR;
1787 }
1788 RtlCopyMemory(OptionValue, Buffer, BufferSize);
1789
1790 return 0;
1791
1792 case IPPROTO_TCP: /* FIXME */
1793 default:
1794 *lpErrno = WSAEINVAL;
1795 return SOCKET_ERROR;
1796 }
1797 }
1798
1799
1800 /*
1801 * FUNCTION: Initialize service provider for a client
1802 * ARGUMENTS:
1803 * wVersionRequested = Highest WinSock SPI version that the caller can use
1804 * lpWSPData = Address of WSPDATA structure to initialize
1805 * lpProtocolInfo = Pointer to structure that defines the desired protocol
1806 * UpcallTable = Pointer to upcall table of the WinSock DLL
1807 * lpProcTable = Address of procedure table to initialize
1808 * RETURNS:
1809 * Status of operation
1810 */
1811 INT
1812 WSPAPI
1813 WSPStartup(IN WORD wVersionRequested,
1814 OUT LPWSPDATA lpWSPData,
1815 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
1816 IN WSPUPCALLTABLE UpcallTable,
1817 OUT LPWSPPROC_TABLE lpProcTable)
1818
1819 {
1820 NTSTATUS Status;
1821
1822 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
1823 Status = NO_ERROR;
1824 Upcalls = UpcallTable;
1825
1826 if (Status == NO_ERROR)
1827 {
1828 lpProcTable->lpWSPAccept = WSPAccept;
1829 lpProcTable->lpWSPAddressToString = WSPAddressToString;
1830 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
1831 lpProcTable->lpWSPBind = WSPBind;
1832 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
1833 lpProcTable->lpWSPCleanup = WSPCleanup;
1834 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
1835 lpProcTable->lpWSPConnect = WSPConnect;
1836 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
1837 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
1838 lpProcTable->lpWSPEventSelect = WSPEventSelect;
1839 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
1840 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
1841 lpProcTable->lpWSPGetSockName = WSPGetSockName;
1842 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
1843 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
1844 lpProcTable->lpWSPIoctl = WSPIoctl;
1845 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
1846 lpProcTable->lpWSPListen = WSPListen;
1847 lpProcTable->lpWSPRecv = WSPRecv;
1848 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
1849 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
1850 lpProcTable->lpWSPSelect = WSPSelect;
1851 lpProcTable->lpWSPSend = WSPSend;
1852 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
1853 lpProcTable->lpWSPSendTo = WSPSendTo;
1854 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
1855 lpProcTable->lpWSPShutdown = WSPShutdown;
1856 lpProcTable->lpWSPSocket = WSPSocket;
1857 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
1858 lpWSPData->wVersion = MAKEWORD(2, 2);
1859 lpWSPData->wHighVersion = MAKEWORD(2, 2);
1860 }
1861
1862 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
1863
1864 return Status;
1865 }
1866
1867
1868 /*
1869 * FUNCTION: Cleans up service provider for a client
1870 * ARGUMENTS:
1871 * lpErrno = Address of buffer for error information
1872 * RETURNS:
1873 * 0 if successful, or SOCKET_ERROR if not
1874 */
1875 INT
1876 WSPAPI
1877 WSPCleanup(OUT LPINT lpErrno)
1878
1879 {
1880 AFD_DbgPrint(MAX_TRACE, ("\n"));
1881 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
1882 *lpErrno = NO_ERROR;
1883
1884 return 0;
1885 }
1886
1887
1888
1889 int
1890 GetSocketInformation(PSOCKET_INFORMATION Socket,
1891 ULONG AfdInformationClass,
1892 PULONG Ulong OPTIONAL,
1893 PLARGE_INTEGER LargeInteger OPTIONAL)
1894 {
1895 IO_STATUS_BLOCK IOSB;
1896 AFD_INFO InfoData;
1897 NTSTATUS Status;
1898 HANDLE SockEvent;
1899
1900 Status = NtCreateEvent(&SockEvent,
1901 GENERIC_READ | GENERIC_WRITE,
1902 NULL,
1903 1,
1904 FALSE);
1905
1906 if( !NT_SUCCESS(Status) )
1907 return -1;
1908
1909 /* Set Info Class */
1910 InfoData.InformationClass = AfdInformationClass;
1911
1912 /* Send IOCTL */
1913 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1914 SockEvent,
1915 NULL,
1916 NULL,
1917 &IOSB,
1918 IOCTL_AFD_GET_INFO,
1919 &InfoData,
1920 sizeof(InfoData),
1921 &InfoData,
1922 sizeof(InfoData));
1923
1924 /* Wait for return */
1925 if (Status == STATUS_PENDING)
1926 {
1927 WaitForSingleObject(SockEvent, INFINITE);
1928 }
1929
1930 /* Return Information */
1931 *Ulong = InfoData.Information.Ulong;
1932 if (LargeInteger != NULL)
1933 {
1934 *LargeInteger = InfoData.Information.LargeInteger;
1935 }
1936
1937 NtClose( SockEvent );
1938
1939 return 0;
1940
1941 }
1942
1943
1944 int
1945 SetSocketInformation(PSOCKET_INFORMATION Socket,
1946 ULONG AfdInformationClass,
1947 PULONG Ulong OPTIONAL,
1948 PLARGE_INTEGER LargeInteger OPTIONAL)
1949 {
1950 IO_STATUS_BLOCK IOSB;
1951 AFD_INFO InfoData;
1952 NTSTATUS Status;
1953 HANDLE SockEvent;
1954
1955 Status = NtCreateEvent(&SockEvent,
1956 GENERIC_READ | GENERIC_WRITE,
1957 NULL,
1958 1,
1959 FALSE);
1960
1961 if( !NT_SUCCESS(Status) )
1962 return -1;
1963
1964 /* Set Info Class */
1965 InfoData.InformationClass = AfdInformationClass;
1966
1967 /* Set Information */
1968 InfoData.Information.Ulong = *Ulong;
1969 if (LargeInteger != NULL)
1970 {
1971 InfoData.Information.LargeInteger = *LargeInteger;
1972 }
1973
1974 AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
1975 AfdInformationClass, *Ulong));
1976
1977 /* Send IOCTL */
1978 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1979 SockEvent,
1980 NULL,
1981 NULL,
1982 &IOSB,
1983 IOCTL_AFD_SET_INFO,
1984 &InfoData,
1985 sizeof(InfoData),
1986 NULL,
1987 0);
1988
1989 /* Wait for return */
1990 if (Status == STATUS_PENDING)
1991 {
1992 WaitForSingleObject(SockEvent, INFINITE);
1993 }
1994
1995 NtClose( SockEvent );
1996
1997 return 0;
1998
1999 }
2000
2001 PSOCKET_INFORMATION
2002 GetSocketStructure(SOCKET Handle)
2003 {
2004 ULONG i;
2005
2006 for (i=0; i<SocketCount; i++)
2007 {
2008 if (Sockets[i]->Handle == Handle)
2009 {
2010 return Sockets[i];
2011 }
2012 }
2013 return 0;
2014 }
2015
2016 int CreateContext(PSOCKET_INFORMATION Socket)
2017 {
2018 IO_STATUS_BLOCK IOSB;
2019 SOCKET_CONTEXT ContextData;
2020 NTSTATUS Status;
2021 HANDLE SockEvent;
2022
2023 Status = NtCreateEvent(&SockEvent,
2024 GENERIC_READ | GENERIC_WRITE,
2025 NULL,
2026 1,
2027 FALSE);
2028
2029 if( !NT_SUCCESS(Status) )
2030 return -1;
2031
2032 /* Create Context */
2033 ContextData.SharedData = Socket->SharedData;
2034 ContextData.SizeOfHelperData = 0;
2035 RtlCopyMemory (&ContextData.LocalAddress,
2036 Socket->LocalAddress,
2037 Socket->SharedData.SizeOfLocalAddress);
2038 RtlCopyMemory (&ContextData.RemoteAddress,
2039 Socket->RemoteAddress,
2040 Socket->SharedData.SizeOfRemoteAddress);
2041
2042 /* Send IOCTL */
2043 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
2044 SockEvent,
2045 NULL,
2046 NULL,
2047 &IOSB,
2048 IOCTL_AFD_SET_CONTEXT,
2049 &ContextData,
2050 sizeof(ContextData),
2051 NULL,
2052 0);
2053
2054 /* Wait for Completition */
2055 if (Status == STATUS_PENDING)
2056 {
2057 WaitForSingleObject(SockEvent, INFINITE);
2058 }
2059
2060 NtClose( SockEvent );
2061
2062 return 0;
2063 }
2064
2065 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
2066 {
2067 HANDLE hAsyncThread;
2068 DWORD AsyncThreadId;
2069 HANDLE AsyncEvent;
2070 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2071 NTSTATUS Status;
2072
2073 /* Check if the Thread Already Exists */
2074 if (SockAsyncThreadRefCount)
2075 {
2076 return TRUE;
2077 }
2078
2079 /* Create the Completion Port */
2080 if (!SockAsyncCompletionPort)
2081 {
2082 Status = NtCreateIoCompletion(&SockAsyncCompletionPort,
2083 IO_COMPLETION_ALL_ACCESS,
2084 NULL,
2085 2); // Allow 2 threads only
2086
2087 /* Protect Handle */
2088 HandleFlags.ProtectFromClose = TRUE;
2089 HandleFlags.Inherit = FALSE;
2090 Status = NtSetInformationObject(SockAsyncCompletionPort,
2091 ObjectHandleFlagInformation,
2092 &HandleFlags,
2093 sizeof(HandleFlags));
2094 }
2095
2096 /* Create the Async Event */
2097 Status = NtCreateEvent(&AsyncEvent,
2098 EVENT_ALL_ACCESS,
2099 NULL,
2100 NotificationEvent,
2101 FALSE);
2102
2103 /* Create the Async Thread */
2104 hAsyncThread = CreateThread(NULL,
2105 0,
2106 (LPTHREAD_START_ROUTINE)SockAsyncThread,
2107 NULL,
2108 0,
2109 &AsyncThreadId);
2110
2111 /* Close the Handle */
2112 NtClose(hAsyncThread);
2113
2114 /* Increase the Reference Count */
2115 SockAsyncThreadRefCount++;
2116 return TRUE;
2117 }
2118
2119 int SockAsyncThread(PVOID ThreadParam)
2120 {
2121 PVOID AsyncContext;
2122 PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
2123 IO_STATUS_BLOCK IOSB;
2124 NTSTATUS Status;
2125
2126 /* Make the Thread Higher Priority */
2127 SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
2128
2129 /* Do a KQUEUE/WorkItem Style Loop, thanks to IoCompletion Ports */
2130 do
2131 {
2132 Status = NtRemoveIoCompletion (SockAsyncCompletionPort,
2133 (PVOID*)&AsyncCompletionRoutine,
2134 &AsyncContext,
2135 &IOSB,
2136 NULL);
2137 /* Call the Async Function */
2138 if (NT_SUCCESS(Status))
2139 {
2140 (*AsyncCompletionRoutine)(AsyncContext, &IOSB);
2141 }
2142 else
2143 {
2144 /* It Failed, sleep for a second */
2145 Sleep(1000);
2146 }
2147 } while ((Status != STATUS_TIMEOUT));
2148
2149 /* The Thread has Ended */
2150 return 0;
2151 }
2152
2153 BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
2154 {
2155 UNICODE_STRING AfdHelper;
2156 OBJECT_ATTRIBUTES ObjectAttributes;
2157 IO_STATUS_BLOCK IoSb;
2158 NTSTATUS Status;
2159 FILE_COMPLETION_INFORMATION CompletionInfo;
2160 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2161
2162 /* First, make sure we're not already intialized */
2163 if (SockAsyncHelperAfdHandle)
2164 {
2165 return TRUE;
2166 }
2167
2168 /* Set up Handle Name and Object */
2169 RtlInitUnicodeString(&AfdHelper, L"\\Device\\Afd\\AsyncSelectHlp" );
2170 InitializeObjectAttributes(&ObjectAttributes,
2171 &AfdHelper,
2172 OBJ_INHERIT | OBJ_CASE_INSENSITIVE,
2173 NULL,
2174 NULL);
2175
2176 /* Open the Handle to AFD */
2177 Status = NtCreateFile(&SockAsyncHelperAfdHandle,
2178 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
2179 &ObjectAttributes,
2180 &IoSb,
2181 NULL,
2182 0,
2183 FILE_SHARE_READ | FILE_SHARE_WRITE,
2184 FILE_OPEN_IF,
2185 0,
2186 NULL,
2187 0);
2188
2189 /*
2190 * Now Set up the Completion Port Information
2191 * This means that whenever a Poll is finished, the routine will be executed
2192 */
2193 CompletionInfo.Port = SockAsyncCompletionPort;
2194 CompletionInfo.Key = SockAsyncSelectCompletionRoutine;
2195 Status = NtSetInformationFile(SockAsyncHelperAfdHandle,
2196 &IoSb,
2197 &CompletionInfo,
2198 sizeof(CompletionInfo),
2199 FileCompletionInformation);
2200
2201
2202 /* Protect the Handle */
2203 HandleFlags.ProtectFromClose = TRUE;
2204 HandleFlags.Inherit = FALSE;
2205 Status = NtSetInformationObject(SockAsyncCompletionPort,
2206 ObjectHandleFlagInformation,
2207 &HandleFlags,
2208 sizeof(HandleFlags));
2209
2210
2211 /* Set this variable to true so that Send/Recv/Accept will know wether to renable disabled events */
2212 SockAsyncSelectCalled = TRUE;
2213 return TRUE;
2214 }
2215
2216 VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2217 {
2218
2219 PASYNC_DATA AsyncData = Context;
2220 PSOCKET_INFORMATION Socket;
2221 ULONG x;
2222
2223 /* Get the Socket */
2224 Socket = AsyncData->ParentSocket;
2225
2226 /* Check if the Sequence Number Changed behind our back */
2227 if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber )
2228 {
2229 return;
2230 }
2231
2232 /* Check we were manually called b/c of a failure */
2233 if (!NT_SUCCESS(IoStatusBlock->Status))
2234 {
2235 /* FIXME: Perform Upcall */
2236 return;
2237 }
2238
2239 for (x = 1; x; x<<=1)
2240 {
2241 switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x)
2242 {
2243 case AFD_EVENT_RECEIVE:
2244 if (0 != (Socket->SharedData.AsyncEvents & FD_READ) &&
2245 0 == (Socket->SharedData.AsyncDisabledEvents & FD_READ))
2246 {
2247 /* Make the Notifcation */
2248 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2249 Socket->SharedData.wMsg,
2250 Socket->Handle,
2251 WSAMAKESELECTREPLY(FD_READ, 0));
2252 /* Disable this event until the next read(); */
2253 Socket->SharedData.AsyncDisabledEvents |= FD_READ;
2254 }
2255 break;
2256
2257 case AFD_EVENT_OOB_RECEIVE:
2258 if (0 != (Socket->SharedData.AsyncEvents & FD_OOB) &&
2259 0 == (Socket->SharedData.AsyncDisabledEvents & FD_OOB))
2260 {
2261 /* Make the Notifcation */
2262 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2263 Socket->SharedData.wMsg,
2264 Socket->Handle,
2265 WSAMAKESELECTREPLY(FD_OOB, 0));
2266 /* Disable this event until the next read(); */
2267 Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
2268 }
2269 break;
2270
2271 case AFD_EVENT_SEND:
2272 if (0 != (Socket->SharedData.AsyncEvents & FD_WRITE) &&
2273 0 == (Socket->SharedData.AsyncDisabledEvents & FD_WRITE))
2274 {
2275 /* Make the Notifcation */
2276 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2277 Socket->SharedData.wMsg,
2278 Socket->Handle,
2279 WSAMAKESELECTREPLY(FD_WRITE, 0));
2280 /* Disable this event until the next write(); */
2281 Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
2282 }
2283 break;
2284
2285 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2286 case AFD_EVENT_CONNECT:
2287 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
2288 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
2289 {
2290 /* Make the Notifcation */
2291 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2292 Socket->SharedData.wMsg,
2293 Socket->Handle,
2294 WSAMAKESELECTREPLY(FD_CONNECT, 0));
2295 /* Disable this event forever; */
2296 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT;
2297 }
2298 break;
2299
2300 case AFD_EVENT_ACCEPT:
2301 if (0 != (Socket->SharedData.AsyncEvents & FD_ACCEPT) &&
2302 0 == (Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT))
2303 {
2304 /* Make the Notifcation */
2305 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2306 Socket->SharedData.wMsg,
2307 Socket->Handle,
2308 WSAMAKESELECTREPLY(FD_ACCEPT, 0));
2309 /* Disable this event until the next accept(); */
2310 Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
2311 }
2312 break;
2313
2314 case AFD_EVENT_DISCONNECT:
2315 case AFD_EVENT_ABORT:
2316 case AFD_EVENT_CLOSE:
2317 if (0 != (Socket->SharedData.AsyncEvents & FD_CLOSE) &&
2318 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CLOSE))
2319 {
2320 /* Make the Notifcation */
2321 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2322 Socket->SharedData.wMsg,
2323 Socket->Handle,
2324 WSAMAKESELECTREPLY(FD_CLOSE, 0));
2325 /* Disable this event forever; */
2326 Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
2327 }
2328 break;
2329 /* FIXME: Support QOS */
2330 }
2331 }
2332
2333 /* Check if there are any events left for us to check */
2334 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2335 {
2336 return;
2337 }
2338
2339 /* Keep Polling */
2340 SockProcessAsyncSelect(Socket, AsyncData);
2341 return;
2342 }
2343
2344 VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
2345 {
2346
2347 ULONG lNetworkEvents;
2348 NTSTATUS Status;
2349
2350 /* Set up the Async Data Event Info */
2351 AsyncData->AsyncSelectInfo.Timeout.HighPart = 0x7FFFFFFF;
2352 AsyncData->AsyncSelectInfo.Timeout.LowPart = 0xFFFFFFFF;
2353 AsyncData->AsyncSelectInfo.HandleCount = 1;
2354 AsyncData->AsyncSelectInfo.Exclusive = TRUE;
2355 AsyncData->AsyncSelectInfo.Handles[0].Handle = Socket->Handle;
2356 AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
2357
2358 /* Remove unwanted events */
2359 lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
2360
2361 /* Set Events to wait for */
2362 if (lNetworkEvents & FD_READ)
2363 {
2364 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_RECEIVE;
2365 }
2366
2367 if (lNetworkEvents & FD_WRITE)
2368 {
2369 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_SEND;
2370 }
2371
2372 if (lNetworkEvents & FD_OOB)
2373 {
2374 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_OOB_RECEIVE;
2375 }
2376
2377 if (lNetworkEvents & FD_ACCEPT)
2378 {
2379 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_ACCEPT;
2380 }
2381
2382 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2383 if (lNetworkEvents & FD_CONNECT)
2384 {
2385 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_CONNECT | AFD_EVENT_CONNECT_FAIL;
2386 }
2387
2388 if (lNetworkEvents & FD_CLOSE)
2389 {
2390 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
2391 }
2392
2393 if (lNetworkEvents & FD_QOS)
2394 {
2395 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_QOS;
2396 }
2397
2398 if (lNetworkEvents & FD_GROUP_QOS)
2399 {
2400 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_GROUP_QOS;
2401 }
2402
2403 /* Send IOCTL */
2404 Status = NtDeviceIoControlFile (SockAsyncHelperAfdHandle,
2405 NULL,
2406 NULL,
2407 AsyncData,
2408 &AsyncData->IoStatusBlock,
2409 IOCTL_AFD_SELECT,
2410 &AsyncData->AsyncSelectInfo,
2411 sizeof(AsyncData->AsyncSelectInfo),
2412 &AsyncData->AsyncSelectInfo,
2413 sizeof(AsyncData->AsyncSelectInfo));
2414
2415 /* I/O Manager Won't call the completion routine, let's do it manually */
2416 if (NT_SUCCESS(Status))
2417 {
2418 return;
2419 }
2420 else
2421 {
2422 AsyncData->IoStatusBlock.Status = Status;
2423 SockAsyncSelectCompletionRoutine(AsyncData, &AsyncData->IoStatusBlock);
2424 }
2425 }
2426
2427 VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2428 {
2429 PASYNC_DATA AsyncData = Context;
2430 BOOL FreeContext = TRUE;
2431 PSOCKET_INFORMATION Socket;
2432
2433 /* Get the Socket */
2434 Socket = AsyncData->ParentSocket;
2435
2436 /* If someone closed it, stop the function */
2437 if (Socket->SharedData.State != SocketClosed)
2438 {
2439 /* Check if the Sequence Number changed by now, in which case quit */
2440 if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber)
2441 {
2442 /* Do the actuall select, if needed */
2443 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)))
2444 {
2445 SockProcessAsyncSelect(Socket, AsyncData);
2446 FreeContext = FALSE;
2447 }
2448 }
2449 }
2450
2451 /* Free the Context */
2452 if (FreeContext)
2453 {
2454 HeapFree(GetProcessHeap(), 0, AsyncData);
2455 }
2456
2457 return;
2458 }
2459
2460 VOID
2461 SockReenableAsyncSelectEvent (IN PSOCKET_INFORMATION Socket,
2462 IN ULONG Event)
2463 {
2464 PASYNC_DATA AsyncData;
2465
2466 /* Make sure the event is actually disabled */
2467 if (!(Socket->SharedData.AsyncDisabledEvents & Event))
2468 {
2469 return;
2470 }
2471
2472 /* Re-enable it */
2473 Socket->SharedData.AsyncDisabledEvents &= ~Event;
2474
2475 /* Return if no more events are being polled */
2476 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2477 {
2478 return;
2479 }
2480
2481 /* Wait on new events */
2482 AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
2483
2484 /* Create the Asynch Thread if Needed */
2485 SockCreateOrReferenceAsyncThread();
2486
2487 /* Increase the sequence number to stop anything else */
2488 Socket->SharedData.SequenceNumber++;
2489
2490 /* Set up the Async Data */
2491 AsyncData->ParentSocket = Socket;
2492 AsyncData->SequenceNumber = Socket->SharedData.SequenceNumber;
2493
2494 /* Begin Async Select by using I/O Completion */
2495 NtSetIoCompletion(SockAsyncCompletionPort,
2496 (PVOID)&SockProcessQueuedAsyncSelect,
2497 AsyncData,
2498 0,
2499 0);
2500
2501 /* All done */
2502 return;
2503 }
2504
2505 BOOL
2506 WINAPI
2507 DllMain(HANDLE hInstDll,
2508 ULONG dwReason,
2509 PVOID Reserved)
2510 {
2511
2512 switch (dwReason)
2513 {
2514 case DLL_PROCESS_ATTACH:
2515
2516 AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
2517
2518 /* Don't need thread attach notifications
2519 so disable them to improve performance */
2520 DisableThreadLibraryCalls(hInstDll);
2521
2522 /* List of DLL Helpers */
2523 InitializeListHead(&SockHelpersListHead);
2524
2525 /* Heap to use when allocating */
2526 GlobalHeap = GetProcessHeap();
2527
2528 /* Allocate Heap for 1024 Sockets, can be expanded later */
2529 Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
2530
2531 AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
2532
2533 break;
2534
2535 case DLL_THREAD_ATTACH:
2536 break;
2537
2538 case DLL_THREAD_DETACH:
2539 break;
2540
2541 case DLL_PROCESS_DETACH:
2542 break;
2543 }
2544
2545 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
2546
2547 return TRUE;
2548 }
2549
2550 /* EOF */
2551
2552