Sync to trunk head (r40091)
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
4 * FILE: misc/dllmain.c
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7 * Alex Ionescu (alex@relsoft.net)
8 * REVISIONS:
9 * CSH 01/09-2000 Created
10 * Alex 16/07/2004 - Complete Rewrite
11 */
12
13 #include <msafd.h>
14
15 #include <debug.h>
16
17 #ifdef DBG
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
19 DWORD DebugTraceLevel = 0;
20 #endif /* DBG */
21
22 HANDLE GlobalHeap;
23 WSPUPCALLTABLE Upcalls;
24 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
25 ULONG SocketCount = 0;
26 PSOCKET_INFORMATION *Sockets = NULL;
27 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
28 ULONG SockAsyncThreadRefCount;
29 HANDLE SockAsyncHelperAfdHandle;
30 HANDLE SockAsyncCompletionPort;
31 BOOLEAN SockAsyncSelectCalled;
32
33
34
35 /*
36 * FUNCTION: Creates a new socket
37 * ARGUMENTS:
38 * af = Address family
39 * type = Socket type
40 * protocol = Protocol type
41 * lpProtocolInfo = Pointer to protocol information
42 * g = Reserved
43 * dwFlags = Socket flags
44 * lpErrno = Address of buffer for error information
45 * RETURNS:
46 * Created socket, or INVALID_SOCKET if it could not be created
47 */
48 SOCKET
49 WSPAPI
50 WSPSocket(int AddressFamily,
51 int SocketType,
52 int Protocol,
53 LPWSAPROTOCOL_INFOW lpProtocolInfo,
54 GROUP g,
55 DWORD dwFlags,
56 LPINT lpErrno)
57 {
58 OBJECT_ATTRIBUTES Object;
59 IO_STATUS_BLOCK IOSB;
60 USHORT SizeOfPacket;
61 ULONG SizeOfEA;
62 PAFD_CREATE_PACKET AfdPacket;
63 HANDLE Sock;
64 PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
65 PFILE_FULL_EA_INFORMATION EABuffer = NULL;
66 PHELPER_DATA HelperData;
67 PVOID HelperDLLContext;
68 DWORD HelperEvents;
69 UNICODE_STRING TransportName;
70 UNICODE_STRING DevName;
71 LARGE_INTEGER GroupData;
72 INT Status;
73
74 AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
75 AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
76 AddressFamily, SocketType, Protocol));
77
78 /* Get Helper Data and Transport */
79 Status = SockGetTdiName (&AddressFamily,
80 &SocketType,
81 &Protocol,
82 g,
83 dwFlags,
84 &TransportName,
85 &HelperDLLContext,
86 &HelperData,
87 &HelperEvents);
88
89 /* Check for error */
90 if (Status != NO_ERROR)
91 {
92 AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
93 goto error;
94 }
95
96 /* AFD Device Name */
97 RtlInitUnicodeString(&DevName, L"\\Device\\Afd\\Endpoint");
98
99 /* Set Socket Data */
100 Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
101 RtlZeroMemory(Socket, sizeof(*Socket));
102 Socket->RefCount = 2;
103 Socket->Handle = -1;
104 Socket->SharedData.Listening = FALSE;
105 Socket->SharedData.State = SocketOpen;
106 Socket->SharedData.AddressFamily = AddressFamily;
107 Socket->SharedData.SocketType = SocketType;
108 Socket->SharedData.Protocol = Protocol;
109 Socket->HelperContext = HelperDLLContext;
110 Socket->HelperData = HelperData;
111 Socket->HelperEvents = HelperEvents;
112 Socket->LocalAddress = &Socket->WSLocalAddress;
113 Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
114 Socket->RemoteAddress = &Socket->WSRemoteAddress;
115 Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
116 Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
117 Socket->SharedData.CreateFlags = dwFlags;
118 Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
119 Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
120 Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
121 Socket->SharedData.GroupID = g;
122 Socket->SharedData.GroupType = 0;
123 Socket->SharedData.UseSAN = FALSE;
124 Socket->SharedData.NonBlocking = FALSE; /* Sockets start blocking */
125 Socket->SanData = NULL;
126
127 /* Ask alex about this */
128 if( Socket->SharedData.SocketType == SOCK_DGRAM ||
129 Socket->SharedData.SocketType == SOCK_RAW )
130 {
131 AFD_DbgPrint(MID_TRACE,("Connectionless socket\n"));
132 Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
133 }
134
135 /* Packet Size */
136 SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
137
138 /* EA Size */
139 SizeOfEA = SizeOfPacket + sizeof(FILE_FULL_EA_INFORMATION) + AFD_PACKET_COMMAND_LENGTH;
140
141 /* Set up EA Buffer */
142 EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
143 RtlZeroMemory(EABuffer, SizeOfEA);
144 EABuffer->NextEntryOffset = 0;
145 EABuffer->Flags = 0;
146 EABuffer->EaNameLength = AFD_PACKET_COMMAND_LENGTH;
147 RtlCopyMemory (EABuffer->EaName,
148 AfdCommand,
149 AFD_PACKET_COMMAND_LENGTH + 1);
150 EABuffer->EaValueLength = SizeOfPacket;
151
152 /* Set up AFD Packet */
153 AfdPacket = (PAFD_CREATE_PACKET)(EABuffer->EaName + EABuffer->EaNameLength + 1);
154 AfdPacket->SizeOfTransportName = TransportName.Length;
155 RtlCopyMemory (AfdPacket->TransportName,
156 TransportName.Buffer,
157 TransportName.Length + sizeof(WCHAR));
158 AfdPacket->GroupID = g;
159
160 /* Set up Endpoint Flags */
161 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0)
162 {
163 if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW))
164 {
165 /* Only RAW or UDP can be Connectionless */
166 goto error;
167 }
168 AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
169 }
170
171 if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0)
172 {
173 if (SocketType == SOCK_STREAM)
174 {
175 if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
176 {
177 /* The Provider doesn't actually support Message Oriented Streams */
178 goto error;
179 }
180 }
181 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MESSAGE_ORIENTED;
182 }
183
184 if (SocketType == SOCK_RAW) AfdPacket->EndpointFlags |= AFD_ENDPOINT_RAW;
185
186 if (dwFlags & (WSA_FLAG_MULTIPOINT_C_ROOT |
187 WSA_FLAG_MULTIPOINT_C_LEAF |
188 WSA_FLAG_MULTIPOINT_D_ROOT |
189 WSA_FLAG_MULTIPOINT_D_LEAF))
190 {
191 if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
192 {
193 /* The Provider doesn't actually support Multipoint */
194 goto error;
195 }
196 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
197
198 if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT)
199 {
200 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
201 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0))
202 {
203 /* The Provider doesn't support Control Planes, or you already gave a leaf */
204 goto error;
205 }
206 AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
207 }
208
209 if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT)
210 {
211 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
212 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0))
213 {
214 /* The Provider doesn't support Data Planes, or you already gave a leaf */
215 goto error;
216 }
217 AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
218 }
219 }
220
221 /* Set up Object Attributes */
222 InitializeObjectAttributes (&Object,
223 &DevName,
224 OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
225 0,
226 0);
227
228 /* Create the Socket as asynchronous. That means we have to block
229 ourselves after every call to NtDeviceIoControlFile. This is
230 because the kernel doesn't support overlapping synchronous I/O
231 requests (made from multiple threads) at this time (Sep 2005) */
232 ZwCreateFile(&Sock,
233 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
234 &Object,
235 &IOSB,
236 NULL,
237 0,
238 FILE_SHARE_READ | FILE_SHARE_WRITE, FILE_OPEN_IF,
239 0,
240 EABuffer,
241 SizeOfEA);
242
243 /* Save Handle */
244 Socket->Handle = (SOCKET)Sock;
245
246 /* XXX See if there's a structure we can reuse -- We need to do this
247 * more properly. */
248 PrevSocket = GetSocketStructure( (SOCKET)Sock );
249
250 if( PrevSocket )
251 {
252 RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
253 RtlFreeHeap( GlobalHeap, 0, Socket );
254 Socket = PrevSocket;
255 }
256
257 /* Save Group Info */
258 if (g != 0)
259 {
260 GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
261 Socket->SharedData.GroupID = GroupData.u.LowPart;
262 Socket->SharedData.GroupType = GroupData.u.HighPart;
263 }
264
265 /* Get Window Sizes and Save them */
266 GetSocketInformation (Socket,
267 AFD_INFO_SEND_WINDOW_SIZE,
268 &Socket->SharedData.SizeOfSendBuffer,
269 NULL);
270
271 GetSocketInformation (Socket,
272 AFD_INFO_RECEIVE_WINDOW_SIZE,
273 &Socket->SharedData.SizeOfRecvBuffer,
274 NULL);
275
276 /* Save in Process Sockets List */
277 Sockets[SocketCount] = Socket;
278 SocketCount ++;
279
280 /* Create the Socket Context */
281 CreateContext(Socket);
282
283 /* Notify Winsock */
284 Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
285
286 /* Return Socket Handle */
287 AFD_DbgPrint(MID_TRACE,("Success %x\n", Sock));
288
289 return (SOCKET)Sock;
290
291 error:
292 AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
293
294 if( lpErrno )
295 *lpErrno = Status;
296
297 return INVALID_SOCKET;
298 }
299
300
301 DWORD MsafdReturnWithErrno(NTSTATUS Status,
302 LPINT Errno,
303 DWORD Received,
304 LPDWORD ReturnedBytes)
305 {
306 if( ReturnedBytes )
307 *ReturnedBytes = 0;
308 if( Errno )
309 {
310 switch (Status)
311 {
312 case STATUS_CANT_WAIT:
313 *Errno = WSAEWOULDBLOCK;
314 break;
315 case STATUS_TIMEOUT:
316 *Errno = WSAETIMEDOUT;
317 break;
318 case STATUS_SUCCESS:
319 /* Return Number of bytes Read */
320 if( ReturnedBytes )
321 *ReturnedBytes = Received;
322 break;
323 case STATUS_END_OF_FILE:
324 *Errno = WSAESHUTDOWN;
325 break;
326 case STATUS_PENDING:
327 *Errno = WSA_IO_PENDING;
328 break;
329 case STATUS_BUFFER_OVERFLOW:
330 DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
331 *Errno = WSAEMSGSIZE;
332 break;
333 case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
334 case STATUS_INSUFFICIENT_RESOURCES:
335 DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
336 *Errno = WSA_NOT_ENOUGH_MEMORY;
337 break;
338 case STATUS_INVALID_CONNECTION:
339 DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
340 *Errno = WSAEAFNOSUPPORT;
341 break;
342 case STATUS_REMOTE_NOT_LISTENING:
343 DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
344 *Errno = WSAECONNRESET;
345 break;
346 case STATUS_FILE_CLOSED:
347 DbgPrint("MSAFD: STATUS_FILE_CLOSED\n");
348 *Errno = WSAENOTSOCK;
349 break;
350 case STATUS_INVALID_PARAMETER:
351 DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
352 *Errno = WSAEINVAL;
353 break;
354 case STATUS_CANCELLED:
355 DbgPrint("MSAFD: STATUS_CANCELLED\n");
356 *Errno = WSAENOTSOCK;
357 break;
358 default:
359 DbgPrint("MSAFD: Error %x is unknown\n", Status);
360 *Errno = WSAEINVAL;
361 break;
362 }
363 }
364
365 /* Success */
366 return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
367 }
368
369 /*
370 * FUNCTION: Closes an open socket
371 * ARGUMENTS:
372 * s = Socket descriptor
373 * lpErrno = Address of buffer for error information
374 * RETURNS:
375 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
376 */
377 INT
378 WSPAPI
379 WSPCloseSocket(IN SOCKET Handle,
380 OUT LPINT lpErrno)
381 {
382 IO_STATUS_BLOCK IoStatusBlock;
383 PSOCKET_INFORMATION Socket = NULL;
384 NTSTATUS Status;
385 HANDLE SockEvent;
386 AFD_DISCONNECT_INFO DisconnectInfo;
387 SOCKET_STATE OldState;
388
389 /* Create the Wait Event */
390 Status = NtCreateEvent(&SockEvent,
391 GENERIC_READ | GENERIC_WRITE,
392 NULL,
393 1,
394 FALSE);
395
396 if(!NT_SUCCESS(Status))
397 return SOCKET_ERROR;
398
399 /* Get the Socket Structure associate to this Socket*/
400 Socket = GetSocketStructure(Handle);
401
402 /* If a Close is already in Process, give up */
403 if (Socket->SharedData.State == SocketClosed)
404 {
405 *lpErrno = WSAENOTSOCK;
406 return SOCKET_ERROR;
407 }
408
409 /* Set the state to close */
410 OldState = Socket->SharedData.State;
411 Socket->SharedData.State = SocketClosed;
412
413 /* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
414 /* FIXME: Should we do this on Datagram Sockets too? */
415 if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
416 {
417 ULONG LingerWait;
418 ULONG SendsInProgress;
419 ULONG SleepWait;
420
421 /* We need to respect the timeout */
422 SleepWait = 100;
423 LingerWait = Socket->SharedData.LingerData.l_linger * 1000;
424
425 /* Loop until no more sends are pending, within the timeout */
426 while (LingerWait)
427 {
428 /* Find out how many Sends are in Progress */
429 if (GetSocketInformation(Socket,
430 AFD_INFO_SENDS_IN_PROGRESS,
431 &SendsInProgress,
432 NULL))
433 {
434 /* Bail out if anything but NO_ERROR */
435 LingerWait = 0;
436 break;
437 }
438
439 /* Bail out if no more sends are pending */
440 if (!SendsInProgress)
441 break;
442 /*
443 * We have to execute a sleep, so it's kind of like
444 * a block. If the socket is Nonblock, we cannot
445 * go on since asyncronous operation is expected
446 * and we cannot offer it
447 */
448 if (Socket->SharedData.NonBlocking)
449 {
450 Socket->SharedData.State = OldState;
451 *lpErrno = WSAEWOULDBLOCK;
452 return SOCKET_ERROR;
453 }
454
455 /* Now we can sleep, and decrement the linger wait */
456 /*
457 * FIXME: It seems Windows does some funky acceleration
458 * since the waiting seems to be longer and longer. I
459 * don't think this improves performance so much, so we
460 * wait a fixed time instead.
461 */
462 Sleep(SleepWait);
463 LingerWait -= SleepWait;
464 }
465
466 /*
467 * We have reached the timeout or sends are over.
468 * Disconnect if the timeout has been reached.
469 */
470 if (LingerWait <= 0)
471 {
472 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
473 DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
474
475 /* Send IOCTL */
476 Status = NtDeviceIoControlFile((HANDLE)Handle,
477 SockEvent,
478 NULL,
479 NULL,
480 &IoStatusBlock,
481 IOCTL_AFD_DISCONNECT,
482 &DisconnectInfo,
483 sizeof(DisconnectInfo),
484 NULL,
485 0);
486
487 /* Wait for return */
488 if (Status == STATUS_PENDING)
489 {
490 WaitForSingleObject(SockEvent, INFINITE);
491 }
492 }
493 }
494
495 /* FIXME: We should notify the Helper DLL of WSH_NOTIFY_CLOSE */
496
497 /* Cleanup Time! */
498 Socket->HelperContext = NULL;
499 Socket->SharedData.AsyncDisabledEvents = -1;
500 NtClose(Socket->TdiAddressHandle);
501 Socket->TdiAddressHandle = NULL;
502 NtClose(Socket->TdiConnectionHandle);
503 Socket->TdiConnectionHandle = NULL;
504
505 /* Close the handle */
506 NtClose((HANDLE)Handle);
507
508 return NO_ERROR;
509 }
510
511
512 /*
513 * FUNCTION: Associates a local address with a socket
514 * ARGUMENTS:
515 * s = Socket descriptor
516 * name = Pointer to local address
517 * namelen = Length of name
518 * lpErrno = Address of buffer for error information
519 * RETURNS:
520 * 0, or SOCKET_ERROR if the socket could not be bound
521 */
522 INT
523 WSPAPI
524 WSPBind(SOCKET Handle,
525 const struct sockaddr *SocketAddress,
526 int SocketAddressLength,
527 LPINT lpErrno)
528 {
529 IO_STATUS_BLOCK IOSB;
530 PAFD_BIND_DATA BindData;
531 PSOCKET_INFORMATION Socket = NULL;
532 NTSTATUS Status;
533 UCHAR BindBuffer[0x1A];
534 SOCKADDR_INFO SocketInfo;
535 HANDLE SockEvent;
536
537 Status = NtCreateEvent(&SockEvent,
538 GENERIC_READ | GENERIC_WRITE,
539 NULL,
540 1,
541 FALSE);
542
543 if( !NT_SUCCESS(Status) )
544 return -1;
545
546 /* Get the Socket Structure associate to this Socket*/
547 Socket = GetSocketStructure(Handle);
548
549 /* Dynamic Structure...ugh */
550 BindData = (PAFD_BIND_DATA)BindBuffer;
551
552 /* Set up Address in TDI Format */
553 BindData->Address.TAAddressCount = 1;
554 BindData->Address.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
555 BindData->Address.Address[0].AddressType = SocketAddress->sa_family;
556 RtlCopyMemory (BindData->Address.Address[0].Address,
557 SocketAddress->sa_data,
558 SocketAddressLength - sizeof(SocketAddress->sa_family));
559
560 /* Get Address Information */
561 Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
562 SocketAddressLength,
563 &SocketInfo);
564
565 /* Set the Share Type */
566 if (Socket->SharedData.ExclusiveAddressUse)
567 {
568 BindData->ShareType = AFD_SHARE_EXCLUSIVE;
569 }
570 else if (SocketInfo.EndpointInfo == SockaddrEndpointInfoWildcard)
571 {
572 BindData->ShareType = AFD_SHARE_WILDCARD;
573 }
574 else if (Socket->SharedData.ReuseAddresses)
575 {
576 BindData->ShareType = AFD_SHARE_REUSE;
577 }
578 else
579 {
580 BindData->ShareType = AFD_SHARE_UNIQUE;
581 }
582
583 /* Send IOCTL */
584 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
585 SockEvent,
586 NULL,
587 NULL,
588 &IOSB,
589 IOCTL_AFD_BIND,
590 BindData,
591 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
592 BindData,
593 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
594
595 /* Wait for return */
596 if (Status == STATUS_PENDING)
597 {
598 WaitForSingleObject(SockEvent, INFINITE);
599 Status = IOSB.Status;
600 }
601
602 /* Set up Socket Data */
603 Socket->SharedData.State = SocketBound;
604 Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
605
606 NtClose( SockEvent );
607
608 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
609 }
610
611 int
612 WSPAPI
613 WSPListen(SOCKET Handle,
614 int Backlog,
615 LPINT lpErrno)
616 {
617 IO_STATUS_BLOCK IOSB;
618 AFD_LISTEN_DATA ListenData;
619 PSOCKET_INFORMATION Socket = NULL;
620 HANDLE SockEvent;
621 NTSTATUS Status;
622
623 /* Get the Socket Structure associate to this Socket*/
624 Socket = GetSocketStructure(Handle);
625
626 if (Socket->SharedData.Listening)
627 return 0;
628
629 Status = NtCreateEvent(&SockEvent,
630 GENERIC_READ | GENERIC_WRITE,
631 NULL,
632 1,
633 FALSE);
634
635 if( !NT_SUCCESS(Status) )
636 return -1;
637
638 /* Set Up Listen Structure */
639 ListenData.UseSAN = FALSE;
640 ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
641 ListenData.Backlog = Backlog;
642
643 /* Send IOCTL */
644 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
645 SockEvent,
646 NULL,
647 NULL,
648 &IOSB,
649 IOCTL_AFD_START_LISTEN,
650 &ListenData,
651 sizeof(ListenData),
652 NULL,
653 0);
654
655 /* Wait for return */
656 if (Status == STATUS_PENDING)
657 {
658 WaitForSingleObject(SockEvent, INFINITE);
659 Status = IOSB.Status;
660 }
661
662 /* Set to Listening */
663 Socket->SharedData.Listening = TRUE;
664
665 NtClose( SockEvent );
666
667 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
668 }
669
670
671 int
672 WSPAPI
673 WSPSelect(int nfds,
674 fd_set *readfds,
675 fd_set *writefds,
676 fd_set *exceptfds,
677 struct timeval *timeout,
678 LPINT lpErrno)
679 {
680 IO_STATUS_BLOCK IOSB;
681 PAFD_POLL_INFO PollInfo;
682 NTSTATUS Status;
683 LONG HandleCount, OutCount = 0;
684 ULONG PollBufferSize;
685 PVOID PollBuffer;
686 ULONG i, j = 0, x;
687 HANDLE SockEvent;
688 BOOL HandleCounted;
689 LARGE_INTEGER Timeout;
690
691 /* Find out how many sockets we have, and how large the buffer needs
692 * to be */
693
694 HandleCount = ( readfds ? readfds->fd_count : 0 ) +
695 ( writefds ? writefds->fd_count : 0 ) +
696 ( exceptfds ? exceptfds->fd_count : 0 );
697
698 if( HandleCount < 0 || nfds != 0 )
699 HandleCount = nfds * 3;
700
701 PollBufferSize = sizeof(*PollInfo) + (HandleCount * sizeof(AFD_HANDLE));
702
703 AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n",
704 HandleCount, PollBufferSize));
705
706 /* Convert Timeout to NT Format */
707 if (timeout == NULL)
708 {
709 Timeout.u.LowPart = -1;
710 Timeout.u.HighPart = 0x7FFFFFFF;
711 AFD_DbgPrint(MAX_TRACE,("Infinite timeout\n"));
712 }
713 else
714 {
715 Timeout = RtlEnlargedIntegerMultiply
716 ((timeout->tv_sec * 1000) + (timeout->tv_usec / 1000), -10000);
717 /* Negative timeouts are illegal. Since the kernel represents an
718 * incremental timeout as a negative number, we check for a positive
719 * result.
720 */
721 if (Timeout.QuadPart > 0)
722 {
723 if (lpErrno) *lpErrno = WSAEINVAL;
724 return SOCKET_ERROR;
725 }
726 AFD_DbgPrint(MAX_TRACE,("Timeout: Orig %d.%06d kernel %d\n",
727 timeout->tv_sec, timeout->tv_usec,
728 Timeout.u.LowPart));
729 }
730
731 Status = NtCreateEvent(&SockEvent,
732 GENERIC_READ | GENERIC_WRITE,
733 NULL,
734 1,
735 FALSE);
736
737 if( !NT_SUCCESS(Status) )
738 return SOCKET_ERROR;
739
740 /* Allocate */
741 PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
742
743 if (!PollBuffer)
744 {
745 if (lpErrno)
746 *lpErrno = WSAEFAULT;
747 NtClose(SockEvent);
748 return SOCKET_ERROR;
749 }
750
751 PollInfo = (PAFD_POLL_INFO)PollBuffer;
752
753 RtlZeroMemory( PollInfo, PollBufferSize );
754
755 /* Number of handles for AFD to Check */
756 PollInfo->Exclusive = FALSE;
757 PollInfo->Timeout = Timeout;
758
759 if (readfds != NULL) {
760 for (i = 0; i < readfds->fd_count; i++, j++)
761 {
762 PollInfo->Handles[j].Handle = readfds->fd_array[i];
763 PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
764 AFD_EVENT_DISCONNECT |
765 AFD_EVENT_ABORT |
766 AFD_EVENT_ACCEPT;
767 }
768 }
769 if (writefds != NULL)
770 {
771 for (i = 0; i < writefds->fd_count; i++, j++)
772 {
773 PollInfo->Handles[j].Handle = writefds->fd_array[i];
774 PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
775 }
776 }
777 if (exceptfds != NULL)
778 {
779 for (i = 0; i < exceptfds->fd_count; i++, j++)
780 {
781 PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
782 PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
783 }
784 }
785
786 PollInfo->HandleCount = j;
787 PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
788
789 /* Send IOCTL */
790 Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
791 SockEvent,
792 NULL,
793 NULL,
794 &IOSB,
795 IOCTL_AFD_SELECT,
796 PollInfo,
797 PollBufferSize,
798 PollInfo,
799 PollBufferSize);
800
801 AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
802
803 /* Wait for Completition */
804 if (Status == STATUS_PENDING)
805 {
806 WaitForSingleObject(SockEvent, INFINITE);
807 }
808
809 /* Clear the Structures */
810 if( readfds )
811 FD_ZERO(readfds);
812 if( writefds )
813 FD_ZERO(writefds);
814 if( exceptfds )
815 FD_ZERO(exceptfds);
816
817 /* Loop through return structure */
818 HandleCount = PollInfo->HandleCount;
819
820 /* Return in FDSET Format */
821 for (i = 0; i < HandleCount; i++)
822 {
823 HandleCounted = FALSE;
824 for(x = 1; x; x<<=1)
825 {
826 switch (PollInfo->Handles[i].Events & x)
827 {
828 case AFD_EVENT_RECEIVE:
829 case AFD_EVENT_DISCONNECT:
830 case AFD_EVENT_ABORT:
831 case AFD_EVENT_ACCEPT:
832 case AFD_EVENT_CLOSE:
833 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
834 PollInfo->Handles[i].Events,
835 PollInfo->Handles[i].Handle));
836 if (! HandleCounted)
837 {
838 OutCount++;
839 HandleCounted = TRUE;
840 }
841 if( readfds )
842 FD_SET(PollInfo->Handles[i].Handle, readfds);
843 break;
844 case AFD_EVENT_SEND:
845 case AFD_EVENT_CONNECT:
846 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
847 PollInfo->Handles[i].Events,
848 PollInfo->Handles[i].Handle));
849 if (! HandleCounted)
850 {
851 OutCount++;
852 HandleCounted = TRUE;
853 }
854 if( writefds )
855 FD_SET(PollInfo->Handles[i].Handle, writefds);
856 break;
857 case AFD_EVENT_OOB_RECEIVE:
858 case AFD_EVENT_CONNECT_FAIL:
859 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
860 PollInfo->Handles[i].Events,
861 PollInfo->Handles[i].Handle));
862 if (! HandleCounted)
863 {
864 OutCount++;
865 HandleCounted = TRUE;
866 }
867 if( exceptfds )
868 FD_SET(PollInfo->Handles[i].Handle, exceptfds);
869 break;
870 }
871 }
872 }
873
874 HeapFree( GlobalHeap, 0, PollBuffer );
875 NtClose( SockEvent );
876
877 if( lpErrno )
878 {
879 switch( IOSB.Status )
880 {
881 case STATUS_SUCCESS:
882 case STATUS_TIMEOUT:
883 *lpErrno = 0;
884 break;
885 default:
886 *lpErrno = WSAEINVAL;
887 break;
888 }
889 AFD_DbgPrint(MID_TRACE,("*lpErrno = %x\n", *lpErrno));
890 }
891
892 AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
893
894 return OutCount;
895 }
896
897 SOCKET
898 WSPAPI
899 WSPAccept(SOCKET Handle,
900 struct sockaddr *SocketAddress,
901 int *SocketAddressLength,
902 LPCONDITIONPROC lpfnCondition,
903 DWORD_PTR dwCallbackData,
904 LPINT lpErrno)
905 {
906 IO_STATUS_BLOCK IOSB;
907 PAFD_RECEIVED_ACCEPT_DATA ListenReceiveData;
908 AFD_ACCEPT_DATA AcceptData;
909 AFD_DEFER_ACCEPT_DATA DeferData;
910 AFD_PENDING_ACCEPT_DATA PendingAcceptData;
911 PSOCKET_INFORMATION Socket = NULL;
912 NTSTATUS Status;
913 struct fd_set ReadSet;
914 struct timeval Timeout;
915 PVOID PendingData = NULL;
916 ULONG PendingDataLength = 0;
917 PVOID CalleeDataBuffer;
918 WSABUF CallerData, CalleeID, CallerID, CalleeData;
919 PSOCKADDR RemoteAddress = NULL;
920 GROUP GroupID = 0;
921 ULONG CallBack;
922 WSAPROTOCOL_INFOW ProtocolInfo;
923 SOCKET AcceptSocket;
924 UCHAR ReceiveBuffer[0x1A];
925 HANDLE SockEvent;
926
927 Status = NtCreateEvent(&SockEvent,
928 GENERIC_READ | GENERIC_WRITE,
929 NULL,
930 1,
931 FALSE);
932
933 if( !NT_SUCCESS(Status) )
934 {
935 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
936 return INVALID_SOCKET;
937 }
938
939 /* Dynamic Structure...ugh */
940 ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
941
942 /* Get the Socket Structure associate to this Socket*/
943 Socket = GetSocketStructure(Handle);
944
945 /* If this is non-blocking, make sure there's something for us to accept */
946 FD_ZERO(&ReadSet);
947 FD_SET(Socket->Handle, &ReadSet);
948 Timeout.tv_sec=0;
949 Timeout.tv_usec=0;
950
951 WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
952
953 if (ReadSet.fd_array[0] != Socket->Handle)
954 return 0;
955
956 /* Send IOCTL */
957 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
958 SockEvent,
959 NULL,
960 NULL,
961 &IOSB,
962 IOCTL_AFD_WAIT_FOR_LISTEN,
963 NULL,
964 0,
965 ListenReceiveData,
966 0xA + sizeof(*ListenReceiveData));
967
968 /* Wait for return */
969 if (Status == STATUS_PENDING)
970 {
971 WaitForSingleObject(SockEvent, INFINITE);
972 Status = IOSB.Status;
973 }
974
975 if (!NT_SUCCESS(Status))
976 {
977 NtClose( SockEvent );
978 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
979 return INVALID_SOCKET;
980 }
981
982 if (lpfnCondition != NULL)
983 {
984 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0)
985 {
986 /* Find out how much data is pending */
987 PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
988 PendingAcceptData.ReturnSize = TRUE;
989
990 /* Send IOCTL */
991 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
992 SockEvent,
993 NULL,
994 NULL,
995 &IOSB,
996 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
997 &PendingAcceptData,
998 sizeof(PendingAcceptData),
999 &PendingAcceptData,
1000 sizeof(PendingAcceptData));
1001
1002 /* Wait for return */
1003 if (Status == STATUS_PENDING)
1004 {
1005 WaitForSingleObject(SockEvent, INFINITE);
1006 Status = IOSB.Status;
1007 }
1008
1009 if (!NT_SUCCESS(Status))
1010 {
1011 NtClose( SockEvent );
1012 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1013 return INVALID_SOCKET;
1014 }
1015
1016 /* How much data to allocate */
1017 PendingDataLength = IOSB.Information;
1018
1019 if (PendingDataLength)
1020 {
1021 /* Allocate needed space */
1022 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
1023
1024 /* We want the data now */
1025 PendingAcceptData.ReturnSize = FALSE;
1026
1027 /* Send IOCTL */
1028 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1029 SockEvent,
1030 NULL,
1031 NULL,
1032 &IOSB,
1033 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1034 &PendingAcceptData,
1035 sizeof(PendingAcceptData),
1036 PendingData,
1037 PendingDataLength);
1038
1039 /* Wait for return */
1040 if (Status == STATUS_PENDING)
1041 {
1042 WaitForSingleObject(SockEvent, INFINITE);
1043 Status = IOSB.Status;
1044 }
1045
1046 if (!NT_SUCCESS(Status))
1047 {
1048 NtClose( SockEvent );
1049 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1050 return INVALID_SOCKET;
1051 }
1052 }
1053 }
1054
1055 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1056 {
1057 /* I don't support this yet */
1058 }
1059
1060 /* Build Callee ID */
1061 CalleeID.buf = (PVOID)Socket->LocalAddress;
1062 CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
1063
1064 /* Set up Address in SOCKADDR Format */
1065 RtlCopyMemory (RemoteAddress,
1066 &ListenReceiveData->Address.Address[0].AddressType,
1067 sizeof(*RemoteAddress));
1068
1069 /* Build Caller ID */
1070 CallerID.buf = (PVOID)RemoteAddress;
1071 CallerID.len = sizeof(*RemoteAddress);
1072
1073 /* Build Caller Data */
1074 CallerData.buf = PendingData;
1075 CallerData.len = PendingDataLength;
1076
1077 /* Check if socket supports Conditional Accept */
1078 if (Socket->SharedData.UseDelayedAcceptance != 0)
1079 {
1080 /* Allocate Buffer for Callee Data */
1081 CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
1082 CalleeData.buf = CalleeDataBuffer;
1083 CalleeData.len = 4096;
1084 }
1085 else
1086 {
1087 /* Nothing */
1088 CalleeData.buf = 0;
1089 CalleeData.len = 0;
1090 }
1091
1092 /* Call the Condition Function */
1093 CallBack = (lpfnCondition)(&CallerID,
1094 CallerData.buf == NULL ? NULL : &CallerData,
1095 NULL,
1096 NULL,
1097 &CalleeID,
1098 CalleeData.buf == NULL ? NULL : &CalleeData,
1099 &GroupID,
1100 dwCallbackData);
1101
1102 if (((CallBack == CF_ACCEPT) && GroupID) != 0)
1103 {
1104 /* TBD: Check for Validity */
1105 }
1106
1107 if (CallBack == CF_ACCEPT)
1108 {
1109 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1110 {
1111 /* I don't support this yet */
1112 }
1113 if (CalleeData.buf)
1114 {
1115 // SockSetConnectData Sockets(SocketID), IOCTL_AFD_SET_CONNECT_DATA, CalleeData.Buffer, CalleeData.BuffSize, 0
1116 }
1117 }
1118 else
1119 {
1120 /* Callback rejected. Build Defer Structure */
1121 DeferData.SequenceNumber = ListenReceiveData->SequenceNumber;
1122 DeferData.RejectConnection = (CallBack == CF_REJECT);
1123
1124 /* Send IOCTL */
1125 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1126 SockEvent,
1127 NULL,
1128 NULL,
1129 &IOSB,
1130 IOCTL_AFD_DEFER_ACCEPT,
1131 &DeferData,
1132 sizeof(DeferData),
1133 NULL,
1134 0);
1135
1136 /* Wait for return */
1137 if (Status == STATUS_PENDING)
1138 {
1139 WaitForSingleObject(SockEvent, INFINITE);
1140 Status = IOSB.Status;
1141 }
1142
1143 NtClose( SockEvent );
1144
1145 if (!NT_SUCCESS(Status))
1146 {
1147 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1148 return INVALID_SOCKET;
1149 }
1150
1151 if (CallBack == CF_REJECT )
1152 {
1153 *lpErrno = WSAECONNREFUSED;
1154 return INVALID_SOCKET;
1155 }
1156 else
1157 {
1158 *lpErrno = WSAECONNREFUSED;
1159 return INVALID_SOCKET;
1160 }
1161 }
1162 }
1163
1164 /* Create a new Socket */
1165 ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
1166 ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
1167 ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
1168
1169 AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
1170 Socket->SharedData.SocketType,
1171 Socket->SharedData.Protocol,
1172 &ProtocolInfo,
1173 GroupID,
1174 Socket->SharedData.CreateFlags,
1175 NULL);
1176
1177 /* Set up the Accept Structure */
1178 AcceptData.ListenHandle = AcceptSocket;
1179 AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1180
1181 /* Send IOCTL to Accept */
1182 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1183 SockEvent,
1184 NULL,
1185 NULL,
1186 &IOSB,
1187 IOCTL_AFD_ACCEPT,
1188 &AcceptData,
1189 sizeof(AcceptData),
1190 NULL,
1191 0);
1192
1193 /* Wait for return */
1194 if (Status == STATUS_PENDING)
1195 {
1196 WaitForSingleObject(SockEvent, INFINITE);
1197 Status = IOSB.Status;
1198 }
1199
1200 if (!NT_SUCCESS(Status))
1201 {
1202 WSPCloseSocket( AcceptSocket, lpErrno );
1203 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1204 return INVALID_SOCKET;
1205 }
1206
1207 /* Return Address in SOCKADDR FORMAT */
1208 if( SocketAddress )
1209 {
1210 RtlCopyMemory (SocketAddress,
1211 &ListenReceiveData->Address.Address[0].AddressType,
1212 sizeof(*RemoteAddress));
1213 if( SocketAddressLength )
1214 *SocketAddressLength = ListenReceiveData->Address.Address[0].AddressLength;
1215 }
1216
1217 NtClose( SockEvent );
1218
1219 /* Re-enable Async Event */
1220 SockReenableAsyncSelectEvent(Socket, FD_ACCEPT);
1221
1222 AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
1223
1224 *lpErrno = 0;
1225
1226 /* Return Socket */
1227 return AcceptSocket;
1228 }
1229
1230 int
1231 WSPAPI
1232 WSPConnect(SOCKET Handle,
1233 const struct sockaddr * SocketAddress,
1234 int SocketAddressLength,
1235 LPWSABUF lpCallerData,
1236 LPWSABUF lpCalleeData,
1237 LPQOS lpSQOS,
1238 LPQOS lpGQOS,
1239 LPINT lpErrno)
1240 {
1241 IO_STATUS_BLOCK IOSB;
1242 PAFD_CONNECT_INFO ConnectInfo;
1243 PSOCKET_INFORMATION Socket = NULL;
1244 NTSTATUS Status;
1245 UCHAR ConnectBuffer[0x22];
1246 ULONG ConnectDataLength;
1247 ULONG InConnectDataLength;
1248 INT BindAddressLength;
1249 PSOCKADDR BindAddress;
1250 HANDLE SockEvent;
1251
1252 Status = NtCreateEvent(&SockEvent,
1253 GENERIC_READ | GENERIC_WRITE,
1254 NULL,
1255 1,
1256 FALSE);
1257
1258 if( !NT_SUCCESS(Status) )
1259 return -1;
1260
1261 AFD_DbgPrint(MID_TRACE,("Called\n"));
1262
1263 /* Get the Socket Structure associate to this Socket*/
1264 Socket = GetSocketStructure(Handle);
1265
1266 /* Bind us First */
1267 if (Socket->SharedData.State == SocketOpen)
1268 {
1269 /* Get the Wildcard Address */
1270 BindAddressLength = Socket->HelperData->MaxWSAddressLength;
1271 BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
1272 Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
1273 BindAddress,
1274 &BindAddressLength);
1275 /* Bind it */
1276 WSPBind(Handle, BindAddress, BindAddressLength, NULL);
1277 }
1278
1279 /* Set the Connect Data */
1280 if (lpCallerData != NULL)
1281 {
1282 ConnectDataLength = lpCallerData->len;
1283 Status = NtDeviceIoControlFile((HANDLE)Handle,
1284 SockEvent,
1285 NULL,
1286 NULL,
1287 &IOSB,
1288 IOCTL_AFD_SET_CONNECT_DATA,
1289 lpCallerData->buf,
1290 ConnectDataLength,
1291 NULL,
1292 0);
1293 /* Wait for return */
1294 if (Status == STATUS_PENDING)
1295 {
1296 WaitForSingleObject(SockEvent, INFINITE);
1297 Status = IOSB.Status;
1298 }
1299 }
1300
1301 /* Dynamic Structure...ugh */
1302 ConnectInfo = (PAFD_CONNECT_INFO)ConnectBuffer;
1303
1304 /* Set up Address in TDI Format */
1305 ConnectInfo->RemoteAddress.TAAddressCount = 1;
1306 ConnectInfo->RemoteAddress.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
1307 ConnectInfo->RemoteAddress.Address[0].AddressType = SocketAddress->sa_family;
1308 RtlCopyMemory (ConnectInfo->RemoteAddress.Address[0].Address,
1309 SocketAddress->sa_data,
1310 SocketAddressLength - sizeof(SocketAddress->sa_family));
1311
1312 /*
1313 * Disable FD_WRITE and FD_CONNECT
1314 * The latter fixes a race condition where the FD_CONNECT is re-enabled
1315 * at the end of this function right after the Async Thread disables it.
1316 * This should only happen at the *next* WSPConnect
1317 */
1318 if (Socket->SharedData.AsyncEvents & FD_CONNECT)
1319 {
1320 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
1321 }
1322
1323 /* Tell AFD that we want Connection Data back, have it allocate a buffer */
1324 if (lpCalleeData != NULL)
1325 {
1326 InConnectDataLength = lpCalleeData->len;
1327 Status = NtDeviceIoControlFile((HANDLE)Handle,
1328 SockEvent,
1329 NULL,
1330 NULL,
1331 &IOSB,
1332 IOCTL_AFD_SET_CONNECT_DATA_SIZE,
1333 &InConnectDataLength,
1334 sizeof(InConnectDataLength),
1335 NULL,
1336 0);
1337
1338 /* Wait for return */
1339 if (Status == STATUS_PENDING)
1340 {
1341 WaitForSingleObject(SockEvent, INFINITE);
1342 Status = IOSB.Status;
1343 }
1344 }
1345
1346 /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
1347 ConnectInfo->Root = 0;
1348 ConnectInfo->UseSAN = FALSE;
1349 ConnectInfo->Unknown = 0;
1350
1351 /* FIXME: Handle Async Connect */
1352 if (Socket->SharedData.NonBlocking)
1353 {
1354 AFD_DbgPrint(MIN_TRACE, ("Async Connect UNIMPLEMENTED!\n"));
1355 }
1356
1357 /* Send IOCTL */
1358 Status = NtDeviceIoControlFile((HANDLE)Handle,
1359 SockEvent,
1360 NULL,
1361 NULL,
1362 &IOSB,
1363 IOCTL_AFD_CONNECT,
1364 ConnectInfo,
1365 0x22,
1366 NULL,
1367 0);
1368 /* Wait for return */
1369 if (Status == STATUS_PENDING)
1370 {
1371 WaitForSingleObject(SockEvent, INFINITE);
1372 Status = IOSB.Status;
1373 }
1374
1375 /* Get any pending connect data */
1376 if (lpCalleeData != NULL)
1377 {
1378 Status = NtDeviceIoControlFile((HANDLE)Handle,
1379 SockEvent,
1380 NULL,
1381 NULL,
1382 &IOSB,
1383 IOCTL_AFD_GET_CONNECT_DATA,
1384 NULL,
1385 0,
1386 lpCalleeData->buf,
1387 lpCalleeData->len);
1388 /* Wait for return */
1389 if (Status == STATUS_PENDING)
1390 {
1391 WaitForSingleObject(SockEvent, INFINITE);
1392 Status = IOSB.Status;
1393 }
1394 }
1395
1396 /* Re-enable Async Event */
1397 SockReenableAsyncSelectEvent(Socket, FD_WRITE);
1398
1399 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
1400 SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
1401
1402 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1403
1404 NtClose( SockEvent );
1405
1406 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1407 }
1408 int
1409 WSPAPI
1410 WSPShutdown(SOCKET Handle,
1411 int HowTo,
1412 LPINT lpErrno)
1413
1414 {
1415 IO_STATUS_BLOCK IOSB;
1416 AFD_DISCONNECT_INFO DisconnectInfo;
1417 PSOCKET_INFORMATION Socket = NULL;
1418 NTSTATUS Status;
1419 HANDLE SockEvent;
1420
1421 Status = NtCreateEvent(&SockEvent,
1422 GENERIC_READ | GENERIC_WRITE,
1423 NULL,
1424 1,
1425 FALSE);
1426
1427 if( !NT_SUCCESS(Status) )
1428 return -1;
1429
1430 AFD_DbgPrint(MID_TRACE,("Called\n"));
1431
1432 /* Get the Socket Structure associate to this Socket*/
1433 Socket = GetSocketStructure(Handle);
1434
1435 /* Set AFD Disconnect Type */
1436 switch (HowTo)
1437 {
1438 case SD_RECEIVE:
1439 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
1440 Socket->SharedData.ReceiveShutdown = TRUE;
1441 break;
1442 case SD_SEND:
1443 DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
1444 Socket->SharedData.SendShutdown = TRUE;
1445 break;
1446 case SD_BOTH:
1447 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
1448 Socket->SharedData.ReceiveShutdown = TRUE;
1449 Socket->SharedData.SendShutdown = TRUE;
1450 break;
1451 }
1452
1453 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
1454
1455 /* Send IOCTL */
1456 Status = NtDeviceIoControlFile((HANDLE)Handle,
1457 SockEvent,
1458 NULL,
1459 NULL,
1460 &IOSB,
1461 IOCTL_AFD_DISCONNECT,
1462 &DisconnectInfo,
1463 sizeof(DisconnectInfo),
1464 NULL,
1465 0);
1466
1467 /* Wait for return */
1468 if (Status == STATUS_PENDING)
1469 {
1470 WaitForSingleObject(SockEvent, INFINITE);
1471 Status = IOSB.Status;
1472 }
1473
1474 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1475
1476 NtClose( SockEvent );
1477
1478 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1479 }
1480
1481
1482 INT
1483 WSPAPI
1484 WSPGetSockName(IN SOCKET Handle,
1485 OUT LPSOCKADDR Name,
1486 IN OUT LPINT NameLength,
1487 OUT LPINT lpErrno)
1488 {
1489 IO_STATUS_BLOCK IOSB;
1490 ULONG TdiAddressSize;
1491 PTDI_ADDRESS_INFO TdiAddress;
1492 PTRANSPORT_ADDRESS SocketAddress;
1493 PSOCKET_INFORMATION Socket = NULL;
1494 NTSTATUS Status;
1495 HANDLE SockEvent;
1496
1497 Status = NtCreateEvent(&SockEvent,
1498 GENERIC_READ | GENERIC_WRITE,
1499 NULL,
1500 1,
1501 FALSE);
1502
1503 if( !NT_SUCCESS(Status) )
1504 return SOCKET_ERROR;
1505
1506 /* Get the Socket Structure associate to this Socket*/
1507 Socket = GetSocketStructure(Handle);
1508
1509 /* Allocate a buffer for the address */
1510 TdiAddressSize =
1511 sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
1512 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1513
1514 if ( TdiAddress == NULL )
1515 {
1516 NtClose( SockEvent );
1517 *lpErrno = WSAENOBUFS;
1518 return SOCKET_ERROR;
1519 }
1520
1521 SocketAddress = &TdiAddress->Address;
1522
1523 /* Send IOCTL */
1524 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1525 SockEvent,
1526 NULL,
1527 NULL,
1528 &IOSB,
1529 IOCTL_AFD_GET_SOCK_NAME,
1530 NULL,
1531 0,
1532 TdiAddress,
1533 TdiAddressSize);
1534
1535 /* Wait for return */
1536 if (Status == STATUS_PENDING)
1537 {
1538 WaitForSingleObject(SockEvent, INFINITE);
1539 Status = IOSB.Status;
1540 }
1541
1542 NtClose( SockEvent );
1543
1544 if (NT_SUCCESS(Status))
1545 {
1546 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1547 {
1548 Name->sa_family = SocketAddress->Address[0].AddressType;
1549 RtlCopyMemory (Name->sa_data,
1550 SocketAddress->Address[0].Address,
1551 SocketAddress->Address[0].AddressLength);
1552 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1553 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
1554 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1555 ((struct sockaddr_in *)Name)->sin_port));
1556 HeapFree(GlobalHeap, 0, TdiAddress);
1557 return 0;
1558 }
1559 else
1560 {
1561 HeapFree(GlobalHeap, 0, TdiAddress);
1562 *lpErrno = WSAEFAULT;
1563 return SOCKET_ERROR;
1564 }
1565 }
1566
1567 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1568 }
1569
1570
1571 INT
1572 WSPAPI
1573 WSPGetPeerName(IN SOCKET s,
1574 OUT LPSOCKADDR Name,
1575 IN OUT LPINT NameLength,
1576 OUT LPINT lpErrno)
1577 {
1578 IO_STATUS_BLOCK IOSB;
1579 ULONG TdiAddressSize;
1580 PTRANSPORT_ADDRESS SocketAddress;
1581 PSOCKET_INFORMATION Socket = NULL;
1582 NTSTATUS Status;
1583 HANDLE SockEvent;
1584
1585 Status = NtCreateEvent(&SockEvent,
1586 GENERIC_READ | GENERIC_WRITE,
1587 NULL,
1588 1,
1589 FALSE);
1590
1591 if( !NT_SUCCESS(Status) )
1592 return SOCKET_ERROR;
1593
1594 /* Get the Socket Structure associate to this Socket*/
1595 Socket = GetSocketStructure(s);
1596
1597 /* Allocate a buffer for the address */
1598 TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
1599 SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1600
1601 if ( SocketAddress == NULL )
1602 {
1603 NtClose( SockEvent );
1604 *lpErrno = WSAENOBUFS;
1605 return SOCKET_ERROR;
1606 }
1607
1608 /* Send IOCTL */
1609 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1610 SockEvent,
1611 NULL,
1612 NULL,
1613 &IOSB,
1614 IOCTL_AFD_GET_PEER_NAME,
1615 NULL,
1616 0,
1617 SocketAddress,
1618 TdiAddressSize);
1619
1620 /* Wait for return */
1621 if (Status == STATUS_PENDING)
1622 {
1623 WaitForSingleObject(SockEvent, INFINITE);
1624 Status = IOSB.Status;
1625 }
1626
1627 NtClose( SockEvent );
1628
1629 if (NT_SUCCESS(Status))
1630 {
1631 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1632 {
1633 Name->sa_family = SocketAddress->Address[0].AddressType;
1634 RtlCopyMemory (Name->sa_data,
1635 SocketAddress->Address[0].Address,
1636 SocketAddress->Address[0].AddressLength);
1637 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1638 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
1639 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1640 ((struct sockaddr_in *)Name)->sin_port));
1641 HeapFree(GlobalHeap, 0, SocketAddress);
1642 return 0;
1643 }
1644 else
1645 {
1646 HeapFree(GlobalHeap, 0, SocketAddress);
1647 *lpErrno = WSAEFAULT;
1648 return SOCKET_ERROR;
1649 }
1650 }
1651
1652 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1653 }
1654
1655 INT
1656 WSPAPI
1657 WSPIoctl(IN SOCKET Handle,
1658 IN DWORD dwIoControlCode,
1659 IN LPVOID lpvInBuffer,
1660 IN DWORD cbInBuffer,
1661 OUT LPVOID lpvOutBuffer,
1662 IN DWORD cbOutBuffer,
1663 OUT LPDWORD lpcbBytesReturned,
1664 IN LPWSAOVERLAPPED lpOverlapped,
1665 IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
1666 IN LPWSATHREADID lpThreadId,
1667 OUT LPINT lpErrno)
1668 {
1669 PSOCKET_INFORMATION Socket = NULL;
1670
1671 /* Get the Socket Structure associate to this Socket*/
1672 Socket = GetSocketStructure(Handle);
1673
1674 switch( dwIoControlCode )
1675 {
1676 case FIONBIO:
1677 if( cbInBuffer < sizeof(INT) )
1678 return SOCKET_ERROR;
1679 Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
1680 AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n", Handle, Socket->SharedData.NonBlocking));
1681 return 0;
1682 case FIONREAD:
1683 return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
1684 default:
1685 *lpErrno = WSAEINVAL;
1686 return SOCKET_ERROR;
1687 }
1688 }
1689
1690
1691 INT
1692 WSPAPI
1693 WSPGetSockOpt(IN SOCKET Handle,
1694 IN INT Level,
1695 IN INT OptionName,
1696 OUT CHAR FAR* OptionValue,
1697 IN OUT LPINT OptionLength,
1698 OUT LPINT lpErrno)
1699 {
1700 PSOCKET_INFORMATION Socket = NULL;
1701 PVOID Buffer;
1702 INT BufferSize;
1703 BOOLEAN BoolBuffer;
1704
1705 /* Get the Socket Structure associate to this Socket*/
1706 Socket = GetSocketStructure(Handle);
1707 if (Socket == NULL)
1708 {
1709 *lpErrno = WSAENOTSOCK;
1710 return SOCKET_ERROR;
1711 }
1712
1713 AFD_DbgPrint(MID_TRACE, ("Called\n"));
1714
1715 switch (Level)
1716 {
1717 case SOL_SOCKET:
1718 switch (OptionName)
1719 {
1720 case SO_TYPE:
1721 Buffer = &Socket->SharedData.SocketType;
1722 BufferSize = sizeof(INT);
1723 break;
1724
1725 case SO_RCVBUF:
1726 Buffer = &Socket->SharedData.SizeOfRecvBuffer;
1727 BufferSize = sizeof(INT);
1728 break;
1729
1730 case SO_SNDBUF:
1731 Buffer = &Socket->SharedData.SizeOfSendBuffer;
1732 BufferSize = sizeof(INT);
1733 break;
1734
1735 case SO_ACCEPTCONN:
1736 BoolBuffer = Socket->SharedData.Listening;
1737 Buffer = &BoolBuffer;
1738 BufferSize = sizeof(BOOLEAN);
1739 break;
1740
1741 case SO_BROADCAST:
1742 BoolBuffer = Socket->SharedData.Broadcast;
1743 Buffer = &BoolBuffer;
1744 BufferSize = sizeof(BOOLEAN);
1745 break;
1746
1747 case SO_DEBUG:
1748 BoolBuffer = Socket->SharedData.Debug;
1749 Buffer = &BoolBuffer;
1750 BufferSize = sizeof(BOOLEAN);
1751 break;
1752
1753 /* case SO_CONDITIONAL_ACCEPT: */
1754 case SO_DONTLINGER:
1755 case SO_DONTROUTE:
1756 case SO_ERROR:
1757 case SO_GROUP_ID:
1758 case SO_GROUP_PRIORITY:
1759 case SO_KEEPALIVE:
1760 case SO_LINGER:
1761 case SO_MAX_MSG_SIZE:
1762 case SO_OOBINLINE:
1763 case SO_PROTOCOL_INFO:
1764 case SO_REUSEADDR:
1765 AFD_DbgPrint(MID_TRACE, ("Unimplemented option (%x)\n",
1766 OptionName));
1767
1768 default:
1769 *lpErrno = WSAEINVAL;
1770 return SOCKET_ERROR;
1771 }
1772
1773 if (*OptionLength < BufferSize)
1774 {
1775 *lpErrno = WSAEFAULT;
1776 *OptionLength = BufferSize;
1777 return SOCKET_ERROR;
1778 }
1779 RtlCopyMemory(OptionValue, Buffer, BufferSize);
1780
1781 return 0;
1782
1783 case IPPROTO_TCP: /* FIXME */
1784 default:
1785 *lpErrno = WSAEINVAL;
1786 return SOCKET_ERROR;
1787 }
1788 }
1789
1790
1791 /*
1792 * FUNCTION: Initialize service provider for a client
1793 * ARGUMENTS:
1794 * wVersionRequested = Highest WinSock SPI version that the caller can use
1795 * lpWSPData = Address of WSPDATA structure to initialize
1796 * lpProtocolInfo = Pointer to structure that defines the desired protocol
1797 * UpcallTable = Pointer to upcall table of the WinSock DLL
1798 * lpProcTable = Address of procedure table to initialize
1799 * RETURNS:
1800 * Status of operation
1801 */
1802 INT
1803 WSPAPI
1804 WSPStartup(IN WORD wVersionRequested,
1805 OUT LPWSPDATA lpWSPData,
1806 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
1807 IN WSPUPCALLTABLE UpcallTable,
1808 OUT LPWSPPROC_TABLE lpProcTable)
1809
1810 {
1811 NTSTATUS Status;
1812
1813 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
1814 Status = NO_ERROR;
1815 Upcalls = UpcallTable;
1816
1817 if (Status == NO_ERROR)
1818 {
1819 lpProcTable->lpWSPAccept = WSPAccept;
1820 lpProcTable->lpWSPAddressToString = WSPAddressToString;
1821 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
1822 lpProcTable->lpWSPBind = WSPBind;
1823 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
1824 lpProcTable->lpWSPCleanup = WSPCleanup;
1825 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
1826 lpProcTable->lpWSPConnect = WSPConnect;
1827 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
1828 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
1829 lpProcTable->lpWSPEventSelect = WSPEventSelect;
1830 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
1831 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
1832 lpProcTable->lpWSPGetSockName = WSPGetSockName;
1833 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
1834 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
1835 lpProcTable->lpWSPIoctl = WSPIoctl;
1836 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
1837 lpProcTable->lpWSPListen = WSPListen;
1838 lpProcTable->lpWSPRecv = WSPRecv;
1839 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
1840 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
1841 lpProcTable->lpWSPSelect = WSPSelect;
1842 lpProcTable->lpWSPSend = WSPSend;
1843 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
1844 lpProcTable->lpWSPSendTo = WSPSendTo;
1845 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
1846 lpProcTable->lpWSPShutdown = WSPShutdown;
1847 lpProcTable->lpWSPSocket = WSPSocket;
1848 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
1849 lpWSPData->wVersion = MAKEWORD(2, 2);
1850 lpWSPData->wHighVersion = MAKEWORD(2, 2);
1851 }
1852
1853 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
1854
1855 return Status;
1856 }
1857
1858
1859 /*
1860 * FUNCTION: Cleans up service provider for a client
1861 * ARGUMENTS:
1862 * lpErrno = Address of buffer for error information
1863 * RETURNS:
1864 * 0 if successful, or SOCKET_ERROR if not
1865 */
1866 INT
1867 WSPAPI
1868 WSPCleanup(OUT LPINT lpErrno)
1869
1870 {
1871 AFD_DbgPrint(MAX_TRACE, ("\n"));
1872 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
1873 *lpErrno = NO_ERROR;
1874
1875 return 0;
1876 }
1877
1878
1879
1880 int
1881 GetSocketInformation(PSOCKET_INFORMATION Socket,
1882 ULONG AfdInformationClass,
1883 PULONG Ulong OPTIONAL,
1884 PLARGE_INTEGER LargeInteger OPTIONAL)
1885 {
1886 IO_STATUS_BLOCK IOSB;
1887 AFD_INFO InfoData;
1888 NTSTATUS Status;
1889 HANDLE SockEvent;
1890
1891 Status = NtCreateEvent(&SockEvent,
1892 GENERIC_READ | GENERIC_WRITE,
1893 NULL,
1894 1,
1895 FALSE);
1896
1897 if( !NT_SUCCESS(Status) )
1898 return -1;
1899
1900 /* Set Info Class */
1901 InfoData.InformationClass = AfdInformationClass;
1902
1903 /* Send IOCTL */
1904 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1905 SockEvent,
1906 NULL,
1907 NULL,
1908 &IOSB,
1909 IOCTL_AFD_GET_INFO,
1910 &InfoData,
1911 sizeof(InfoData),
1912 &InfoData,
1913 sizeof(InfoData));
1914
1915 /* Wait for return */
1916 if (Status == STATUS_PENDING)
1917 {
1918 WaitForSingleObject(SockEvent, INFINITE);
1919 }
1920
1921 /* Return Information */
1922 *Ulong = InfoData.Information.Ulong;
1923 if (LargeInteger != NULL)
1924 {
1925 *LargeInteger = InfoData.Information.LargeInteger;
1926 }
1927
1928 NtClose( SockEvent );
1929
1930 return 0;
1931
1932 }
1933
1934
1935 int
1936 SetSocketInformation(PSOCKET_INFORMATION Socket,
1937 ULONG AfdInformationClass,
1938 PULONG Ulong OPTIONAL,
1939 PLARGE_INTEGER LargeInteger OPTIONAL)
1940 {
1941 IO_STATUS_BLOCK IOSB;
1942 AFD_INFO InfoData;
1943 NTSTATUS Status;
1944 HANDLE SockEvent;
1945
1946 Status = NtCreateEvent(&SockEvent,
1947 GENERIC_READ | GENERIC_WRITE,
1948 NULL,
1949 1,
1950 FALSE);
1951
1952 if( !NT_SUCCESS(Status) )
1953 return -1;
1954
1955 /* Set Info Class */
1956 InfoData.InformationClass = AfdInformationClass;
1957
1958 /* Set Information */
1959 InfoData.Information.Ulong = *Ulong;
1960 if (LargeInteger != NULL)
1961 {
1962 InfoData.Information.LargeInteger = *LargeInteger;
1963 }
1964
1965 AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
1966 AfdInformationClass, *Ulong));
1967
1968 /* Send IOCTL */
1969 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1970 SockEvent,
1971 NULL,
1972 NULL,
1973 &IOSB,
1974 IOCTL_AFD_GET_INFO,
1975 &InfoData,
1976 sizeof(InfoData),
1977 NULL,
1978 0);
1979
1980 /* Wait for return */
1981 if (Status == STATUS_PENDING)
1982 {
1983 WaitForSingleObject(SockEvent, INFINITE);
1984 }
1985
1986 NtClose( SockEvent );
1987
1988 return 0;
1989
1990 }
1991
1992 PSOCKET_INFORMATION
1993 GetSocketStructure(SOCKET Handle)
1994 {
1995 ULONG i;
1996
1997 for (i=0; i<SocketCount; i++)
1998 {
1999 if (Sockets[i]->Handle == Handle)
2000 {
2001 return Sockets[i];
2002 }
2003 }
2004 return 0;
2005 }
2006
2007 int CreateContext(PSOCKET_INFORMATION Socket)
2008 {
2009 IO_STATUS_BLOCK IOSB;
2010 SOCKET_CONTEXT ContextData;
2011 NTSTATUS Status;
2012 HANDLE SockEvent;
2013
2014 Status = NtCreateEvent(&SockEvent,
2015 GENERIC_READ | GENERIC_WRITE,
2016 NULL,
2017 1,
2018 FALSE);
2019
2020 if( !NT_SUCCESS(Status) )
2021 return -1;
2022
2023 /* Create Context */
2024 ContextData.SharedData = Socket->SharedData;
2025 ContextData.SizeOfHelperData = 0;
2026 RtlCopyMemory (&ContextData.LocalAddress,
2027 Socket->LocalAddress,
2028 Socket->SharedData.SizeOfLocalAddress);
2029 RtlCopyMemory (&ContextData.RemoteAddress,
2030 Socket->RemoteAddress,
2031 Socket->SharedData.SizeOfRemoteAddress);
2032
2033 /* Send IOCTL */
2034 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
2035 SockEvent,
2036 NULL,
2037 NULL,
2038 &IOSB,
2039 IOCTL_AFD_SET_CONTEXT,
2040 &ContextData,
2041 sizeof(ContextData),
2042 NULL,
2043 0);
2044
2045 /* Wait for Completition */
2046 if (Status == STATUS_PENDING)
2047 {
2048 WaitForSingleObject(SockEvent, INFINITE);
2049 }
2050
2051 NtClose( SockEvent );
2052
2053 return 0;
2054 }
2055
2056 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
2057 {
2058 HANDLE hAsyncThread;
2059 DWORD AsyncThreadId;
2060 HANDLE AsyncEvent;
2061 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2062 NTSTATUS Status;
2063
2064 /* Check if the Thread Already Exists */
2065 if (SockAsyncThreadRefCount)
2066 {
2067 return TRUE;
2068 }
2069
2070 /* Create the Completion Port */
2071 if (!SockAsyncCompletionPort)
2072 {
2073 Status = NtCreateIoCompletion(&SockAsyncCompletionPort,
2074 IO_COMPLETION_ALL_ACCESS,
2075 NULL,
2076 2); // Allow 2 threads only
2077
2078 /* Protect Handle */
2079 HandleFlags.ProtectFromClose = TRUE;
2080 HandleFlags.Inherit = FALSE;
2081 Status = NtSetInformationObject(SockAsyncCompletionPort,
2082 ObjectHandleFlagInformation,
2083 &HandleFlags,
2084 sizeof(HandleFlags));
2085 }
2086
2087 /* Create the Async Event */
2088 Status = NtCreateEvent(&AsyncEvent,
2089 EVENT_ALL_ACCESS,
2090 NULL,
2091 NotificationEvent,
2092 FALSE);
2093
2094 /* Create the Async Thread */
2095 hAsyncThread = CreateThread(NULL,
2096 0,
2097 (LPTHREAD_START_ROUTINE)SockAsyncThread,
2098 NULL,
2099 0,
2100 &AsyncThreadId);
2101
2102 /* Close the Handle */
2103 NtClose(hAsyncThread);
2104
2105 /* Increase the Reference Count */
2106 SockAsyncThreadRefCount++;
2107 return TRUE;
2108 }
2109
2110 int SockAsyncThread(PVOID ThreadParam)
2111 {
2112 PVOID AsyncContext;
2113 PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
2114 IO_STATUS_BLOCK IOSB;
2115 NTSTATUS Status;
2116
2117 /* Make the Thread Higher Priority */
2118 SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
2119
2120 /* Do a KQUEUE/WorkItem Style Loop, thanks to IoCompletion Ports */
2121 do
2122 {
2123 Status = NtRemoveIoCompletion (SockAsyncCompletionPort,
2124 (PVOID*)&AsyncCompletionRoutine,
2125 &AsyncContext,
2126 &IOSB,
2127 NULL);
2128 /* Call the Async Function */
2129 if (NT_SUCCESS(Status))
2130 {
2131 (*AsyncCompletionRoutine)(AsyncContext, &IOSB);
2132 }
2133 else
2134 {
2135 /* It Failed, sleep for a second */
2136 Sleep(1000);
2137 }
2138 } while ((Status != STATUS_TIMEOUT));
2139
2140 /* The Thread has Ended */
2141 return 0;
2142 }
2143
2144 BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
2145 {
2146 UNICODE_STRING AfdHelper;
2147 OBJECT_ATTRIBUTES ObjectAttributes;
2148 IO_STATUS_BLOCK IoSb;
2149 NTSTATUS Status;
2150 FILE_COMPLETION_INFORMATION CompletionInfo;
2151 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2152
2153 /* First, make sure we're not already intialized */
2154 if (SockAsyncHelperAfdHandle)
2155 {
2156 return TRUE;
2157 }
2158
2159 /* Set up Handle Name and Object */
2160 RtlInitUnicodeString(&AfdHelper, L"\\Device\\Afd\\AsyncSelectHlp" );
2161 InitializeObjectAttributes(&ObjectAttributes,
2162 &AfdHelper,
2163 OBJ_INHERIT | OBJ_CASE_INSENSITIVE,
2164 NULL,
2165 NULL);
2166
2167 /* Open the Handle to AFD */
2168 Status = NtCreateFile(&SockAsyncHelperAfdHandle,
2169 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
2170 &ObjectAttributes,
2171 &IoSb,
2172 NULL,
2173 0,
2174 FILE_SHARE_READ | FILE_SHARE_WRITE,
2175 FILE_OPEN_IF,
2176 0,
2177 NULL,
2178 0);
2179
2180 /*
2181 * Now Set up the Completion Port Information
2182 * This means that whenever a Poll is finished, the routine will be executed
2183 */
2184 CompletionInfo.Port = SockAsyncCompletionPort;
2185 CompletionInfo.Key = SockAsyncSelectCompletionRoutine;
2186 Status = NtSetInformationFile(SockAsyncHelperAfdHandle,
2187 &IoSb,
2188 &CompletionInfo,
2189 sizeof(CompletionInfo),
2190 FileCompletionInformation);
2191
2192
2193 /* Protect the Handle */
2194 HandleFlags.ProtectFromClose = TRUE;
2195 HandleFlags.Inherit = FALSE;
2196 Status = NtSetInformationObject(SockAsyncCompletionPort,
2197 ObjectHandleFlagInformation,
2198 &HandleFlags,
2199 sizeof(HandleFlags));
2200
2201
2202 /* Set this variable to true so that Send/Recv/Accept will know wether to renable disabled events */
2203 SockAsyncSelectCalled = TRUE;
2204 return TRUE;
2205 }
2206
2207 VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2208 {
2209
2210 PASYNC_DATA AsyncData = Context;
2211 PSOCKET_INFORMATION Socket;
2212 ULONG x;
2213
2214 /* Get the Socket */
2215 Socket = AsyncData->ParentSocket;
2216
2217 /* Check if the Sequence Number Changed behind our back */
2218 if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber )
2219 {
2220 return;
2221 }
2222
2223 /* Check we were manually called b/c of a failure */
2224 if (!NT_SUCCESS(IoStatusBlock->Status))
2225 {
2226 /* FIXME: Perform Upcall */
2227 return;
2228 }
2229
2230 for (x = 1; x; x<<=1)
2231 {
2232 switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x)
2233 {
2234 case AFD_EVENT_RECEIVE:
2235 if (0 != (Socket->SharedData.AsyncEvents & FD_READ) &&
2236 0 == (Socket->SharedData.AsyncDisabledEvents & FD_READ))
2237 {
2238 /* Make the Notifcation */
2239 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2240 Socket->SharedData.wMsg,
2241 Socket->Handle,
2242 WSAMAKESELECTREPLY(FD_READ, 0));
2243 /* Disable this event until the next read(); */
2244 Socket->SharedData.AsyncDisabledEvents |= FD_READ;
2245 }
2246 break;
2247
2248 case AFD_EVENT_OOB_RECEIVE:
2249 if (0 != (Socket->SharedData.AsyncEvents & FD_OOB) &&
2250 0 == (Socket->SharedData.AsyncDisabledEvents & FD_OOB))
2251 {
2252 /* Make the Notifcation */
2253 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2254 Socket->SharedData.wMsg,
2255 Socket->Handle,
2256 WSAMAKESELECTREPLY(FD_OOB, 0));
2257 /* Disable this event until the next read(); */
2258 Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
2259 }
2260 break;
2261
2262 case AFD_EVENT_SEND:
2263 if (0 != (Socket->SharedData.AsyncEvents & FD_WRITE) &&
2264 0 == (Socket->SharedData.AsyncDisabledEvents & FD_WRITE))
2265 {
2266 /* Make the Notifcation */
2267 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2268 Socket->SharedData.wMsg,
2269 Socket->Handle,
2270 WSAMAKESELECTREPLY(FD_WRITE, 0));
2271 /* Disable this event until the next write(); */
2272 Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
2273 }
2274 break;
2275
2276 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2277 case AFD_EVENT_CONNECT:
2278 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
2279 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
2280 {
2281 /* Make the Notifcation */
2282 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2283 Socket->SharedData.wMsg,
2284 Socket->Handle,
2285 WSAMAKESELECTREPLY(FD_CONNECT, 0));
2286 /* Disable this event forever; */
2287 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT;
2288 }
2289 break;
2290
2291 case AFD_EVENT_ACCEPT:
2292 if (0 != (Socket->SharedData.AsyncEvents & FD_ACCEPT) &&
2293 0 == (Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT))
2294 {
2295 /* Make the Notifcation */
2296 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2297 Socket->SharedData.wMsg,
2298 Socket->Handle,
2299 WSAMAKESELECTREPLY(FD_ACCEPT, 0));
2300 /* Disable this event until the next accept(); */
2301 Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
2302 }
2303 break;
2304
2305 case AFD_EVENT_DISCONNECT:
2306 case AFD_EVENT_ABORT:
2307 case AFD_EVENT_CLOSE:
2308 if (0 != (Socket->SharedData.AsyncEvents & FD_CLOSE) &&
2309 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CLOSE))
2310 {
2311 /* Make the Notifcation */
2312 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2313 Socket->SharedData.wMsg,
2314 Socket->Handle,
2315 WSAMAKESELECTREPLY(FD_CLOSE, 0));
2316 /* Disable this event forever; */
2317 Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
2318 }
2319 break;
2320 /* FIXME: Support QOS */
2321 }
2322 }
2323
2324 /* Check if there are any events left for us to check */
2325 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2326 {
2327 return;
2328 }
2329
2330 /* Keep Polling */
2331 SockProcessAsyncSelect(Socket, AsyncData);
2332 return;
2333 }
2334
2335 VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
2336 {
2337
2338 ULONG lNetworkEvents;
2339 NTSTATUS Status;
2340
2341 /* Set up the Async Data Event Info */
2342 AsyncData->AsyncSelectInfo.Timeout.HighPart = 0x7FFFFFFF;
2343 AsyncData->AsyncSelectInfo.Timeout.LowPart = 0xFFFFFFFF;
2344 AsyncData->AsyncSelectInfo.HandleCount = 1;
2345 AsyncData->AsyncSelectInfo.Exclusive = TRUE;
2346 AsyncData->AsyncSelectInfo.Handles[0].Handle = Socket->Handle;
2347 AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
2348
2349 /* Remove unwanted events */
2350 lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
2351
2352 /* Set Events to wait for */
2353 if (lNetworkEvents & FD_READ)
2354 {
2355 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_RECEIVE;
2356 }
2357
2358 if (lNetworkEvents & FD_WRITE)
2359 {
2360 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_SEND;
2361 }
2362
2363 if (lNetworkEvents & FD_OOB)
2364 {
2365 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_OOB_RECEIVE;
2366 }
2367
2368 if (lNetworkEvents & FD_ACCEPT)
2369 {
2370 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_ACCEPT;
2371 }
2372
2373 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2374 if (lNetworkEvents & FD_CONNECT)
2375 {
2376 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_CONNECT | AFD_EVENT_CONNECT_FAIL;
2377 }
2378
2379 if (lNetworkEvents & FD_CLOSE)
2380 {
2381 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
2382 }
2383
2384 if (lNetworkEvents & FD_QOS)
2385 {
2386 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_QOS;
2387 }
2388
2389 if (lNetworkEvents & FD_GROUP_QOS)
2390 {
2391 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_GROUP_QOS;
2392 }
2393
2394 /* Send IOCTL */
2395 Status = NtDeviceIoControlFile (SockAsyncHelperAfdHandle,
2396 NULL,
2397 NULL,
2398 AsyncData,
2399 &AsyncData->IoStatusBlock,
2400 IOCTL_AFD_SELECT,
2401 &AsyncData->AsyncSelectInfo,
2402 sizeof(AsyncData->AsyncSelectInfo),
2403 &AsyncData->AsyncSelectInfo,
2404 sizeof(AsyncData->AsyncSelectInfo));
2405
2406 /* I/O Manager Won't call the completion routine, let's do it manually */
2407 if (NT_SUCCESS(Status))
2408 {
2409 return;
2410 }
2411 else
2412 {
2413 AsyncData->IoStatusBlock.Status = Status;
2414 SockAsyncSelectCompletionRoutine(AsyncData, &AsyncData->IoStatusBlock);
2415 }
2416 }
2417
2418 VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2419 {
2420 PASYNC_DATA AsyncData = Context;
2421 BOOL FreeContext = TRUE;
2422 PSOCKET_INFORMATION Socket;
2423
2424 /* Get the Socket */
2425 Socket = AsyncData->ParentSocket;
2426
2427 /* If someone closed it, stop the function */
2428 if (Socket->SharedData.State != SocketClosed)
2429 {
2430 /* Check if the Sequence Number changed by now, in which case quit */
2431 if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber)
2432 {
2433 /* Do the actuall select, if needed */
2434 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)))
2435 {
2436 SockProcessAsyncSelect(Socket, AsyncData);
2437 FreeContext = FALSE;
2438 }
2439 }
2440 }
2441
2442 /* Free the Context */
2443 if (FreeContext)
2444 {
2445 HeapFree(GetProcessHeap(), 0, AsyncData);
2446 }
2447
2448 return;
2449 }
2450
2451 VOID
2452 SockReenableAsyncSelectEvent (IN PSOCKET_INFORMATION Socket,
2453 IN ULONG Event)
2454 {
2455 PASYNC_DATA AsyncData;
2456
2457 /* Make sure the event is actually disabled */
2458 if (!(Socket->SharedData.AsyncDisabledEvents & Event))
2459 {
2460 return;
2461 }
2462
2463 /* Re-enable it */
2464 Socket->SharedData.AsyncDisabledEvents &= ~Event;
2465
2466 /* Return if no more events are being polled */
2467 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2468 {
2469 return;
2470 }
2471
2472 /* Wait on new events */
2473 AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
2474
2475 /* Create the Asynch Thread if Needed */
2476 SockCreateOrReferenceAsyncThread();
2477
2478 /* Increase the sequence number to stop anything else */
2479 Socket->SharedData.SequenceNumber++;
2480
2481 /* Set up the Async Data */
2482 AsyncData->ParentSocket = Socket;
2483 AsyncData->SequenceNumber = Socket->SharedData.SequenceNumber;
2484
2485 /* Begin Async Select by using I/O Completion */
2486 NtSetIoCompletion(SockAsyncCompletionPort,
2487 (PVOID)&SockProcessQueuedAsyncSelect,
2488 AsyncData,
2489 0,
2490 0);
2491
2492 /* All done */
2493 return;
2494 }
2495
2496 BOOL
2497 WINAPI
2498 DllMain(HANDLE hInstDll,
2499 ULONG dwReason,
2500 PVOID Reserved)
2501 {
2502
2503 switch (dwReason)
2504 {
2505 case DLL_PROCESS_ATTACH:
2506
2507 AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
2508
2509 /* Don't need thread attach notifications
2510 so disable them to improve performance */
2511 DisableThreadLibraryCalls(hInstDll);
2512
2513 /* List of DLL Helpers */
2514 InitializeListHead(&SockHelpersListHead);
2515
2516 /* Heap to use when allocating */
2517 GlobalHeap = GetProcessHeap();
2518
2519 /* Allocate Heap for 1024 Sockets, can be expanded later */
2520 Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
2521
2522 AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
2523
2524 break;
2525
2526 case DLL_THREAD_ATTACH:
2527 break;
2528
2529 case DLL_THREAD_DETACH:
2530 break;
2531
2532 case DLL_PROCESS_DETACH:
2533 break;
2534 }
2535
2536 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
2537
2538 return TRUE;
2539 }
2540
2541 /* EOF */
2542
2543