Merge trunk head (r41474)
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
4 * FILE: misc/dllmain.c
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7 * Alex Ionescu (alex@relsoft.net)
8 * REVISIONS:
9 * CSH 01/09-2000 Created
10 * Alex 16/07/2004 - Complete Rewrite
11 */
12
13 #include <msafd.h>
14
15 #include <debug.h>
16
17 #if DBG
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
19 DWORD DebugTraceLevel = 0;
20 #endif /* DBG */
21
22 HANDLE GlobalHeap;
23 WSPUPCALLTABLE Upcalls;
24 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
25 ULONG SocketCount = 0;
26 PSOCKET_INFORMATION *Sockets = NULL;
27 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
28 ULONG SockAsyncThreadRefCount;
29 HANDLE SockAsyncHelperAfdHandle;
30 HANDLE SockAsyncCompletionPort;
31 BOOLEAN SockAsyncSelectCalled;
32
33
34
35 /*
36 * FUNCTION: Creates a new socket
37 * ARGUMENTS:
38 * af = Address family
39 * type = Socket type
40 * protocol = Protocol type
41 * lpProtocolInfo = Pointer to protocol information
42 * g = Reserved
43 * dwFlags = Socket flags
44 * lpErrno = Address of buffer for error information
45 * RETURNS:
46 * Created socket, or INVALID_SOCKET if it could not be created
47 */
48 SOCKET
49 WSPAPI
50 WSPSocket(int AddressFamily,
51 int SocketType,
52 int Protocol,
53 LPWSAPROTOCOL_INFOW lpProtocolInfo,
54 GROUP g,
55 DWORD dwFlags,
56 LPINT lpErrno)
57 {
58 OBJECT_ATTRIBUTES Object;
59 IO_STATUS_BLOCK IOSB;
60 USHORT SizeOfPacket;
61 ULONG SizeOfEA;
62 PAFD_CREATE_PACKET AfdPacket;
63 HANDLE Sock;
64 PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
65 PFILE_FULL_EA_INFORMATION EABuffer = NULL;
66 PHELPER_DATA HelperData;
67 PVOID HelperDLLContext;
68 DWORD HelperEvents;
69 UNICODE_STRING TransportName;
70 UNICODE_STRING DevName;
71 LARGE_INTEGER GroupData;
72 INT Status;
73
74 AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
75 AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
76 AddressFamily, SocketType, Protocol));
77
78 /* Get Helper Data and Transport */
79 Status = SockGetTdiName (&AddressFamily,
80 &SocketType,
81 &Protocol,
82 g,
83 dwFlags,
84 &TransportName,
85 &HelperDLLContext,
86 &HelperData,
87 &HelperEvents);
88
89 /* Check for error */
90 if (Status != NO_ERROR)
91 {
92 AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
93 goto error;
94 }
95
96 /* AFD Device Name */
97 RtlInitUnicodeString(&DevName, L"\\Device\\Afd\\Endpoint");
98
99 /* Set Socket Data */
100 Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
101 RtlZeroMemory(Socket, sizeof(*Socket));
102 Socket->RefCount = 2;
103 Socket->Handle = -1;
104 Socket->SharedData.Listening = FALSE;
105 Socket->SharedData.State = SocketOpen;
106 Socket->SharedData.AddressFamily = AddressFamily;
107 Socket->SharedData.SocketType = SocketType;
108 Socket->SharedData.Protocol = Protocol;
109 Socket->HelperContext = HelperDLLContext;
110 Socket->HelperData = HelperData;
111 Socket->HelperEvents = HelperEvents;
112 Socket->LocalAddress = &Socket->WSLocalAddress;
113 Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
114 Socket->RemoteAddress = &Socket->WSRemoteAddress;
115 Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
116 Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
117 Socket->SharedData.CreateFlags = dwFlags;
118 Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
119 Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
120 Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
121 Socket->SharedData.GroupID = g;
122 Socket->SharedData.GroupType = 0;
123 Socket->SharedData.UseSAN = FALSE;
124 Socket->SharedData.NonBlocking = FALSE; /* Sockets start blocking */
125 Socket->SanData = NULL;
126
127 /* Ask alex about this */
128 if( Socket->SharedData.SocketType == SOCK_DGRAM ||
129 Socket->SharedData.SocketType == SOCK_RAW )
130 {
131 AFD_DbgPrint(MID_TRACE,("Connectionless socket\n"));
132 Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
133 }
134
135 /* Packet Size */
136 SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
137
138 /* EA Size */
139 SizeOfEA = SizeOfPacket + sizeof(FILE_FULL_EA_INFORMATION) + AFD_PACKET_COMMAND_LENGTH;
140
141 /* Set up EA Buffer */
142 EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
143 RtlZeroMemory(EABuffer, SizeOfEA);
144 EABuffer->NextEntryOffset = 0;
145 EABuffer->Flags = 0;
146 EABuffer->EaNameLength = AFD_PACKET_COMMAND_LENGTH;
147 RtlCopyMemory (EABuffer->EaName,
148 AfdCommand,
149 AFD_PACKET_COMMAND_LENGTH + 1);
150 EABuffer->EaValueLength = SizeOfPacket;
151
152 /* Set up AFD Packet */
153 AfdPacket = (PAFD_CREATE_PACKET)(EABuffer->EaName + EABuffer->EaNameLength + 1);
154 AfdPacket->SizeOfTransportName = TransportName.Length;
155 RtlCopyMemory (AfdPacket->TransportName,
156 TransportName.Buffer,
157 TransportName.Length + sizeof(WCHAR));
158 AfdPacket->GroupID = g;
159
160 /* Set up Endpoint Flags */
161 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0)
162 {
163 if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW))
164 {
165 /* Only RAW or UDP can be Connectionless */
166 goto error;
167 }
168 AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
169 }
170
171 if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0)
172 {
173 if (SocketType == SOCK_STREAM)
174 {
175 if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
176 {
177 /* The Provider doesn't actually support Message Oriented Streams */
178 goto error;
179 }
180 }
181 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MESSAGE_ORIENTED;
182 }
183
184 if (SocketType == SOCK_RAW) AfdPacket->EndpointFlags |= AFD_ENDPOINT_RAW;
185
186 if (dwFlags & (WSA_FLAG_MULTIPOINT_C_ROOT |
187 WSA_FLAG_MULTIPOINT_C_LEAF |
188 WSA_FLAG_MULTIPOINT_D_ROOT |
189 WSA_FLAG_MULTIPOINT_D_LEAF))
190 {
191 if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
192 {
193 /* The Provider doesn't actually support Multipoint */
194 goto error;
195 }
196 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
197
198 if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT)
199 {
200 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
201 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0))
202 {
203 /* The Provider doesn't support Control Planes, or you already gave a leaf */
204 goto error;
205 }
206 AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
207 }
208
209 if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT)
210 {
211 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
212 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0))
213 {
214 /* The Provider doesn't support Data Planes, or you already gave a leaf */
215 goto error;
216 }
217 AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
218 }
219 }
220
221 /* Set up Object Attributes */
222 InitializeObjectAttributes (&Object,
223 &DevName,
224 OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
225 0,
226 0);
227
228 /* Create the Socket as asynchronous. That means we have to block
229 ourselves after every call to NtDeviceIoControlFile. This is
230 because the kernel doesn't support overlapping synchronous I/O
231 requests (made from multiple threads) at this time (Sep 2005) */
232 ZwCreateFile(&Sock,
233 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
234 &Object,
235 &IOSB,
236 NULL,
237 0,
238 FILE_SHARE_READ | FILE_SHARE_WRITE, FILE_OPEN_IF,
239 0,
240 EABuffer,
241 SizeOfEA);
242
243 /* Save Handle */
244 Socket->Handle = (SOCKET)Sock;
245
246 /* XXX See if there's a structure we can reuse -- We need to do this
247 * more properly. */
248 PrevSocket = GetSocketStructure( (SOCKET)Sock );
249
250 if( PrevSocket )
251 {
252 RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
253 RtlFreeHeap( GlobalHeap, 0, Socket );
254 Socket = PrevSocket;
255 }
256
257 /* Save Group Info */
258 if (g != 0)
259 {
260 GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
261 Socket->SharedData.GroupID = GroupData.u.LowPart;
262 Socket->SharedData.GroupType = GroupData.u.HighPart;
263 }
264
265 /* Get Window Sizes and Save them */
266 GetSocketInformation (Socket,
267 AFD_INFO_SEND_WINDOW_SIZE,
268 &Socket->SharedData.SizeOfSendBuffer,
269 NULL);
270
271 GetSocketInformation (Socket,
272 AFD_INFO_RECEIVE_WINDOW_SIZE,
273 &Socket->SharedData.SizeOfRecvBuffer,
274 NULL);
275
276 /* Save in Process Sockets List */
277 Sockets[SocketCount] = Socket;
278 SocketCount ++;
279
280 /* Create the Socket Context */
281 CreateContext(Socket);
282
283 /* Notify Winsock */
284 Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
285
286 /* Return Socket Handle */
287 AFD_DbgPrint(MID_TRACE,("Success %x\n", Sock));
288
289 return (SOCKET)Sock;
290
291 error:
292 AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
293
294 if( lpErrno )
295 *lpErrno = Status;
296
297 return INVALID_SOCKET;
298 }
299
300
301 DWORD MsafdReturnWithErrno(NTSTATUS Status,
302 LPINT Errno,
303 DWORD Received,
304 LPDWORD ReturnedBytes)
305 {
306 if( ReturnedBytes )
307 *ReturnedBytes = 0;
308 if( Errno )
309 {
310 switch (Status)
311 {
312 case STATUS_CANT_WAIT:
313 *Errno = WSAEWOULDBLOCK;
314 break;
315 case STATUS_TIMEOUT:
316 *Errno = WSAETIMEDOUT;
317 break;
318 case STATUS_SUCCESS:
319 /* Return Number of bytes Read */
320 if( ReturnedBytes )
321 *ReturnedBytes = Received;
322 break;
323 case STATUS_FILE_CLOSED:
324 case STATUS_END_OF_FILE:
325 *Errno = WSAESHUTDOWN;
326 break;
327 case STATUS_PENDING:
328 *Errno = WSA_IO_PENDING;
329 break;
330 case STATUS_BUFFER_TOO_SMALL:
331 case STATUS_BUFFER_OVERFLOW:
332 DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
333 *Errno = WSAEMSGSIZE;
334 break;
335 case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
336 case STATUS_INSUFFICIENT_RESOURCES:
337 DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
338 *Errno = WSA_NOT_ENOUGH_MEMORY;
339 break;
340 case STATUS_INVALID_CONNECTION:
341 DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
342 *Errno = WSAEAFNOSUPPORT;
343 break;
344 case STATUS_REMOTE_NOT_LISTENING:
345 DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
346 *Errno = WSAECONNREFUSED;
347 break;
348 case STATUS_NETWORK_UNREACHABLE:
349 DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
350 *Errno = WSAENETUNREACH;
351 break;
352 case STATUS_INVALID_PARAMETER:
353 DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
354 *Errno = WSAEINVAL;
355 break;
356 case STATUS_CANCELLED:
357 DbgPrint("MSAFD: STATUS_CANCELLED\n");
358 *Errno = WSA_OPERATION_ABORTED;
359 break;
360 default:
361 DbgPrint("MSAFD: Error %x is unknown\n", Status);
362 *Errno = WSAEINVAL;
363 break;
364 }
365 }
366
367 /* Success */
368 return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
369 }
370
371 /*
372 * FUNCTION: Closes an open socket
373 * ARGUMENTS:
374 * s = Socket descriptor
375 * lpErrno = Address of buffer for error information
376 * RETURNS:
377 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
378 */
379 INT
380 WSPAPI
381 WSPCloseSocket(IN SOCKET Handle,
382 OUT LPINT lpErrno)
383 {
384 IO_STATUS_BLOCK IoStatusBlock;
385 PSOCKET_INFORMATION Socket = NULL;
386 NTSTATUS Status;
387 HANDLE SockEvent;
388 AFD_DISCONNECT_INFO DisconnectInfo;
389 SOCKET_STATE OldState;
390
391 /* Create the Wait Event */
392 Status = NtCreateEvent(&SockEvent,
393 GENERIC_READ | GENERIC_WRITE,
394 NULL,
395 1,
396 FALSE);
397
398 if(!NT_SUCCESS(Status))
399 return SOCKET_ERROR;
400
401 /* Get the Socket Structure associate to this Socket*/
402 Socket = GetSocketStructure(Handle);
403
404 /* If a Close is already in Process, give up */
405 if (Socket->SharedData.State == SocketClosed)
406 {
407 NtClose(SockEvent);
408 *lpErrno = WSAENOTSOCK;
409 return SOCKET_ERROR;
410 }
411
412 /* Set the state to close */
413 OldState = Socket->SharedData.State;
414 Socket->SharedData.State = SocketClosed;
415
416 /* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
417 /* FIXME: Should we do this on Datagram Sockets too? */
418 if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
419 {
420 ULONG LingerWait;
421 ULONG SendsInProgress;
422 ULONG SleepWait;
423
424 /* We need to respect the timeout */
425 SleepWait = 100;
426 LingerWait = Socket->SharedData.LingerData.l_linger * 1000;
427
428 /* Loop until no more sends are pending, within the timeout */
429 while (LingerWait)
430 {
431 /* Find out how many Sends are in Progress */
432 if (GetSocketInformation(Socket,
433 AFD_INFO_SENDS_IN_PROGRESS,
434 &SendsInProgress,
435 NULL))
436 {
437 /* Bail out if anything but NO_ERROR */
438 LingerWait = 0;
439 break;
440 }
441
442 /* Bail out if no more sends are pending */
443 if (!SendsInProgress)
444 break;
445 /*
446 * We have to execute a sleep, so it's kind of like
447 * a block. If the socket is Nonblock, we cannot
448 * go on since asyncronous operation is expected
449 * and we cannot offer it
450 */
451 if (Socket->SharedData.NonBlocking)
452 {
453 NtClose(SockEvent);
454 Socket->SharedData.State = OldState;
455 *lpErrno = WSAEWOULDBLOCK;
456 return SOCKET_ERROR;
457 }
458
459 /* Now we can sleep, and decrement the linger wait */
460 /*
461 * FIXME: It seems Windows does some funky acceleration
462 * since the waiting seems to be longer and longer. I
463 * don't think this improves performance so much, so we
464 * wait a fixed time instead.
465 */
466 Sleep(SleepWait);
467 LingerWait -= SleepWait;
468 }
469
470 /*
471 * We have reached the timeout or sends are over.
472 * Disconnect if the timeout has been reached.
473 */
474 if (LingerWait <= 0)
475 {
476 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
477 DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
478
479 /* Send IOCTL */
480 Status = NtDeviceIoControlFile((HANDLE)Handle,
481 SockEvent,
482 NULL,
483 NULL,
484 &IoStatusBlock,
485 IOCTL_AFD_DISCONNECT,
486 &DisconnectInfo,
487 sizeof(DisconnectInfo),
488 NULL,
489 0);
490
491 /* Wait for return */
492 if (Status == STATUS_PENDING)
493 {
494 WaitForSingleObject(SockEvent, INFINITE);
495 }
496 }
497 }
498
499 /* FIXME: We should notify the Helper DLL of WSH_NOTIFY_CLOSE */
500
501 /* Cleanup Time! */
502 Socket->HelperContext = NULL;
503 Socket->SharedData.AsyncDisabledEvents = -1;
504 NtClose(Socket->TdiAddressHandle);
505 Socket->TdiAddressHandle = NULL;
506 NtClose(Socket->TdiConnectionHandle);
507 Socket->TdiConnectionHandle = NULL;
508
509 /* Close the handle */
510 NtClose((HANDLE)Handle);
511 NtClose(SockEvent);
512
513 return NO_ERROR;
514 }
515
516
517 /*
518 * FUNCTION: Associates a local address with a socket
519 * ARGUMENTS:
520 * s = Socket descriptor
521 * name = Pointer to local address
522 * namelen = Length of name
523 * lpErrno = Address of buffer for error information
524 * RETURNS:
525 * 0, or SOCKET_ERROR if the socket could not be bound
526 */
527 INT
528 WSPAPI
529 WSPBind(SOCKET Handle,
530 const struct sockaddr *SocketAddress,
531 int SocketAddressLength,
532 LPINT lpErrno)
533 {
534 IO_STATUS_BLOCK IOSB;
535 PAFD_BIND_DATA BindData;
536 PSOCKET_INFORMATION Socket = NULL;
537 NTSTATUS Status;
538 UCHAR BindBuffer[0x1A];
539 SOCKADDR_INFO SocketInfo;
540 HANDLE SockEvent;
541
542 Status = NtCreateEvent(&SockEvent,
543 GENERIC_READ | GENERIC_WRITE,
544 NULL,
545 1,
546 FALSE);
547
548 if( !NT_SUCCESS(Status) )
549 return -1;
550
551 /* Get the Socket Structure associate to this Socket*/
552 Socket = GetSocketStructure(Handle);
553
554 /* Dynamic Structure...ugh */
555 BindData = (PAFD_BIND_DATA)BindBuffer;
556
557 /* Set up Address in TDI Format */
558 BindData->Address.TAAddressCount = 1;
559 BindData->Address.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
560 BindData->Address.Address[0].AddressType = SocketAddress->sa_family;
561 RtlCopyMemory (BindData->Address.Address[0].Address,
562 SocketAddress->sa_data,
563 SocketAddressLength - sizeof(SocketAddress->sa_family));
564
565 /* Get Address Information */
566 Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
567 SocketAddressLength,
568 &SocketInfo);
569
570 /* Set the Share Type */
571 if (Socket->SharedData.ExclusiveAddressUse)
572 {
573 BindData->ShareType = AFD_SHARE_EXCLUSIVE;
574 }
575 else if (SocketInfo.EndpointInfo == SockaddrEndpointInfoWildcard)
576 {
577 BindData->ShareType = AFD_SHARE_WILDCARD;
578 }
579 else if (Socket->SharedData.ReuseAddresses)
580 {
581 BindData->ShareType = AFD_SHARE_REUSE;
582 }
583 else
584 {
585 BindData->ShareType = AFD_SHARE_UNIQUE;
586 }
587
588 /* Send IOCTL */
589 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
590 SockEvent,
591 NULL,
592 NULL,
593 &IOSB,
594 IOCTL_AFD_BIND,
595 BindData,
596 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
597 BindData,
598 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
599
600 /* Wait for return */
601 if (Status == STATUS_PENDING)
602 {
603 WaitForSingleObject(SockEvent, INFINITE);
604 Status = IOSB.Status;
605 }
606
607 /* Set up Socket Data */
608 Socket->SharedData.State = SocketBound;
609 Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
610
611 NtClose( SockEvent );
612
613 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
614 }
615
616 int
617 WSPAPI
618 WSPListen(SOCKET Handle,
619 int Backlog,
620 LPINT lpErrno)
621 {
622 IO_STATUS_BLOCK IOSB;
623 AFD_LISTEN_DATA ListenData;
624 PSOCKET_INFORMATION Socket = NULL;
625 HANDLE SockEvent;
626 NTSTATUS Status;
627
628 /* Get the Socket Structure associate to this Socket*/
629 Socket = GetSocketStructure(Handle);
630
631 if (Socket->SharedData.Listening)
632 return 0;
633
634 Status = NtCreateEvent(&SockEvent,
635 GENERIC_READ | GENERIC_WRITE,
636 NULL,
637 1,
638 FALSE);
639
640 if( !NT_SUCCESS(Status) )
641 return -1;
642
643 /* Set Up Listen Structure */
644 ListenData.UseSAN = FALSE;
645 ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
646 ListenData.Backlog = Backlog;
647
648 /* Send IOCTL */
649 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
650 SockEvent,
651 NULL,
652 NULL,
653 &IOSB,
654 IOCTL_AFD_START_LISTEN,
655 &ListenData,
656 sizeof(ListenData),
657 NULL,
658 0);
659
660 /* Wait for return */
661 if (Status == STATUS_PENDING)
662 {
663 WaitForSingleObject(SockEvent, INFINITE);
664 Status = IOSB.Status;
665 }
666
667 /* Set to Listening */
668 Socket->SharedData.Listening = TRUE;
669
670 NtClose( SockEvent );
671
672 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
673 }
674
675
676 int
677 WSPAPI
678 WSPSelect(int nfds,
679 fd_set *readfds,
680 fd_set *writefds,
681 fd_set *exceptfds,
682 struct timeval *timeout,
683 LPINT lpErrno)
684 {
685 IO_STATUS_BLOCK IOSB;
686 PAFD_POLL_INFO PollInfo;
687 NTSTATUS Status;
688 LONG HandleCount, OutCount = 0;
689 ULONG PollBufferSize;
690 PVOID PollBuffer;
691 ULONG i, j = 0, x;
692 HANDLE SockEvent;
693 BOOL HandleCounted;
694 LARGE_INTEGER Timeout;
695
696 /* Find out how many sockets we have, and how large the buffer needs
697 * to be */
698
699 HandleCount = ( readfds ? readfds->fd_count : 0 ) +
700 ( writefds ? writefds->fd_count : 0 ) +
701 ( exceptfds ? exceptfds->fd_count : 0 );
702
703 if ( HandleCount == 0 )
704 {
705 AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
706 HandleCount));
707 if (lpErrno) *lpErrno = WSAEINVAL;
708 return SOCKET_ERROR;
709 }
710
711 PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
712
713 AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n",
714 HandleCount, PollBufferSize));
715
716 /* Convert Timeout to NT Format */
717 if (timeout == NULL)
718 {
719 Timeout.u.LowPart = -1;
720 Timeout.u.HighPart = 0x7FFFFFFF;
721 AFD_DbgPrint(MAX_TRACE,("Infinite timeout\n"));
722 }
723 else
724 {
725 Timeout = RtlEnlargedIntegerMultiply
726 ((timeout->tv_sec * 1000) + (timeout->tv_usec / 1000), -10000);
727 /* Negative timeouts are illegal. Since the kernel represents an
728 * incremental timeout as a negative number, we check for a positive
729 * result.
730 */
731 if (Timeout.QuadPart > 0)
732 {
733 if (lpErrno) *lpErrno = WSAEINVAL;
734 return SOCKET_ERROR;
735 }
736 AFD_DbgPrint(MAX_TRACE,("Timeout: Orig %d.%06d kernel %d\n",
737 timeout->tv_sec, timeout->tv_usec,
738 Timeout.u.LowPart));
739 }
740
741 Status = NtCreateEvent(&SockEvent,
742 GENERIC_READ | GENERIC_WRITE,
743 NULL,
744 1,
745 FALSE);
746
747 if( !NT_SUCCESS(Status) )
748 return SOCKET_ERROR;
749
750 /* Allocate */
751 PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
752
753 if (!PollBuffer)
754 {
755 if (lpErrno)
756 *lpErrno = WSAEFAULT;
757 NtClose(SockEvent);
758 return SOCKET_ERROR;
759 }
760
761 PollInfo = (PAFD_POLL_INFO)PollBuffer;
762
763 RtlZeroMemory( PollInfo, PollBufferSize );
764
765 /* Number of handles for AFD to Check */
766 PollInfo->Exclusive = FALSE;
767 PollInfo->Timeout = Timeout;
768
769 if (readfds != NULL) {
770 for (i = 0; i < readfds->fd_count; i++, j++)
771 {
772 PollInfo->Handles[j].Handle = readfds->fd_array[i];
773 PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
774 AFD_EVENT_DISCONNECT |
775 AFD_EVENT_ABORT |
776 AFD_EVENT_ACCEPT;
777 }
778 }
779 if (writefds != NULL)
780 {
781 for (i = 0; i < writefds->fd_count; i++, j++)
782 {
783 PollInfo->Handles[j].Handle = writefds->fd_array[i];
784 PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
785 }
786 }
787 if (exceptfds != NULL)
788 {
789 for (i = 0; i < exceptfds->fd_count; i++, j++)
790 {
791 PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
792 PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
793 }
794 }
795
796 PollInfo->HandleCount = j;
797 PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
798
799 /* Send IOCTL */
800 Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
801 SockEvent,
802 NULL,
803 NULL,
804 &IOSB,
805 IOCTL_AFD_SELECT,
806 PollInfo,
807 PollBufferSize,
808 PollInfo,
809 PollBufferSize);
810
811 AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
812
813 /* Wait for Completition */
814 if (Status == STATUS_PENDING)
815 {
816 WaitForSingleObject(SockEvent, INFINITE);
817 }
818
819 /* Clear the Structures */
820 if( readfds )
821 FD_ZERO(readfds);
822 if( writefds )
823 FD_ZERO(writefds);
824 if( exceptfds )
825 FD_ZERO(exceptfds);
826
827 /* Loop through return structure */
828 HandleCount = PollInfo->HandleCount;
829
830 /* Return in FDSET Format */
831 for (i = 0; i < HandleCount; i++)
832 {
833 HandleCounted = FALSE;
834 for(x = 1; x; x<<=1)
835 {
836 switch (PollInfo->Handles[i].Events & x)
837 {
838 case AFD_EVENT_RECEIVE:
839 case AFD_EVENT_DISCONNECT:
840 case AFD_EVENT_ABORT:
841 case AFD_EVENT_ACCEPT:
842 case AFD_EVENT_CLOSE:
843 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
844 PollInfo->Handles[i].Events,
845 PollInfo->Handles[i].Handle));
846 if (! HandleCounted)
847 {
848 OutCount++;
849 HandleCounted = TRUE;
850 }
851 if( readfds )
852 FD_SET(PollInfo->Handles[i].Handle, readfds);
853 break;
854 case AFD_EVENT_SEND:
855 case AFD_EVENT_CONNECT:
856 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
857 PollInfo->Handles[i].Events,
858 PollInfo->Handles[i].Handle));
859 if (! HandleCounted)
860 {
861 OutCount++;
862 HandleCounted = TRUE;
863 }
864 if( writefds )
865 FD_SET(PollInfo->Handles[i].Handle, writefds);
866 break;
867 case AFD_EVENT_OOB_RECEIVE:
868 case AFD_EVENT_CONNECT_FAIL:
869 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
870 PollInfo->Handles[i].Events,
871 PollInfo->Handles[i].Handle));
872 if (! HandleCounted)
873 {
874 OutCount++;
875 HandleCounted = TRUE;
876 }
877 if( exceptfds )
878 FD_SET(PollInfo->Handles[i].Handle, exceptfds);
879 break;
880 }
881 }
882 }
883
884 HeapFree( GlobalHeap, 0, PollBuffer );
885 NtClose( SockEvent );
886
887 if( lpErrno )
888 {
889 switch( IOSB.Status )
890 {
891 case STATUS_SUCCESS:
892 case STATUS_TIMEOUT:
893 *lpErrno = 0;
894 break;
895 default:
896 *lpErrno = WSAEINVAL;
897 break;
898 }
899 AFD_DbgPrint(MID_TRACE,("*lpErrno = %x\n", *lpErrno));
900 }
901
902 AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
903
904 return OutCount;
905 }
906
907 SOCKET
908 WSPAPI
909 WSPAccept(SOCKET Handle,
910 struct sockaddr *SocketAddress,
911 int *SocketAddressLength,
912 LPCONDITIONPROC lpfnCondition,
913 DWORD dwCallbackData,
914 LPINT lpErrno)
915 {
916 IO_STATUS_BLOCK IOSB;
917 PAFD_RECEIVED_ACCEPT_DATA ListenReceiveData;
918 AFD_ACCEPT_DATA AcceptData;
919 AFD_DEFER_ACCEPT_DATA DeferData;
920 AFD_PENDING_ACCEPT_DATA PendingAcceptData;
921 PSOCKET_INFORMATION Socket = NULL;
922 NTSTATUS Status;
923 struct fd_set ReadSet;
924 struct timeval Timeout;
925 PVOID PendingData = NULL;
926 ULONG PendingDataLength = 0;
927 PVOID CalleeDataBuffer;
928 WSABUF CallerData, CalleeID, CallerID, CalleeData;
929 PSOCKADDR RemoteAddress = NULL;
930 GROUP GroupID = 0;
931 ULONG CallBack;
932 WSAPROTOCOL_INFOW ProtocolInfo;
933 SOCKET AcceptSocket;
934 UCHAR ReceiveBuffer[0x1A];
935 HANDLE SockEvent;
936
937 Status = NtCreateEvent(&SockEvent,
938 GENERIC_READ | GENERIC_WRITE,
939 NULL,
940 1,
941 FALSE);
942
943 if( !NT_SUCCESS(Status) )
944 {
945 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
946 return INVALID_SOCKET;
947 }
948
949 /* Dynamic Structure...ugh */
950 ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
951
952 /* Get the Socket Structure associate to this Socket*/
953 Socket = GetSocketStructure(Handle);
954
955 /* If this is non-blocking, make sure there's something for us to accept */
956 FD_ZERO(&ReadSet);
957 FD_SET(Socket->Handle, &ReadSet);
958 Timeout.tv_sec=0;
959 Timeout.tv_usec=0;
960
961 WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
962
963 if (ReadSet.fd_array[0] != Socket->Handle)
964 {
965 NtClose(SockEvent);
966 return 0;
967 }
968
969 /* Send IOCTL */
970 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
971 SockEvent,
972 NULL,
973 NULL,
974 &IOSB,
975 IOCTL_AFD_WAIT_FOR_LISTEN,
976 NULL,
977 0,
978 ListenReceiveData,
979 0xA + sizeof(*ListenReceiveData));
980
981 /* Wait for return */
982 if (Status == STATUS_PENDING)
983 {
984 WaitForSingleObject(SockEvent, INFINITE);
985 Status = IOSB.Status;
986 }
987
988 if (!NT_SUCCESS(Status))
989 {
990 NtClose( SockEvent );
991 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
992 return INVALID_SOCKET;
993 }
994
995 if (lpfnCondition != NULL)
996 {
997 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0)
998 {
999 /* Find out how much data is pending */
1000 PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1001 PendingAcceptData.ReturnSize = TRUE;
1002
1003 /* Send IOCTL */
1004 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1005 SockEvent,
1006 NULL,
1007 NULL,
1008 &IOSB,
1009 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1010 &PendingAcceptData,
1011 sizeof(PendingAcceptData),
1012 &PendingAcceptData,
1013 sizeof(PendingAcceptData));
1014
1015 /* Wait for return */
1016 if (Status == STATUS_PENDING)
1017 {
1018 WaitForSingleObject(SockEvent, INFINITE);
1019 Status = IOSB.Status;
1020 }
1021
1022 if (!NT_SUCCESS(Status))
1023 {
1024 NtClose( SockEvent );
1025 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1026 return INVALID_SOCKET;
1027 }
1028
1029 /* How much data to allocate */
1030 PendingDataLength = IOSB.Information;
1031
1032 if (PendingDataLength)
1033 {
1034 /* Allocate needed space */
1035 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
1036
1037 /* We want the data now */
1038 PendingAcceptData.ReturnSize = FALSE;
1039
1040 /* Send IOCTL */
1041 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1042 SockEvent,
1043 NULL,
1044 NULL,
1045 &IOSB,
1046 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1047 &PendingAcceptData,
1048 sizeof(PendingAcceptData),
1049 PendingData,
1050 PendingDataLength);
1051
1052 /* Wait for return */
1053 if (Status == STATUS_PENDING)
1054 {
1055 WaitForSingleObject(SockEvent, INFINITE);
1056 Status = IOSB.Status;
1057 }
1058
1059 if (!NT_SUCCESS(Status))
1060 {
1061 NtClose( SockEvent );
1062 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1063 return INVALID_SOCKET;
1064 }
1065 }
1066 }
1067
1068 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1069 {
1070 /* I don't support this yet */
1071 }
1072
1073 /* Build Callee ID */
1074 CalleeID.buf = (PVOID)Socket->LocalAddress;
1075 CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
1076
1077 /* Set up Address in SOCKADDR Format */
1078 RtlCopyMemory (RemoteAddress,
1079 &ListenReceiveData->Address.Address[0].AddressType,
1080 sizeof(*RemoteAddress));
1081
1082 /* Build Caller ID */
1083 CallerID.buf = (PVOID)RemoteAddress;
1084 CallerID.len = sizeof(*RemoteAddress);
1085
1086 /* Build Caller Data */
1087 CallerData.buf = PendingData;
1088 CallerData.len = PendingDataLength;
1089
1090 /* Check if socket supports Conditional Accept */
1091 if (Socket->SharedData.UseDelayedAcceptance != 0)
1092 {
1093 /* Allocate Buffer for Callee Data */
1094 CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
1095 CalleeData.buf = CalleeDataBuffer;
1096 CalleeData.len = 4096;
1097 }
1098 else
1099 {
1100 /* Nothing */
1101 CalleeData.buf = 0;
1102 CalleeData.len = 0;
1103 }
1104
1105 /* Call the Condition Function */
1106 CallBack = (lpfnCondition)(&CallerID,
1107 CallerData.buf == NULL ? NULL : &CallerData,
1108 NULL,
1109 NULL,
1110 &CalleeID,
1111 CalleeData.buf == NULL ? NULL : &CalleeData,
1112 &GroupID,
1113 dwCallbackData);
1114
1115 if (((CallBack == CF_ACCEPT) && GroupID) != 0)
1116 {
1117 /* TBD: Check for Validity */
1118 }
1119
1120 if (CallBack == CF_ACCEPT)
1121 {
1122 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1123 {
1124 /* I don't support this yet */
1125 }
1126 if (CalleeData.buf)
1127 {
1128 // SockSetConnectData Sockets(SocketID), IOCTL_AFD_SET_CONNECT_DATA, CalleeData.Buffer, CalleeData.BuffSize, 0
1129 }
1130 }
1131 else
1132 {
1133 /* Callback rejected. Build Defer Structure */
1134 DeferData.SequenceNumber = ListenReceiveData->SequenceNumber;
1135 DeferData.RejectConnection = (CallBack == CF_REJECT);
1136
1137 /* Send IOCTL */
1138 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1139 SockEvent,
1140 NULL,
1141 NULL,
1142 &IOSB,
1143 IOCTL_AFD_DEFER_ACCEPT,
1144 &DeferData,
1145 sizeof(DeferData),
1146 NULL,
1147 0);
1148
1149 /* Wait for return */
1150 if (Status == STATUS_PENDING)
1151 {
1152 WaitForSingleObject(SockEvent, INFINITE);
1153 Status = IOSB.Status;
1154 }
1155
1156 NtClose( SockEvent );
1157
1158 if (!NT_SUCCESS(Status))
1159 {
1160 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1161 return INVALID_SOCKET;
1162 }
1163
1164 if (CallBack == CF_REJECT )
1165 {
1166 *lpErrno = WSAECONNREFUSED;
1167 return INVALID_SOCKET;
1168 }
1169 else
1170 {
1171 *lpErrno = WSAECONNREFUSED;
1172 return INVALID_SOCKET;
1173 }
1174 }
1175 }
1176
1177 /* Create a new Socket */
1178 ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
1179 ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
1180 ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
1181
1182 AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
1183 Socket->SharedData.SocketType,
1184 Socket->SharedData.Protocol,
1185 &ProtocolInfo,
1186 GroupID,
1187 Socket->SharedData.CreateFlags,
1188 NULL);
1189
1190 /* Set up the Accept Structure */
1191 AcceptData.ListenHandle = AcceptSocket;
1192 AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1193
1194 /* Send IOCTL to Accept */
1195 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1196 SockEvent,
1197 NULL,
1198 NULL,
1199 &IOSB,
1200 IOCTL_AFD_ACCEPT,
1201 &AcceptData,
1202 sizeof(AcceptData),
1203 NULL,
1204 0);
1205
1206 /* Wait for return */
1207 if (Status == STATUS_PENDING)
1208 {
1209 WaitForSingleObject(SockEvent, INFINITE);
1210 Status = IOSB.Status;
1211 }
1212
1213 if (!NT_SUCCESS(Status))
1214 {
1215 NtClose(SockEvent);
1216 WSPCloseSocket( AcceptSocket, lpErrno );
1217 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1218 return INVALID_SOCKET;
1219 }
1220
1221 /* Return Address in SOCKADDR FORMAT */
1222 if( SocketAddress )
1223 {
1224 RtlCopyMemory (SocketAddress,
1225 &ListenReceiveData->Address.Address[0].AddressType,
1226 sizeof(*RemoteAddress));
1227 if( SocketAddressLength )
1228 *SocketAddressLength = ListenReceiveData->Address.Address[0].AddressLength;
1229 }
1230
1231 NtClose( SockEvent );
1232
1233 /* Re-enable Async Event */
1234 SockReenableAsyncSelectEvent(Socket, FD_ACCEPT);
1235
1236 AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
1237
1238 *lpErrno = 0;
1239
1240 /* Return Socket */
1241 return AcceptSocket;
1242 }
1243
1244 int
1245 WSPAPI
1246 WSPConnect(SOCKET Handle,
1247 const struct sockaddr * SocketAddress,
1248 int SocketAddressLength,
1249 LPWSABUF lpCallerData,
1250 LPWSABUF lpCalleeData,
1251 LPQOS lpSQOS,
1252 LPQOS lpGQOS,
1253 LPINT lpErrno)
1254 {
1255 IO_STATUS_BLOCK IOSB;
1256 PAFD_CONNECT_INFO ConnectInfo;
1257 PSOCKET_INFORMATION Socket = NULL;
1258 NTSTATUS Status;
1259 UCHAR ConnectBuffer[0x22];
1260 ULONG ConnectDataLength;
1261 ULONG InConnectDataLength;
1262 INT BindAddressLength;
1263 PSOCKADDR BindAddress;
1264 HANDLE SockEvent;
1265
1266 Status = NtCreateEvent(&SockEvent,
1267 GENERIC_READ | GENERIC_WRITE,
1268 NULL,
1269 1,
1270 FALSE);
1271
1272 if( !NT_SUCCESS(Status) )
1273 return -1;
1274
1275 AFD_DbgPrint(MID_TRACE,("Called\n"));
1276
1277 /* Get the Socket Structure associate to this Socket*/
1278 Socket = GetSocketStructure(Handle);
1279
1280 /* Bind us First */
1281 if (Socket->SharedData.State == SocketOpen)
1282 {
1283 /* Get the Wildcard Address */
1284 BindAddressLength = Socket->HelperData->MaxWSAddressLength;
1285 BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
1286 Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
1287 BindAddress,
1288 &BindAddressLength);
1289 /* Bind it */
1290 WSPBind(Handle, BindAddress, BindAddressLength, NULL);
1291 }
1292
1293 /* Set the Connect Data */
1294 if (lpCallerData != NULL)
1295 {
1296 ConnectDataLength = lpCallerData->len;
1297 Status = NtDeviceIoControlFile((HANDLE)Handle,
1298 SockEvent,
1299 NULL,
1300 NULL,
1301 &IOSB,
1302 IOCTL_AFD_SET_CONNECT_DATA,
1303 lpCallerData->buf,
1304 ConnectDataLength,
1305 NULL,
1306 0);
1307 /* Wait for return */
1308 if (Status == STATUS_PENDING)
1309 {
1310 WaitForSingleObject(SockEvent, INFINITE);
1311 Status = IOSB.Status;
1312 }
1313 }
1314
1315 /* Dynamic Structure...ugh */
1316 ConnectInfo = (PAFD_CONNECT_INFO)ConnectBuffer;
1317
1318 /* Set up Address in TDI Format */
1319 ConnectInfo->RemoteAddress.TAAddressCount = 1;
1320 ConnectInfo->RemoteAddress.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
1321 ConnectInfo->RemoteAddress.Address[0].AddressType = SocketAddress->sa_family;
1322 RtlCopyMemory (ConnectInfo->RemoteAddress.Address[0].Address,
1323 SocketAddress->sa_data,
1324 SocketAddressLength - sizeof(SocketAddress->sa_family));
1325
1326 /*
1327 * Disable FD_WRITE and FD_CONNECT
1328 * The latter fixes a race condition where the FD_CONNECT is re-enabled
1329 * at the end of this function right after the Async Thread disables it.
1330 * This should only happen at the *next* WSPConnect
1331 */
1332 if (Socket->SharedData.AsyncEvents & FD_CONNECT)
1333 {
1334 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
1335 }
1336
1337 /* Tell AFD that we want Connection Data back, have it allocate a buffer */
1338 if (lpCalleeData != NULL)
1339 {
1340 InConnectDataLength = lpCalleeData->len;
1341 Status = NtDeviceIoControlFile((HANDLE)Handle,
1342 SockEvent,
1343 NULL,
1344 NULL,
1345 &IOSB,
1346 IOCTL_AFD_SET_CONNECT_DATA_SIZE,
1347 &InConnectDataLength,
1348 sizeof(InConnectDataLength),
1349 NULL,
1350 0);
1351
1352 /* Wait for return */
1353 if (Status == STATUS_PENDING)
1354 {
1355 WaitForSingleObject(SockEvent, INFINITE);
1356 Status = IOSB.Status;
1357 }
1358 }
1359
1360 /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
1361 ConnectInfo->Root = 0;
1362 ConnectInfo->UseSAN = FALSE;
1363 ConnectInfo->Unknown = 0;
1364
1365 /* FIXME: Handle Async Connect */
1366 if (Socket->SharedData.NonBlocking)
1367 {
1368 AFD_DbgPrint(MIN_TRACE, ("Async Connect UNIMPLEMENTED!\n"));
1369 }
1370
1371 /* Send IOCTL */
1372 Status = NtDeviceIoControlFile((HANDLE)Handle,
1373 SockEvent,
1374 NULL,
1375 NULL,
1376 &IOSB,
1377 IOCTL_AFD_CONNECT,
1378 ConnectInfo,
1379 0x22,
1380 NULL,
1381 0);
1382 /* Wait for return */
1383 if (Status == STATUS_PENDING)
1384 {
1385 WaitForSingleObject(SockEvent, INFINITE);
1386 Status = IOSB.Status;
1387 }
1388
1389 /* Get any pending connect data */
1390 if (lpCalleeData != NULL)
1391 {
1392 Status = NtDeviceIoControlFile((HANDLE)Handle,
1393 SockEvent,
1394 NULL,
1395 NULL,
1396 &IOSB,
1397 IOCTL_AFD_GET_CONNECT_DATA,
1398 NULL,
1399 0,
1400 lpCalleeData->buf,
1401 lpCalleeData->len);
1402 /* Wait for return */
1403 if (Status == STATUS_PENDING)
1404 {
1405 WaitForSingleObject(SockEvent, INFINITE);
1406 Status = IOSB.Status;
1407 }
1408 }
1409
1410 /* Re-enable Async Event */
1411 SockReenableAsyncSelectEvent(Socket, FD_WRITE);
1412
1413 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
1414 SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
1415
1416 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1417
1418 NtClose( SockEvent );
1419
1420 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1421 }
1422 int
1423 WSPAPI
1424 WSPShutdown(SOCKET Handle,
1425 int HowTo,
1426 LPINT lpErrno)
1427
1428 {
1429 IO_STATUS_BLOCK IOSB;
1430 AFD_DISCONNECT_INFO DisconnectInfo;
1431 PSOCKET_INFORMATION Socket = NULL;
1432 NTSTATUS Status;
1433 HANDLE SockEvent;
1434
1435 Status = NtCreateEvent(&SockEvent,
1436 GENERIC_READ | GENERIC_WRITE,
1437 NULL,
1438 1,
1439 FALSE);
1440
1441 if( !NT_SUCCESS(Status) )
1442 return -1;
1443
1444 AFD_DbgPrint(MID_TRACE,("Called\n"));
1445
1446 /* Get the Socket Structure associate to this Socket*/
1447 Socket = GetSocketStructure(Handle);
1448
1449 /* Set AFD Disconnect Type */
1450 switch (HowTo)
1451 {
1452 case SD_RECEIVE:
1453 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
1454 Socket->SharedData.ReceiveShutdown = TRUE;
1455 break;
1456 case SD_SEND:
1457 DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
1458 Socket->SharedData.SendShutdown = TRUE;
1459 break;
1460 case SD_BOTH:
1461 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
1462 Socket->SharedData.ReceiveShutdown = TRUE;
1463 Socket->SharedData.SendShutdown = TRUE;
1464 break;
1465 }
1466
1467 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
1468
1469 /* Send IOCTL */
1470 Status = NtDeviceIoControlFile((HANDLE)Handle,
1471 SockEvent,
1472 NULL,
1473 NULL,
1474 &IOSB,
1475 IOCTL_AFD_DISCONNECT,
1476 &DisconnectInfo,
1477 sizeof(DisconnectInfo),
1478 NULL,
1479 0);
1480
1481 /* Wait for return */
1482 if (Status == STATUS_PENDING)
1483 {
1484 WaitForSingleObject(SockEvent, INFINITE);
1485 Status = IOSB.Status;
1486 }
1487
1488 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1489
1490 NtClose( SockEvent );
1491
1492 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1493 }
1494
1495
1496 INT
1497 WSPAPI
1498 WSPGetSockName(IN SOCKET Handle,
1499 OUT LPSOCKADDR Name,
1500 IN OUT LPINT NameLength,
1501 OUT LPINT lpErrno)
1502 {
1503 IO_STATUS_BLOCK IOSB;
1504 ULONG TdiAddressSize;
1505 PTDI_ADDRESS_INFO TdiAddress;
1506 PTRANSPORT_ADDRESS SocketAddress;
1507 PSOCKET_INFORMATION Socket = NULL;
1508 NTSTATUS Status;
1509 HANDLE SockEvent;
1510
1511 Status = NtCreateEvent(&SockEvent,
1512 GENERIC_READ | GENERIC_WRITE,
1513 NULL,
1514 1,
1515 FALSE);
1516
1517 if( !NT_SUCCESS(Status) )
1518 return SOCKET_ERROR;
1519
1520 /* Get the Socket Structure associate to this Socket*/
1521 Socket = GetSocketStructure(Handle);
1522
1523 /* Allocate a buffer for the address */
1524 TdiAddressSize =
1525 sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
1526 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1527
1528 if ( TdiAddress == NULL )
1529 {
1530 NtClose( SockEvent );
1531 *lpErrno = WSAENOBUFS;
1532 return SOCKET_ERROR;
1533 }
1534
1535 SocketAddress = &TdiAddress->Address;
1536
1537 /* Send IOCTL */
1538 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1539 SockEvent,
1540 NULL,
1541 NULL,
1542 &IOSB,
1543 IOCTL_AFD_GET_SOCK_NAME,
1544 NULL,
1545 0,
1546 TdiAddress,
1547 TdiAddressSize);
1548
1549 /* Wait for return */
1550 if (Status == STATUS_PENDING)
1551 {
1552 WaitForSingleObject(SockEvent, INFINITE);
1553 Status = IOSB.Status;
1554 }
1555
1556 NtClose( SockEvent );
1557
1558 if (NT_SUCCESS(Status))
1559 {
1560 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1561 {
1562 Name->sa_family = SocketAddress->Address[0].AddressType;
1563 RtlCopyMemory (Name->sa_data,
1564 SocketAddress->Address[0].Address,
1565 SocketAddress->Address[0].AddressLength);
1566 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1567 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
1568 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1569 ((struct sockaddr_in *)Name)->sin_port));
1570 HeapFree(GlobalHeap, 0, TdiAddress);
1571 return 0;
1572 }
1573 else
1574 {
1575 HeapFree(GlobalHeap, 0, TdiAddress);
1576 *lpErrno = WSAEFAULT;
1577 return SOCKET_ERROR;
1578 }
1579 }
1580
1581 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1582 }
1583
1584
1585 INT
1586 WSPAPI
1587 WSPGetPeerName(IN SOCKET s,
1588 OUT LPSOCKADDR Name,
1589 IN OUT LPINT NameLength,
1590 OUT LPINT lpErrno)
1591 {
1592 IO_STATUS_BLOCK IOSB;
1593 ULONG TdiAddressSize;
1594 PTRANSPORT_ADDRESS SocketAddress;
1595 PSOCKET_INFORMATION Socket = NULL;
1596 NTSTATUS Status;
1597 HANDLE SockEvent;
1598
1599 Status = NtCreateEvent(&SockEvent,
1600 GENERIC_READ | GENERIC_WRITE,
1601 NULL,
1602 1,
1603 FALSE);
1604
1605 if( !NT_SUCCESS(Status) )
1606 return SOCKET_ERROR;
1607
1608 /* Get the Socket Structure associate to this Socket*/
1609 Socket = GetSocketStructure(s);
1610
1611 /* Allocate a buffer for the address */
1612 TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
1613 SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1614
1615 if ( SocketAddress == NULL )
1616 {
1617 NtClose( SockEvent );
1618 *lpErrno = WSAENOBUFS;
1619 return SOCKET_ERROR;
1620 }
1621
1622 /* Send IOCTL */
1623 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1624 SockEvent,
1625 NULL,
1626 NULL,
1627 &IOSB,
1628 IOCTL_AFD_GET_PEER_NAME,
1629 NULL,
1630 0,
1631 SocketAddress,
1632 TdiAddressSize);
1633
1634 /* Wait for return */
1635 if (Status == STATUS_PENDING)
1636 {
1637 WaitForSingleObject(SockEvent, INFINITE);
1638 Status = IOSB.Status;
1639 }
1640
1641 NtClose( SockEvent );
1642
1643 if (NT_SUCCESS(Status))
1644 {
1645 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1646 {
1647 Name->sa_family = SocketAddress->Address[0].AddressType;
1648 RtlCopyMemory (Name->sa_data,
1649 SocketAddress->Address[0].Address,
1650 SocketAddress->Address[0].AddressLength);
1651 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1652 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
1653 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1654 ((struct sockaddr_in *)Name)->sin_port));
1655 HeapFree(GlobalHeap, 0, SocketAddress);
1656 return 0;
1657 }
1658 else
1659 {
1660 HeapFree(GlobalHeap, 0, SocketAddress);
1661 *lpErrno = WSAEFAULT;
1662 return SOCKET_ERROR;
1663 }
1664 }
1665
1666 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1667 }
1668
1669 INT
1670 WSPAPI
1671 WSPIoctl(IN SOCKET Handle,
1672 IN DWORD dwIoControlCode,
1673 IN LPVOID lpvInBuffer,
1674 IN DWORD cbInBuffer,
1675 OUT LPVOID lpvOutBuffer,
1676 IN DWORD cbOutBuffer,
1677 OUT LPDWORD lpcbBytesReturned,
1678 IN LPWSAOVERLAPPED lpOverlapped,
1679 IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
1680 IN LPWSATHREADID lpThreadId,
1681 OUT LPINT lpErrno)
1682 {
1683 PSOCKET_INFORMATION Socket = NULL;
1684
1685 /* Get the Socket Structure associate to this Socket*/
1686 Socket = GetSocketStructure(Handle);
1687
1688 switch( dwIoControlCode )
1689 {
1690 case FIONBIO:
1691 if( cbInBuffer < sizeof(INT) )
1692 return SOCKET_ERROR;
1693 Socket->SharedData.NonBlocking = *((PINT)lpvInBuffer) ? 1 : 0;
1694 AFD_DbgPrint(MID_TRACE,("[%x] Set nonblocking %d\n", Handle, Socket->SharedData.NonBlocking));
1695 return 0;
1696 case FIONREAD:
1697 return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
1698 default:
1699 *lpErrno = WSAEINVAL;
1700 return SOCKET_ERROR;
1701 }
1702 }
1703
1704
1705 INT
1706 WSPAPI
1707 WSPGetSockOpt(IN SOCKET Handle,
1708 IN INT Level,
1709 IN INT OptionName,
1710 OUT CHAR FAR* OptionValue,
1711 IN OUT LPINT OptionLength,
1712 OUT LPINT lpErrno)
1713 {
1714 PSOCKET_INFORMATION Socket = NULL;
1715 PVOID Buffer;
1716 INT BufferSize;
1717 BOOLEAN BoolBuffer;
1718
1719 /* Get the Socket Structure associate to this Socket*/
1720 Socket = GetSocketStructure(Handle);
1721 if (Socket == NULL)
1722 {
1723 *lpErrno = WSAENOTSOCK;
1724 return SOCKET_ERROR;
1725 }
1726
1727 AFD_DbgPrint(MID_TRACE, ("Called\n"));
1728
1729 switch (Level)
1730 {
1731 case SOL_SOCKET:
1732 switch (OptionName)
1733 {
1734 case SO_TYPE:
1735 Buffer = &Socket->SharedData.SocketType;
1736 BufferSize = sizeof(INT);
1737 break;
1738
1739 case SO_RCVBUF:
1740 Buffer = &Socket->SharedData.SizeOfRecvBuffer;
1741 BufferSize = sizeof(INT);
1742 break;
1743
1744 case SO_SNDBUF:
1745 Buffer = &Socket->SharedData.SizeOfSendBuffer;
1746 BufferSize = sizeof(INT);
1747 break;
1748
1749 case SO_ACCEPTCONN:
1750 BoolBuffer = Socket->SharedData.Listening;
1751 Buffer = &BoolBuffer;
1752 BufferSize = sizeof(BOOLEAN);
1753 break;
1754
1755 case SO_BROADCAST:
1756 BoolBuffer = Socket->SharedData.Broadcast;
1757 Buffer = &BoolBuffer;
1758 BufferSize = sizeof(BOOLEAN);
1759 break;
1760
1761 case SO_DEBUG:
1762 BoolBuffer = Socket->SharedData.Debug;
1763 Buffer = &BoolBuffer;
1764 BufferSize = sizeof(BOOLEAN);
1765 break;
1766
1767 /* case SO_CONDITIONAL_ACCEPT: */
1768 case SO_DONTLINGER:
1769 case SO_DONTROUTE:
1770 case SO_ERROR:
1771 case SO_GROUP_ID:
1772 case SO_GROUP_PRIORITY:
1773 case SO_KEEPALIVE:
1774 case SO_LINGER:
1775 case SO_MAX_MSG_SIZE:
1776 case SO_OOBINLINE:
1777 case SO_PROTOCOL_INFO:
1778 case SO_REUSEADDR:
1779 AFD_DbgPrint(MID_TRACE, ("Unimplemented option (%x)\n",
1780 OptionName));
1781
1782 default:
1783 *lpErrno = WSAEINVAL;
1784 return SOCKET_ERROR;
1785 }
1786
1787 if (*OptionLength < BufferSize)
1788 {
1789 *lpErrno = WSAEFAULT;
1790 *OptionLength = BufferSize;
1791 return SOCKET_ERROR;
1792 }
1793 RtlCopyMemory(OptionValue, Buffer, BufferSize);
1794
1795 return 0;
1796
1797 case IPPROTO_TCP: /* FIXME */
1798 default:
1799 *lpErrno = WSAEINVAL;
1800 return SOCKET_ERROR;
1801 }
1802 }
1803
1804
1805 /*
1806 * FUNCTION: Initialize service provider for a client
1807 * ARGUMENTS:
1808 * wVersionRequested = Highest WinSock SPI version that the caller can use
1809 * lpWSPData = Address of WSPDATA structure to initialize
1810 * lpProtocolInfo = Pointer to structure that defines the desired protocol
1811 * UpcallTable = Pointer to upcall table of the WinSock DLL
1812 * lpProcTable = Address of procedure table to initialize
1813 * RETURNS:
1814 * Status of operation
1815 */
1816 INT
1817 WSPAPI
1818 WSPStartup(IN WORD wVersionRequested,
1819 OUT LPWSPDATA lpWSPData,
1820 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
1821 IN WSPUPCALLTABLE UpcallTable,
1822 OUT LPWSPPROC_TABLE lpProcTable)
1823
1824 {
1825 NTSTATUS Status;
1826
1827 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
1828 Status = NO_ERROR;
1829 Upcalls = UpcallTable;
1830
1831 if (Status == NO_ERROR)
1832 {
1833 lpProcTable->lpWSPAccept = WSPAccept;
1834 lpProcTable->lpWSPAddressToString = WSPAddressToString;
1835 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
1836 lpProcTable->lpWSPBind = WSPBind;
1837 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
1838 lpProcTable->lpWSPCleanup = WSPCleanup;
1839 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
1840 lpProcTable->lpWSPConnect = WSPConnect;
1841 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
1842 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
1843 lpProcTable->lpWSPEventSelect = WSPEventSelect;
1844 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
1845 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
1846 lpProcTable->lpWSPGetSockName = WSPGetSockName;
1847 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
1848 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
1849 lpProcTable->lpWSPIoctl = WSPIoctl;
1850 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
1851 lpProcTable->lpWSPListen = WSPListen;
1852 lpProcTable->lpWSPRecv = WSPRecv;
1853 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
1854 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
1855 lpProcTable->lpWSPSelect = WSPSelect;
1856 lpProcTable->lpWSPSend = WSPSend;
1857 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
1858 lpProcTable->lpWSPSendTo = WSPSendTo;
1859 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
1860 lpProcTable->lpWSPShutdown = WSPShutdown;
1861 lpProcTable->lpWSPSocket = WSPSocket;
1862 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
1863 lpWSPData->wVersion = MAKEWORD(2, 2);
1864 lpWSPData->wHighVersion = MAKEWORD(2, 2);
1865 }
1866
1867 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
1868
1869 return Status;
1870 }
1871
1872
1873 /*
1874 * FUNCTION: Cleans up service provider for a client
1875 * ARGUMENTS:
1876 * lpErrno = Address of buffer for error information
1877 * RETURNS:
1878 * 0 if successful, or SOCKET_ERROR if not
1879 */
1880 INT
1881 WSPAPI
1882 WSPCleanup(OUT LPINT lpErrno)
1883
1884 {
1885 AFD_DbgPrint(MAX_TRACE, ("\n"));
1886 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
1887 *lpErrno = NO_ERROR;
1888
1889 return 0;
1890 }
1891
1892
1893
1894 int
1895 GetSocketInformation(PSOCKET_INFORMATION Socket,
1896 ULONG AfdInformationClass,
1897 PULONG Ulong OPTIONAL,
1898 PLARGE_INTEGER LargeInteger OPTIONAL)
1899 {
1900 IO_STATUS_BLOCK IOSB;
1901 AFD_INFO InfoData;
1902 NTSTATUS Status;
1903 HANDLE SockEvent;
1904
1905 Status = NtCreateEvent(&SockEvent,
1906 GENERIC_READ | GENERIC_WRITE,
1907 NULL,
1908 1,
1909 FALSE);
1910
1911 if( !NT_SUCCESS(Status) )
1912 return -1;
1913
1914 /* Set Info Class */
1915 InfoData.InformationClass = AfdInformationClass;
1916
1917 /* Send IOCTL */
1918 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1919 SockEvent,
1920 NULL,
1921 NULL,
1922 &IOSB,
1923 IOCTL_AFD_GET_INFO,
1924 &InfoData,
1925 sizeof(InfoData),
1926 &InfoData,
1927 sizeof(InfoData));
1928
1929 /* Wait for return */
1930 if (Status == STATUS_PENDING)
1931 {
1932 WaitForSingleObject(SockEvent, INFINITE);
1933 }
1934
1935 /* Return Information */
1936 *Ulong = InfoData.Information.Ulong;
1937 if (LargeInteger != NULL)
1938 {
1939 *LargeInteger = InfoData.Information.LargeInteger;
1940 }
1941
1942 NtClose( SockEvent );
1943
1944 return 0;
1945
1946 }
1947
1948
1949 int
1950 SetSocketInformation(PSOCKET_INFORMATION Socket,
1951 ULONG AfdInformationClass,
1952 PULONG Ulong OPTIONAL,
1953 PLARGE_INTEGER LargeInteger OPTIONAL)
1954 {
1955 IO_STATUS_BLOCK IOSB;
1956 AFD_INFO InfoData;
1957 NTSTATUS Status;
1958 HANDLE SockEvent;
1959
1960 Status = NtCreateEvent(&SockEvent,
1961 GENERIC_READ | GENERIC_WRITE,
1962 NULL,
1963 1,
1964 FALSE);
1965
1966 if( !NT_SUCCESS(Status) )
1967 return -1;
1968
1969 /* Set Info Class */
1970 InfoData.InformationClass = AfdInformationClass;
1971
1972 /* Set Information */
1973 InfoData.Information.Ulong = *Ulong;
1974 if (LargeInteger != NULL)
1975 {
1976 InfoData.Information.LargeInteger = *LargeInteger;
1977 }
1978
1979 AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
1980 AfdInformationClass, *Ulong));
1981
1982 /* Send IOCTL */
1983 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1984 SockEvent,
1985 NULL,
1986 NULL,
1987 &IOSB,
1988 IOCTL_AFD_SET_INFO,
1989 &InfoData,
1990 sizeof(InfoData),
1991 NULL,
1992 0);
1993
1994 /* Wait for return */
1995 if (Status == STATUS_PENDING)
1996 {
1997 WaitForSingleObject(SockEvent, INFINITE);
1998 }
1999
2000 NtClose( SockEvent );
2001
2002 return 0;
2003
2004 }
2005
2006 PSOCKET_INFORMATION
2007 GetSocketStructure(SOCKET Handle)
2008 {
2009 ULONG i;
2010
2011 for (i=0; i<SocketCount; i++)
2012 {
2013 if (Sockets[i]->Handle == Handle)
2014 {
2015 return Sockets[i];
2016 }
2017 }
2018 return 0;
2019 }
2020
2021 int CreateContext(PSOCKET_INFORMATION Socket)
2022 {
2023 IO_STATUS_BLOCK IOSB;
2024 SOCKET_CONTEXT ContextData;
2025 NTSTATUS Status;
2026 HANDLE SockEvent;
2027
2028 Status = NtCreateEvent(&SockEvent,
2029 GENERIC_READ | GENERIC_WRITE,
2030 NULL,
2031 1,
2032 FALSE);
2033
2034 if( !NT_SUCCESS(Status) )
2035 return -1;
2036
2037 /* Create Context */
2038 ContextData.SharedData = Socket->SharedData;
2039 ContextData.SizeOfHelperData = 0;
2040 RtlCopyMemory (&ContextData.LocalAddress,
2041 Socket->LocalAddress,
2042 Socket->SharedData.SizeOfLocalAddress);
2043 RtlCopyMemory (&ContextData.RemoteAddress,
2044 Socket->RemoteAddress,
2045 Socket->SharedData.SizeOfRemoteAddress);
2046
2047 /* Send IOCTL */
2048 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
2049 SockEvent,
2050 NULL,
2051 NULL,
2052 &IOSB,
2053 IOCTL_AFD_SET_CONTEXT,
2054 &ContextData,
2055 sizeof(ContextData),
2056 NULL,
2057 0);
2058
2059 /* Wait for Completition */
2060 if (Status == STATUS_PENDING)
2061 {
2062 WaitForSingleObject(SockEvent, INFINITE);
2063 }
2064
2065 NtClose( SockEvent );
2066
2067 return 0;
2068 }
2069
2070 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
2071 {
2072 HANDLE hAsyncThread;
2073 DWORD AsyncThreadId;
2074 HANDLE AsyncEvent;
2075 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2076 NTSTATUS Status;
2077
2078 /* Check if the Thread Already Exists */
2079 if (SockAsyncThreadRefCount)
2080 {
2081 return TRUE;
2082 }
2083
2084 /* Create the Completion Port */
2085 if (!SockAsyncCompletionPort)
2086 {
2087 Status = NtCreateIoCompletion(&SockAsyncCompletionPort,
2088 IO_COMPLETION_ALL_ACCESS,
2089 NULL,
2090 2); // Allow 2 threads only
2091
2092 /* Protect Handle */
2093 HandleFlags.ProtectFromClose = TRUE;
2094 HandleFlags.Inherit = FALSE;
2095 Status = NtSetInformationObject(SockAsyncCompletionPort,
2096 ObjectHandleFlagInformation,
2097 &HandleFlags,
2098 sizeof(HandleFlags));
2099 }
2100
2101 /* Create the Async Event */
2102 Status = NtCreateEvent(&AsyncEvent,
2103 EVENT_ALL_ACCESS,
2104 NULL,
2105 NotificationEvent,
2106 FALSE);
2107
2108 /* Create the Async Thread */
2109 hAsyncThread = CreateThread(NULL,
2110 0,
2111 (LPTHREAD_START_ROUTINE)SockAsyncThread,
2112 NULL,
2113 0,
2114 &AsyncThreadId);
2115
2116 /* Close the Handle */
2117 NtClose(hAsyncThread);
2118
2119 /* Increase the Reference Count */
2120 SockAsyncThreadRefCount++;
2121 return TRUE;
2122 }
2123
2124 int SockAsyncThread(PVOID ThreadParam)
2125 {
2126 PVOID AsyncContext;
2127 PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
2128 IO_STATUS_BLOCK IOSB;
2129 NTSTATUS Status;
2130
2131 /* Make the Thread Higher Priority */
2132 SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
2133
2134 /* Do a KQUEUE/WorkItem Style Loop, thanks to IoCompletion Ports */
2135 do
2136 {
2137 Status = NtRemoveIoCompletion (SockAsyncCompletionPort,
2138 (PVOID*)&AsyncCompletionRoutine,
2139 &AsyncContext,
2140 &IOSB,
2141 NULL);
2142 /* Call the Async Function */
2143 if (NT_SUCCESS(Status))
2144 {
2145 (*AsyncCompletionRoutine)(AsyncContext, &IOSB);
2146 }
2147 else
2148 {
2149 /* It Failed, sleep for a second */
2150 Sleep(1000);
2151 }
2152 } while ((Status != STATUS_TIMEOUT));
2153
2154 /* The Thread has Ended */
2155 return 0;
2156 }
2157
2158 BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
2159 {
2160 UNICODE_STRING AfdHelper;
2161 OBJECT_ATTRIBUTES ObjectAttributes;
2162 IO_STATUS_BLOCK IoSb;
2163 NTSTATUS Status;
2164 FILE_COMPLETION_INFORMATION CompletionInfo;
2165 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2166
2167 /* First, make sure we're not already intialized */
2168 if (SockAsyncHelperAfdHandle)
2169 {
2170 return TRUE;
2171 }
2172
2173 /* Set up Handle Name and Object */
2174 RtlInitUnicodeString(&AfdHelper, L"\\Device\\Afd\\AsyncSelectHlp" );
2175 InitializeObjectAttributes(&ObjectAttributes,
2176 &AfdHelper,
2177 OBJ_INHERIT | OBJ_CASE_INSENSITIVE,
2178 NULL,
2179 NULL);
2180
2181 /* Open the Handle to AFD */
2182 Status = NtCreateFile(&SockAsyncHelperAfdHandle,
2183 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
2184 &ObjectAttributes,
2185 &IoSb,
2186 NULL,
2187 0,
2188 FILE_SHARE_READ | FILE_SHARE_WRITE,
2189 FILE_OPEN_IF,
2190 0,
2191 NULL,
2192 0);
2193
2194 /*
2195 * Now Set up the Completion Port Information
2196 * This means that whenever a Poll is finished, the routine will be executed
2197 */
2198 CompletionInfo.Port = SockAsyncCompletionPort;
2199 CompletionInfo.Key = SockAsyncSelectCompletionRoutine;
2200 Status = NtSetInformationFile(SockAsyncHelperAfdHandle,
2201 &IoSb,
2202 &CompletionInfo,
2203 sizeof(CompletionInfo),
2204 FileCompletionInformation);
2205
2206
2207 /* Protect the Handle */
2208 HandleFlags.ProtectFromClose = TRUE;
2209 HandleFlags.Inherit = FALSE;
2210 Status = NtSetInformationObject(SockAsyncCompletionPort,
2211 ObjectHandleFlagInformation,
2212 &HandleFlags,
2213 sizeof(HandleFlags));
2214
2215
2216 /* Set this variable to true so that Send/Recv/Accept will know wether to renable disabled events */
2217 SockAsyncSelectCalled = TRUE;
2218 return TRUE;
2219 }
2220
2221 VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2222 {
2223
2224 PASYNC_DATA AsyncData = Context;
2225 PSOCKET_INFORMATION Socket;
2226 ULONG x;
2227
2228 /* Get the Socket */
2229 Socket = AsyncData->ParentSocket;
2230
2231 /* Check if the Sequence Number Changed behind our back */
2232 if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber )
2233 {
2234 return;
2235 }
2236
2237 /* Check we were manually called b/c of a failure */
2238 if (!NT_SUCCESS(IoStatusBlock->Status))
2239 {
2240 /* FIXME: Perform Upcall */
2241 return;
2242 }
2243
2244 for (x = 1; x; x<<=1)
2245 {
2246 switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x)
2247 {
2248 case AFD_EVENT_RECEIVE:
2249 if (0 != (Socket->SharedData.AsyncEvents & FD_READ) &&
2250 0 == (Socket->SharedData.AsyncDisabledEvents & FD_READ))
2251 {
2252 /* Make the Notifcation */
2253 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2254 Socket->SharedData.wMsg,
2255 Socket->Handle,
2256 WSAMAKESELECTREPLY(FD_READ, 0));
2257 /* Disable this event until the next read(); */
2258 Socket->SharedData.AsyncDisabledEvents |= FD_READ;
2259 }
2260 break;
2261
2262 case AFD_EVENT_OOB_RECEIVE:
2263 if (0 != (Socket->SharedData.AsyncEvents & FD_OOB) &&
2264 0 == (Socket->SharedData.AsyncDisabledEvents & FD_OOB))
2265 {
2266 /* Make the Notifcation */
2267 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2268 Socket->SharedData.wMsg,
2269 Socket->Handle,
2270 WSAMAKESELECTREPLY(FD_OOB, 0));
2271 /* Disable this event until the next read(); */
2272 Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
2273 }
2274 break;
2275
2276 case AFD_EVENT_SEND:
2277 if (0 != (Socket->SharedData.AsyncEvents & FD_WRITE) &&
2278 0 == (Socket->SharedData.AsyncDisabledEvents & FD_WRITE))
2279 {
2280 /* Make the Notifcation */
2281 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2282 Socket->SharedData.wMsg,
2283 Socket->Handle,
2284 WSAMAKESELECTREPLY(FD_WRITE, 0));
2285 /* Disable this event until the next write(); */
2286 Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
2287 }
2288 break;
2289
2290 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2291 case AFD_EVENT_CONNECT:
2292 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
2293 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
2294 {
2295 /* Make the Notifcation */
2296 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2297 Socket->SharedData.wMsg,
2298 Socket->Handle,
2299 WSAMAKESELECTREPLY(FD_CONNECT, 0));
2300 /* Disable this event forever; */
2301 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT;
2302 }
2303 break;
2304
2305 case AFD_EVENT_ACCEPT:
2306 if (0 != (Socket->SharedData.AsyncEvents & FD_ACCEPT) &&
2307 0 == (Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT))
2308 {
2309 /* Make the Notifcation */
2310 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2311 Socket->SharedData.wMsg,
2312 Socket->Handle,
2313 WSAMAKESELECTREPLY(FD_ACCEPT, 0));
2314 /* Disable this event until the next accept(); */
2315 Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
2316 }
2317 break;
2318
2319 case AFD_EVENT_DISCONNECT:
2320 case AFD_EVENT_ABORT:
2321 case AFD_EVENT_CLOSE:
2322 if (0 != (Socket->SharedData.AsyncEvents & FD_CLOSE) &&
2323 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CLOSE))
2324 {
2325 /* Make the Notifcation */
2326 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2327 Socket->SharedData.wMsg,
2328 Socket->Handle,
2329 WSAMAKESELECTREPLY(FD_CLOSE, 0));
2330 /* Disable this event forever; */
2331 Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
2332 }
2333 break;
2334 /* FIXME: Support QOS */
2335 }
2336 }
2337
2338 /* Check if there are any events left for us to check */
2339 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2340 {
2341 return;
2342 }
2343
2344 /* Keep Polling */
2345 SockProcessAsyncSelect(Socket, AsyncData);
2346 return;
2347 }
2348
2349 VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
2350 {
2351
2352 ULONG lNetworkEvents;
2353 NTSTATUS Status;
2354
2355 /* Set up the Async Data Event Info */
2356 AsyncData->AsyncSelectInfo.Timeout.HighPart = 0x7FFFFFFF;
2357 AsyncData->AsyncSelectInfo.Timeout.LowPart = 0xFFFFFFFF;
2358 AsyncData->AsyncSelectInfo.HandleCount = 1;
2359 AsyncData->AsyncSelectInfo.Exclusive = TRUE;
2360 AsyncData->AsyncSelectInfo.Handles[0].Handle = Socket->Handle;
2361 AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
2362
2363 /* Remove unwanted events */
2364 lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
2365
2366 /* Set Events to wait for */
2367 if (lNetworkEvents & FD_READ)
2368 {
2369 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_RECEIVE;
2370 }
2371
2372 if (lNetworkEvents & FD_WRITE)
2373 {
2374 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_SEND;
2375 }
2376
2377 if (lNetworkEvents & FD_OOB)
2378 {
2379 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_OOB_RECEIVE;
2380 }
2381
2382 if (lNetworkEvents & FD_ACCEPT)
2383 {
2384 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_ACCEPT;
2385 }
2386
2387 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2388 if (lNetworkEvents & FD_CONNECT)
2389 {
2390 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_CONNECT | AFD_EVENT_CONNECT_FAIL;
2391 }
2392
2393 if (lNetworkEvents & FD_CLOSE)
2394 {
2395 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT;
2396 }
2397
2398 if (lNetworkEvents & FD_QOS)
2399 {
2400 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_QOS;
2401 }
2402
2403 if (lNetworkEvents & FD_GROUP_QOS)
2404 {
2405 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_GROUP_QOS;
2406 }
2407
2408 /* Send IOCTL */
2409 Status = NtDeviceIoControlFile (SockAsyncHelperAfdHandle,
2410 NULL,
2411 NULL,
2412 AsyncData,
2413 &AsyncData->IoStatusBlock,
2414 IOCTL_AFD_SELECT,
2415 &AsyncData->AsyncSelectInfo,
2416 sizeof(AsyncData->AsyncSelectInfo),
2417 &AsyncData->AsyncSelectInfo,
2418 sizeof(AsyncData->AsyncSelectInfo));
2419
2420 /* I/O Manager Won't call the completion routine, let's do it manually */
2421 if (NT_SUCCESS(Status))
2422 {
2423 return;
2424 }
2425 else
2426 {
2427 AsyncData->IoStatusBlock.Status = Status;
2428 SockAsyncSelectCompletionRoutine(AsyncData, &AsyncData->IoStatusBlock);
2429 }
2430 }
2431
2432 VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2433 {
2434 PASYNC_DATA AsyncData = Context;
2435 BOOL FreeContext = TRUE;
2436 PSOCKET_INFORMATION Socket;
2437
2438 /* Get the Socket */
2439 Socket = AsyncData->ParentSocket;
2440
2441 /* If someone closed it, stop the function */
2442 if (Socket->SharedData.State != SocketClosed)
2443 {
2444 /* Check if the Sequence Number changed by now, in which case quit */
2445 if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber)
2446 {
2447 /* Do the actuall select, if needed */
2448 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)))
2449 {
2450 SockProcessAsyncSelect(Socket, AsyncData);
2451 FreeContext = FALSE;
2452 }
2453 }
2454 }
2455
2456 /* Free the Context */
2457 if (FreeContext)
2458 {
2459 HeapFree(GetProcessHeap(), 0, AsyncData);
2460 }
2461
2462 return;
2463 }
2464
2465 VOID
2466 SockReenableAsyncSelectEvent (IN PSOCKET_INFORMATION Socket,
2467 IN ULONG Event)
2468 {
2469 PASYNC_DATA AsyncData;
2470
2471 /* Make sure the event is actually disabled */
2472 if (!(Socket->SharedData.AsyncDisabledEvents & Event))
2473 {
2474 return;
2475 }
2476
2477 /* Re-enable it */
2478 Socket->SharedData.AsyncDisabledEvents &= ~Event;
2479
2480 /* Return if no more events are being polled */
2481 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2482 {
2483 return;
2484 }
2485
2486 /* Wait on new events */
2487 AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
2488
2489 /* Create the Asynch Thread if Needed */
2490 SockCreateOrReferenceAsyncThread();
2491
2492 /* Increase the sequence number to stop anything else */
2493 Socket->SharedData.SequenceNumber++;
2494
2495 /* Set up the Async Data */
2496 AsyncData->ParentSocket = Socket;
2497 AsyncData->SequenceNumber = Socket->SharedData.SequenceNumber;
2498
2499 /* Begin Async Select by using I/O Completion */
2500 NtSetIoCompletion(SockAsyncCompletionPort,
2501 (PVOID)&SockProcessQueuedAsyncSelect,
2502 AsyncData,
2503 0,
2504 0);
2505
2506 /* All done */
2507 return;
2508 }
2509
2510 BOOL
2511 WINAPI
2512 DllMain(HANDLE hInstDll,
2513 ULONG dwReason,
2514 PVOID Reserved)
2515 {
2516
2517 switch (dwReason)
2518 {
2519 case DLL_PROCESS_ATTACH:
2520
2521 AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
2522
2523 /* Don't need thread attach notifications
2524 so disable them to improve performance */
2525 DisableThreadLibraryCalls(hInstDll);
2526
2527 /* List of DLL Helpers */
2528 InitializeListHead(&SockHelpersListHead);
2529
2530 /* Heap to use when allocating */
2531 GlobalHeap = GetProcessHeap();
2532
2533 /* Allocate Heap for 1024 Sockets, can be expanded later */
2534 Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
2535
2536 AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
2537
2538 break;
2539
2540 case DLL_THREAD_ATTACH:
2541 break;
2542
2543 case DLL_THREAD_DETACH:
2544 break;
2545
2546 case DLL_PROCESS_DETACH:
2547 break;
2548 }
2549
2550 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
2551
2552 return TRUE;
2553 }
2554
2555 /* EOF */
2556
2557