Avoid buffer overflow (bug #4693).
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
1 /*
2 * COPYRIGHT: See COPYING in the top level directory
3 * PROJECT: ReactOS Ancillary Function Driver DLL
4 * FILE: misc/dllmain.c
5 * PURPOSE: DLL entry point
6 * PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
7 * Alex Ionescu (alex@relsoft.net)
8 * REVISIONS:
9 * CSH 01/09-2000 Created
10 * Alex 16/07/2004 - Complete Rewrite
11 */
12
13 #include <msafd.h>
14
15 #include <debug.h>
16
17 #if DBG
18 //DWORD DebugTraceLevel = DEBUG_ULTRA;
19 DWORD DebugTraceLevel = 0;
20 #endif /* DBG */
21
22 HANDLE GlobalHeap;
23 WSPUPCALLTABLE Upcalls;
24 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
25 ULONG SocketCount = 0;
26 PSOCKET_INFORMATION *Sockets = NULL;
27 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
28 ULONG SockAsyncThreadRefCount;
29 HANDLE SockAsyncHelperAfdHandle;
30 HANDLE SockAsyncCompletionPort;
31 BOOLEAN SockAsyncSelectCalled;
32
33
34
35 /*
36 * FUNCTION: Creates a new socket
37 * ARGUMENTS:
38 * af = Address family
39 * type = Socket type
40 * protocol = Protocol type
41 * lpProtocolInfo = Pointer to protocol information
42 * g = Reserved
43 * dwFlags = Socket flags
44 * lpErrno = Address of buffer for error information
45 * RETURNS:
46 * Created socket, or INVALID_SOCKET if it could not be created
47 */
48 SOCKET
49 WSPAPI
50 WSPSocket(int AddressFamily,
51 int SocketType,
52 int Protocol,
53 LPWSAPROTOCOL_INFOW lpProtocolInfo,
54 GROUP g,
55 DWORD dwFlags,
56 LPINT lpErrno)
57 {
58 OBJECT_ATTRIBUTES Object;
59 IO_STATUS_BLOCK IOSB;
60 USHORT SizeOfPacket;
61 ULONG SizeOfEA;
62 PAFD_CREATE_PACKET AfdPacket;
63 HANDLE Sock;
64 PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
65 PFILE_FULL_EA_INFORMATION EABuffer = NULL;
66 PHELPER_DATA HelperData;
67 PVOID HelperDLLContext;
68 DWORD HelperEvents;
69 UNICODE_STRING TransportName;
70 UNICODE_STRING DevName;
71 LARGE_INTEGER GroupData;
72 INT Status;
73
74 AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
75 AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
76 AddressFamily, SocketType, Protocol));
77
78 /* Get Helper Data and Transport */
79 Status = SockGetTdiName (&AddressFamily,
80 &SocketType,
81 &Protocol,
82 g,
83 dwFlags,
84 &TransportName,
85 &HelperDLLContext,
86 &HelperData,
87 &HelperEvents);
88
89 /* Check for error */
90 if (Status != NO_ERROR)
91 {
92 AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
93 goto error;
94 }
95
96 /* AFD Device Name */
97 RtlInitUnicodeString(&DevName, L"\\Device\\Afd\\Endpoint");
98
99 /* Set Socket Data */
100 Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
101 if (!Socket)
102 return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
103
104 RtlZeroMemory(Socket, sizeof(*Socket));
105 Socket->RefCount = 2;
106 Socket->Handle = -1;
107 Socket->SharedData.Listening = FALSE;
108 Socket->SharedData.State = SocketOpen;
109 Socket->SharedData.AddressFamily = AddressFamily;
110 Socket->SharedData.SocketType = SocketType;
111 Socket->SharedData.Protocol = Protocol;
112 Socket->HelperContext = HelperDLLContext;
113 Socket->HelperData = HelperData;
114 Socket->HelperEvents = HelperEvents;
115 Socket->LocalAddress = &Socket->WSLocalAddress;
116 Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
117 Socket->RemoteAddress = &Socket->WSRemoteAddress;
118 Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
119 Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
120 Socket->SharedData.CreateFlags = dwFlags;
121 Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
122 Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
123 Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
124 Socket->SharedData.GroupID = g;
125 Socket->SharedData.GroupType = 0;
126 Socket->SharedData.UseSAN = FALSE;
127 Socket->SharedData.NonBlocking = FALSE; /* Sockets start blocking */
128 Socket->SanData = NULL;
129
130 /* Ask alex about this */
131 if( Socket->SharedData.SocketType == SOCK_DGRAM ||
132 Socket->SharedData.SocketType == SOCK_RAW )
133 {
134 AFD_DbgPrint(MID_TRACE,("Connectionless socket\n"));
135 Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
136 }
137
138 /* Packet Size */
139 SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
140
141 /* EA Size */
142 SizeOfEA = SizeOfPacket + sizeof(FILE_FULL_EA_INFORMATION) + AFD_PACKET_COMMAND_LENGTH;
143
144 /* Set up EA Buffer */
145 EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
146 if (!EABuffer)
147 return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
148
149 RtlZeroMemory(EABuffer, SizeOfEA);
150 EABuffer->NextEntryOffset = 0;
151 EABuffer->Flags = 0;
152 EABuffer->EaNameLength = AFD_PACKET_COMMAND_LENGTH;
153 RtlCopyMemory (EABuffer->EaName,
154 AfdCommand,
155 AFD_PACKET_COMMAND_LENGTH + 1);
156 EABuffer->EaValueLength = SizeOfPacket;
157
158 /* Set up AFD Packet */
159 AfdPacket = (PAFD_CREATE_PACKET)(EABuffer->EaName + EABuffer->EaNameLength + 1);
160 AfdPacket->SizeOfTransportName = TransportName.Length;
161 RtlCopyMemory (AfdPacket->TransportName,
162 TransportName.Buffer,
163 TransportName.Length + sizeof(WCHAR));
164 AfdPacket->GroupID = g;
165
166 /* Set up Endpoint Flags */
167 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0)
168 {
169 if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW))
170 {
171 /* Only RAW or UDP can be Connectionless */
172 goto error;
173 }
174 AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
175 }
176
177 if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0)
178 {
179 if (SocketType == SOCK_STREAM)
180 {
181 if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
182 {
183 /* The Provider doesn't actually support Message Oriented Streams */
184 goto error;
185 }
186 }
187 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MESSAGE_ORIENTED;
188 }
189
190 if (SocketType == SOCK_RAW) AfdPacket->EndpointFlags |= AFD_ENDPOINT_RAW;
191
192 if (dwFlags & (WSA_FLAG_MULTIPOINT_C_ROOT |
193 WSA_FLAG_MULTIPOINT_C_LEAF |
194 WSA_FLAG_MULTIPOINT_D_ROOT |
195 WSA_FLAG_MULTIPOINT_D_LEAF))
196 {
197 if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
198 {
199 /* The Provider doesn't actually support Multipoint */
200 goto error;
201 }
202 AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
203
204 if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT)
205 {
206 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
207 || ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0))
208 {
209 /* The Provider doesn't support Control Planes, or you already gave a leaf */
210 goto error;
211 }
212 AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
213 }
214
215 if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT)
216 {
217 if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
218 || ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0))
219 {
220 /* The Provider doesn't support Data Planes, or you already gave a leaf */
221 goto error;
222 }
223 AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
224 }
225 }
226
227 /* Set up Object Attributes */
228 InitializeObjectAttributes (&Object,
229 &DevName,
230 OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
231 0,
232 0);
233
234 /* Create the Socket as asynchronous. That means we have to block
235 ourselves after every call to NtDeviceIoControlFile. This is
236 because the kernel doesn't support overlapping synchronous I/O
237 requests (made from multiple threads) at this time (Sep 2005) */
238 Status = NtCreateFile(&Sock,
239 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
240 &Object,
241 &IOSB,
242 NULL,
243 0,
244 FILE_SHARE_READ | FILE_SHARE_WRITE,
245 FILE_OPEN_IF,
246 0,
247 EABuffer,
248 SizeOfEA);
249
250 HeapFree(GlobalHeap, 0, EABuffer);
251
252 if (Status != STATUS_SUCCESS)
253 {
254 AFD_DbgPrint(MIN_TRACE, ("Failed to open socket\n"));
255
256 HeapFree(GlobalHeap, 0, Socket);
257
258 return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
259 }
260
261 /* Save Handle */
262 Socket->Handle = (SOCKET)Sock;
263
264 /* XXX See if there's a structure we can reuse -- We need to do this
265 * more properly. */
266 PrevSocket = GetSocketStructure( (SOCKET)Sock );
267
268 if( PrevSocket )
269 {
270 RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
271 RtlFreeHeap( GlobalHeap, 0, Socket );
272 Socket = PrevSocket;
273 }
274
275 /* Save Group Info */
276 if (g != 0)
277 {
278 GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
279 Socket->SharedData.GroupID = GroupData.u.LowPart;
280 Socket->SharedData.GroupType = GroupData.u.HighPart;
281 }
282
283 /* Get Window Sizes and Save them */
284 GetSocketInformation (Socket,
285 AFD_INFO_SEND_WINDOW_SIZE,
286 &Socket->SharedData.SizeOfSendBuffer,
287 NULL);
288
289 GetSocketInformation (Socket,
290 AFD_INFO_RECEIVE_WINDOW_SIZE,
291 &Socket->SharedData.SizeOfRecvBuffer,
292 NULL);
293
294 /* Save in Process Sockets List */
295 Sockets[SocketCount] = Socket;
296 SocketCount ++;
297
298 /* Create the Socket Context */
299 CreateContext(Socket);
300
301 /* Notify Winsock */
302 Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
303
304 /* Return Socket Handle */
305 AFD_DbgPrint(MID_TRACE,("Success %x\n", Sock));
306
307 return (SOCKET)Sock;
308
309 error:
310 AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
311
312 if( Socket )
313 HeapFree(GlobalHeap, 0, Socket);
314
315 if( lpErrno )
316 *lpErrno = Status;
317
318 return INVALID_SOCKET;
319 }
320
321
322 DWORD MsafdReturnWithErrno(NTSTATUS Status,
323 LPINT Errno,
324 DWORD Received,
325 LPDWORD ReturnedBytes)
326 {
327 if( ReturnedBytes )
328 *ReturnedBytes = 0;
329 if( Errno )
330 {
331 switch (Status)
332 {
333 case STATUS_CANT_WAIT:
334 *Errno = WSAEWOULDBLOCK;
335 break;
336 case STATUS_TIMEOUT:
337 *Errno = WSAETIMEDOUT;
338 break;
339 case STATUS_SUCCESS:
340 /* Return Number of bytes Read */
341 if( ReturnedBytes )
342 *ReturnedBytes = Received;
343 break;
344 case STATUS_FILE_CLOSED:
345 case STATUS_END_OF_FILE:
346 *Errno = WSAESHUTDOWN;
347 break;
348 case STATUS_PENDING:
349 *Errno = WSA_IO_PENDING;
350 break;
351 case STATUS_BUFFER_TOO_SMALL:
352 case STATUS_BUFFER_OVERFLOW:
353 DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
354 *Errno = WSAEMSGSIZE;
355 break;
356 case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
357 case STATUS_INSUFFICIENT_RESOURCES:
358 DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
359 *Errno = WSAENOBUFS;
360 break;
361 case STATUS_INVALID_CONNECTION:
362 DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
363 *Errno = WSAEAFNOSUPPORT;
364 break;
365 case STATUS_INVALID_ADDRESS:
366 DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
367 *Errno = WSAEADDRNOTAVAIL;
368 break;
369 case STATUS_REMOTE_NOT_LISTENING:
370 DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
371 *Errno = WSAECONNREFUSED;
372 break;
373 case STATUS_NETWORK_UNREACHABLE:
374 DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
375 *Errno = WSAENETUNREACH;
376 break;
377 case STATUS_INVALID_PARAMETER:
378 DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
379 *Errno = WSAEINVAL;
380 break;
381 case STATUS_CANCELLED:
382 DbgPrint("MSAFD: STATUS_CANCELLED\n");
383 *Errno = WSA_OPERATION_ABORTED;
384 break;
385 default:
386 DbgPrint("MSAFD: Error %x is unknown\n", Status);
387 *Errno = WSAEINVAL;
388 break;
389 }
390 }
391
392 /* Success */
393 return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
394 }
395
396 /*
397 * FUNCTION: Closes an open socket
398 * ARGUMENTS:
399 * s = Socket descriptor
400 * lpErrno = Address of buffer for error information
401 * RETURNS:
402 * NO_ERROR, or SOCKET_ERROR if the socket could not be closed
403 */
404 INT
405 WSPAPI
406 WSPCloseSocket(IN SOCKET Handle,
407 OUT LPINT lpErrno)
408 {
409 IO_STATUS_BLOCK IoStatusBlock;
410 PSOCKET_INFORMATION Socket = NULL;
411 NTSTATUS Status;
412 HANDLE SockEvent;
413 AFD_DISCONNECT_INFO DisconnectInfo;
414 SOCKET_STATE OldState;
415
416 /* Create the Wait Event */
417 Status = NtCreateEvent(&SockEvent,
418 GENERIC_READ | GENERIC_WRITE,
419 NULL,
420 1,
421 FALSE);
422
423 if(!NT_SUCCESS(Status))
424 return SOCKET_ERROR;
425
426 /* Get the Socket Structure associate to this Socket*/
427 Socket = GetSocketStructure(Handle);
428
429 /* If a Close is already in Process, give up */
430 if (Socket->SharedData.State == SocketClosed)
431 {
432 NtClose(SockEvent);
433 *lpErrno = WSAENOTSOCK;
434 return SOCKET_ERROR;
435 }
436
437 /* Set the state to close */
438 OldState = Socket->SharedData.State;
439 Socket->SharedData.State = SocketClosed;
440
441 /* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
442 /* FIXME: Should we do this on Datagram Sockets too? */
443 if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
444 {
445 ULONG LingerWait;
446 ULONG SendsInProgress;
447 ULONG SleepWait;
448
449 /* We need to respect the timeout */
450 SleepWait = 100;
451 LingerWait = Socket->SharedData.LingerData.l_linger * 1000;
452
453 /* Loop until no more sends are pending, within the timeout */
454 while (LingerWait)
455 {
456 /* Find out how many Sends are in Progress */
457 if (GetSocketInformation(Socket,
458 AFD_INFO_SENDS_IN_PROGRESS,
459 &SendsInProgress,
460 NULL))
461 {
462 /* Bail out if anything but NO_ERROR */
463 LingerWait = 0;
464 break;
465 }
466
467 /* Bail out if no more sends are pending */
468 if (!SendsInProgress)
469 break;
470 /*
471 * We have to execute a sleep, so it's kind of like
472 * a block. If the socket is Nonblock, we cannot
473 * go on since asyncronous operation is expected
474 * and we cannot offer it
475 */
476 if (Socket->SharedData.NonBlocking)
477 {
478 NtClose(SockEvent);
479 Socket->SharedData.State = OldState;
480 *lpErrno = WSAEWOULDBLOCK;
481 return SOCKET_ERROR;
482 }
483
484 /* Now we can sleep, and decrement the linger wait */
485 /*
486 * FIXME: It seems Windows does some funky acceleration
487 * since the waiting seems to be longer and longer. I
488 * don't think this improves performance so much, so we
489 * wait a fixed time instead.
490 */
491 Sleep(SleepWait);
492 LingerWait -= SleepWait;
493 }
494
495 /*
496 * We have reached the timeout or sends are over.
497 * Disconnect if the timeout has been reached.
498 */
499 if (LingerWait <= 0)
500 {
501 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
502 DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
503
504 /* Send IOCTL */
505 Status = NtDeviceIoControlFile((HANDLE)Handle,
506 SockEvent,
507 NULL,
508 NULL,
509 &IoStatusBlock,
510 IOCTL_AFD_DISCONNECT,
511 &DisconnectInfo,
512 sizeof(DisconnectInfo),
513 NULL,
514 0);
515
516 /* Wait for return */
517 if (Status == STATUS_PENDING)
518 {
519 WaitForSingleObject(SockEvent, INFINITE);
520 }
521 }
522 }
523
524 /* FIXME: We should notify the Helper DLL of WSH_NOTIFY_CLOSE */
525
526 /* Cleanup Time! */
527 Socket->HelperContext = NULL;
528 Socket->SharedData.AsyncDisabledEvents = -1;
529 NtClose(Socket->TdiAddressHandle);
530 Socket->TdiAddressHandle = NULL;
531 NtClose(Socket->TdiConnectionHandle);
532 Socket->TdiConnectionHandle = NULL;
533
534 /* Close the handle */
535 NtClose((HANDLE)Handle);
536 NtClose(SockEvent);
537
538 return NO_ERROR;
539 }
540
541
542 /*
543 * FUNCTION: Associates a local address with a socket
544 * ARGUMENTS:
545 * s = Socket descriptor
546 * name = Pointer to local address
547 * namelen = Length of name
548 * lpErrno = Address of buffer for error information
549 * RETURNS:
550 * 0, or SOCKET_ERROR if the socket could not be bound
551 */
552 INT
553 WSPAPI
554 WSPBind(SOCKET Handle,
555 const struct sockaddr *SocketAddress,
556 int SocketAddressLength,
557 LPINT lpErrno)
558 {
559 IO_STATUS_BLOCK IOSB;
560 PAFD_BIND_DATA BindData;
561 PSOCKET_INFORMATION Socket = NULL;
562 NTSTATUS Status;
563 SOCKADDR_INFO SocketInfo;
564 HANDLE SockEvent;
565
566 /* See below */
567 BindData = HeapAlloc(GlobalHeap, 0, 0xA + SocketAddressLength);
568 if (!BindData)
569 {
570 return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
571 }
572
573 Status = NtCreateEvent(&SockEvent,
574 GENERIC_READ | GENERIC_WRITE,
575 NULL,
576 1,
577 FALSE);
578
579 if (!NT_SUCCESS(Status))
580 {
581 HeapFree(GlobalHeap, 0, BindData);
582 return SOCKET_ERROR;
583 }
584
585 /* Get the Socket Structure associate to this Socket*/
586 Socket = GetSocketStructure(Handle);
587
588 /* Set up Address in TDI Format */
589 BindData->Address.TAAddressCount = 1;
590 BindData->Address.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
591 BindData->Address.Address[0].AddressType = SocketAddress->sa_family;
592 RtlCopyMemory (BindData->Address.Address[0].Address,
593 SocketAddress->sa_data,
594 SocketAddressLength - sizeof(SocketAddress->sa_family));
595
596 /* Get Address Information */
597 Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
598 SocketAddressLength,
599 &SocketInfo);
600
601 /* Set the Share Type */
602 if (Socket->SharedData.ExclusiveAddressUse)
603 {
604 BindData->ShareType = AFD_SHARE_EXCLUSIVE;
605 }
606 else if (SocketInfo.EndpointInfo == SockaddrEndpointInfoWildcard)
607 {
608 BindData->ShareType = AFD_SHARE_WILDCARD;
609 }
610 else if (Socket->SharedData.ReuseAddresses)
611 {
612 BindData->ShareType = AFD_SHARE_REUSE;
613 }
614 else
615 {
616 BindData->ShareType = AFD_SHARE_UNIQUE;
617 }
618
619 /* Send IOCTL */
620 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
621 SockEvent,
622 NULL,
623 NULL,
624 &IOSB,
625 IOCTL_AFD_BIND,
626 BindData,
627 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
628 BindData,
629 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
630
631 /* Wait for return */
632 if (Status == STATUS_PENDING)
633 {
634 WaitForSingleObject(SockEvent, INFINITE);
635 Status = IOSB.Status;
636 }
637
638 /* Set up Socket Data */
639 Socket->SharedData.State = SocketBound;
640 Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
641
642 NtClose(SockEvent);
643 HeapFree(GlobalHeap, 0, BindData);
644 return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
645 }
646
647 int
648 WSPAPI
649 WSPListen(SOCKET Handle,
650 int Backlog,
651 LPINT lpErrno)
652 {
653 IO_STATUS_BLOCK IOSB;
654 AFD_LISTEN_DATA ListenData;
655 PSOCKET_INFORMATION Socket = NULL;
656 HANDLE SockEvent;
657 NTSTATUS Status;
658
659 /* Get the Socket Structure associate to this Socket*/
660 Socket = GetSocketStructure(Handle);
661
662 if (Socket->SharedData.Listening)
663 return 0;
664
665 Status = NtCreateEvent(&SockEvent,
666 GENERIC_READ | GENERIC_WRITE,
667 NULL,
668 1,
669 FALSE);
670
671 if( !NT_SUCCESS(Status) )
672 return -1;
673
674 /* Set Up Listen Structure */
675 ListenData.UseSAN = FALSE;
676 ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
677 ListenData.Backlog = Backlog;
678
679 /* Send IOCTL */
680 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
681 SockEvent,
682 NULL,
683 NULL,
684 &IOSB,
685 IOCTL_AFD_START_LISTEN,
686 &ListenData,
687 sizeof(ListenData),
688 NULL,
689 0);
690
691 /* Wait for return */
692 if (Status == STATUS_PENDING)
693 {
694 WaitForSingleObject(SockEvent, INFINITE);
695 Status = IOSB.Status;
696 }
697
698 /* Set to Listening */
699 Socket->SharedData.Listening = TRUE;
700
701 NtClose( SockEvent );
702
703 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
704 }
705
706
707 int
708 WSPAPI
709 WSPSelect(int nfds,
710 fd_set *readfds,
711 fd_set *writefds,
712 fd_set *exceptfds,
713 const LPTIMEVAL timeout,
714 LPINT lpErrno)
715 {
716 IO_STATUS_BLOCK IOSB;
717 PAFD_POLL_INFO PollInfo;
718 NTSTATUS Status;
719 LONG HandleCount, OutCount = 0;
720 ULONG PollBufferSize;
721 PVOID PollBuffer;
722 ULONG i, j = 0, x;
723 HANDLE SockEvent;
724 BOOL HandleCounted;
725 LARGE_INTEGER Timeout;
726
727 /* Find out how many sockets we have, and how large the buffer needs
728 * to be */
729
730 HandleCount = ( readfds ? readfds->fd_count : 0 ) +
731 ( writefds ? writefds->fd_count : 0 ) +
732 ( exceptfds ? exceptfds->fd_count : 0 );
733
734 if ( HandleCount == 0 )
735 {
736 AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
737 HandleCount));
738 if (lpErrno) *lpErrno = WSAEINVAL;
739 return SOCKET_ERROR;
740 }
741
742 PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
743
744 AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n",
745 HandleCount, PollBufferSize));
746
747 /* Convert Timeout to NT Format */
748 if (timeout == NULL)
749 {
750 Timeout.u.LowPart = -1;
751 Timeout.u.HighPart = 0x7FFFFFFF;
752 AFD_DbgPrint(MAX_TRACE,("Infinite timeout\n"));
753 }
754 else
755 {
756 Timeout = RtlEnlargedIntegerMultiply
757 ((timeout->tv_sec * 1000) + (timeout->tv_usec / 1000), -10000);
758 /* Negative timeouts are illegal. Since the kernel represents an
759 * incremental timeout as a negative number, we check for a positive
760 * result.
761 */
762 if (Timeout.QuadPart > 0)
763 {
764 if (lpErrno) *lpErrno = WSAEINVAL;
765 return SOCKET_ERROR;
766 }
767 AFD_DbgPrint(MAX_TRACE,("Timeout: Orig %d.%06d kernel %d\n",
768 timeout->tv_sec, timeout->tv_usec,
769 Timeout.u.LowPart));
770 }
771
772 Status = NtCreateEvent(&SockEvent,
773 GENERIC_READ | GENERIC_WRITE,
774 NULL,
775 1,
776 FALSE);
777
778 if( !NT_SUCCESS(Status) )
779 return SOCKET_ERROR;
780
781 /* Allocate */
782 PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
783
784 if (!PollBuffer)
785 {
786 if (lpErrno)
787 *lpErrno = WSAEFAULT;
788 NtClose(SockEvent);
789 return SOCKET_ERROR;
790 }
791
792 PollInfo = (PAFD_POLL_INFO)PollBuffer;
793
794 RtlZeroMemory( PollInfo, PollBufferSize );
795
796 /* Number of handles for AFD to Check */
797 PollInfo->Exclusive = FALSE;
798 PollInfo->Timeout = Timeout;
799
800 if (readfds != NULL) {
801 for (i = 0; i < readfds->fd_count; i++, j++)
802 {
803 PollInfo->Handles[j].Handle = readfds->fd_array[i];
804 PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
805 AFD_EVENT_DISCONNECT |
806 AFD_EVENT_ABORT |
807 AFD_EVENT_CLOSE |
808 AFD_EVENT_ACCEPT;
809 }
810 }
811 if (writefds != NULL)
812 {
813 for (i = 0; i < writefds->fd_count; i++, j++)
814 {
815 PollInfo->Handles[j].Handle = writefds->fd_array[i];
816 PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
817 }
818 }
819 if (exceptfds != NULL)
820 {
821 for (i = 0; i < exceptfds->fd_count; i++, j++)
822 {
823 PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
824 PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
825 }
826 }
827
828 PollInfo->HandleCount = j;
829 PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
830
831 /* Send IOCTL */
832 Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
833 SockEvent,
834 NULL,
835 NULL,
836 &IOSB,
837 IOCTL_AFD_SELECT,
838 PollInfo,
839 PollBufferSize,
840 PollInfo,
841 PollBufferSize);
842
843 AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
844
845 /* Wait for Completition */
846 if (Status == STATUS_PENDING)
847 {
848 WaitForSingleObject(SockEvent, INFINITE);
849 }
850
851 /* Clear the Structures */
852 if( readfds )
853 FD_ZERO(readfds);
854 if( writefds )
855 FD_ZERO(writefds);
856 if( exceptfds )
857 FD_ZERO(exceptfds);
858
859 /* Loop through return structure */
860 HandleCount = PollInfo->HandleCount;
861
862 /* Return in FDSET Format */
863 for (i = 0; i < HandleCount; i++)
864 {
865 HandleCounted = FALSE;
866 for(x = 1; x; x<<=1)
867 {
868 switch (PollInfo->Handles[i].Events & x)
869 {
870 case AFD_EVENT_RECEIVE:
871 case AFD_EVENT_DISCONNECT:
872 case AFD_EVENT_ABORT:
873 case AFD_EVENT_ACCEPT:
874 case AFD_EVENT_CLOSE:
875 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
876 PollInfo->Handles[i].Events,
877 PollInfo->Handles[i].Handle));
878 if (! HandleCounted)
879 {
880 OutCount++;
881 HandleCounted = TRUE;
882 }
883 if( readfds )
884 FD_SET(PollInfo->Handles[i].Handle, readfds);
885 break;
886 case AFD_EVENT_SEND:
887 case AFD_EVENT_CONNECT:
888 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
889 PollInfo->Handles[i].Events,
890 PollInfo->Handles[i].Handle));
891 if (! HandleCounted)
892 {
893 OutCount++;
894 HandleCounted = TRUE;
895 }
896 if( writefds )
897 FD_SET(PollInfo->Handles[i].Handle, writefds);
898 break;
899 case AFD_EVENT_OOB_RECEIVE:
900 case AFD_EVENT_CONNECT_FAIL:
901 AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
902 PollInfo->Handles[i].Events,
903 PollInfo->Handles[i].Handle));
904 if (! HandleCounted)
905 {
906 OutCount++;
907 HandleCounted = TRUE;
908 }
909 if( exceptfds )
910 FD_SET(PollInfo->Handles[i].Handle, exceptfds);
911 break;
912 }
913 }
914 }
915
916 HeapFree( GlobalHeap, 0, PollBuffer );
917 NtClose( SockEvent );
918
919 if( lpErrno )
920 {
921 switch( IOSB.Status )
922 {
923 case STATUS_SUCCESS:
924 case STATUS_TIMEOUT:
925 *lpErrno = 0;
926 break;
927 default:
928 *lpErrno = WSAEINVAL;
929 break;
930 }
931 AFD_DbgPrint(MID_TRACE,("*lpErrno = %x\n", *lpErrno));
932 }
933
934 AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
935
936 return OutCount;
937 }
938
939 SOCKET
940 WSPAPI
941 WSPAccept(SOCKET Handle,
942 struct sockaddr *SocketAddress,
943 int *SocketAddressLength,
944 LPCONDITIONPROC lpfnCondition,
945 DWORD dwCallbackData,
946 LPINT lpErrno)
947 {
948 IO_STATUS_BLOCK IOSB;
949 PAFD_RECEIVED_ACCEPT_DATA ListenReceiveData;
950 AFD_ACCEPT_DATA AcceptData;
951 AFD_DEFER_ACCEPT_DATA DeferData;
952 AFD_PENDING_ACCEPT_DATA PendingAcceptData;
953 PSOCKET_INFORMATION Socket = NULL;
954 NTSTATUS Status;
955 struct fd_set ReadSet;
956 struct timeval Timeout;
957 PVOID PendingData = NULL;
958 ULONG PendingDataLength = 0;
959 PVOID CalleeDataBuffer;
960 WSABUF CallerData, CalleeID, CallerID, CalleeData;
961 PSOCKADDR RemoteAddress = NULL;
962 GROUP GroupID = 0;
963 ULONG CallBack;
964 WSAPROTOCOL_INFOW ProtocolInfo;
965 SOCKET AcceptSocket;
966 UCHAR ReceiveBuffer[0x1A];
967 HANDLE SockEvent;
968
969 Status = NtCreateEvent(&SockEvent,
970 GENERIC_READ | GENERIC_WRITE,
971 NULL,
972 1,
973 FALSE);
974
975 if( !NT_SUCCESS(Status) )
976 {
977 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
978 return INVALID_SOCKET;
979 }
980
981 /* Dynamic Structure...ugh */
982 ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
983
984 /* Get the Socket Structure associate to this Socket*/
985 Socket = GetSocketStructure(Handle);
986
987 /* If this is non-blocking, make sure there's something for us to accept */
988 FD_ZERO(&ReadSet);
989 FD_SET(Socket->Handle, &ReadSet);
990 Timeout.tv_sec=0;
991 Timeout.tv_usec=0;
992
993 WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
994
995 if (ReadSet.fd_array[0] != Socket->Handle)
996 {
997 NtClose(SockEvent);
998 return 0;
999 }
1000
1001 /* Send IOCTL */
1002 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1003 SockEvent,
1004 NULL,
1005 NULL,
1006 &IOSB,
1007 IOCTL_AFD_WAIT_FOR_LISTEN,
1008 NULL,
1009 0,
1010 ListenReceiveData,
1011 0xA + sizeof(*ListenReceiveData));
1012
1013 /* Wait for return */
1014 if (Status == STATUS_PENDING)
1015 {
1016 WaitForSingleObject(SockEvent, INFINITE);
1017 Status = IOSB.Status;
1018 }
1019
1020 if (!NT_SUCCESS(Status))
1021 {
1022 NtClose( SockEvent );
1023 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1024 return INVALID_SOCKET;
1025 }
1026
1027 if (lpfnCondition != NULL)
1028 {
1029 if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0)
1030 {
1031 /* Find out how much data is pending */
1032 PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1033 PendingAcceptData.ReturnSize = TRUE;
1034
1035 /* Send IOCTL */
1036 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1037 SockEvent,
1038 NULL,
1039 NULL,
1040 &IOSB,
1041 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1042 &PendingAcceptData,
1043 sizeof(PendingAcceptData),
1044 &PendingAcceptData,
1045 sizeof(PendingAcceptData));
1046
1047 /* Wait for return */
1048 if (Status == STATUS_PENDING)
1049 {
1050 WaitForSingleObject(SockEvent, INFINITE);
1051 Status = IOSB.Status;
1052 }
1053
1054 if (!NT_SUCCESS(Status))
1055 {
1056 NtClose( SockEvent );
1057 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1058 return INVALID_SOCKET;
1059 }
1060
1061 /* How much data to allocate */
1062 PendingDataLength = IOSB.Information;
1063
1064 if (PendingDataLength)
1065 {
1066 /* Allocate needed space */
1067 PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
1068 if (!PendingData)
1069 {
1070 MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
1071 return INVALID_SOCKET;
1072 }
1073
1074 /* We want the data now */
1075 PendingAcceptData.ReturnSize = FALSE;
1076
1077 /* Send IOCTL */
1078 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1079 SockEvent,
1080 NULL,
1081 NULL,
1082 &IOSB,
1083 IOCTL_AFD_GET_PENDING_CONNECT_DATA,
1084 &PendingAcceptData,
1085 sizeof(PendingAcceptData),
1086 PendingData,
1087 PendingDataLength);
1088
1089 /* Wait for return */
1090 if (Status == STATUS_PENDING)
1091 {
1092 WaitForSingleObject(SockEvent, INFINITE);
1093 Status = IOSB.Status;
1094 }
1095
1096 if (!NT_SUCCESS(Status))
1097 {
1098 NtClose( SockEvent );
1099 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1100 return INVALID_SOCKET;
1101 }
1102 }
1103 }
1104
1105 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1106 {
1107 /* I don't support this yet */
1108 }
1109
1110 /* Build Callee ID */
1111 CalleeID.buf = (PVOID)Socket->LocalAddress;
1112 CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
1113
1114 RemoteAddress = HeapAlloc(GlobalHeap, 0, sizeof(*RemoteAddress));
1115 if (!RemoteAddress)
1116 {
1117 MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
1118 return INVALID_SOCKET;
1119 }
1120
1121 /* Set up Address in SOCKADDR Format */
1122 RtlCopyMemory (RemoteAddress,
1123 &ListenReceiveData->Address.Address[0].AddressType,
1124 sizeof(*RemoteAddress));
1125
1126 /* Build Caller ID */
1127 CallerID.buf = (PVOID)RemoteAddress;
1128 CallerID.len = sizeof(*RemoteAddress);
1129
1130 /* Build Caller Data */
1131 CallerData.buf = PendingData;
1132 CallerData.len = PendingDataLength;
1133
1134 /* Check if socket supports Conditional Accept */
1135 if (Socket->SharedData.UseDelayedAcceptance != 0)
1136 {
1137 /* Allocate Buffer for Callee Data */
1138 CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
1139 if (!CalleeDataBuffer) {
1140 MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
1141 return INVALID_SOCKET;
1142 }
1143 CalleeData.buf = CalleeDataBuffer;
1144 CalleeData.len = 4096;
1145 }
1146 else
1147 {
1148 /* Nothing */
1149 CalleeData.buf = 0;
1150 CalleeData.len = 0;
1151 }
1152
1153 /* Call the Condition Function */
1154 CallBack = (lpfnCondition)(&CallerID,
1155 CallerData.buf == NULL ? NULL : &CallerData,
1156 NULL,
1157 NULL,
1158 &CalleeID,
1159 CalleeData.buf == NULL ? NULL : &CalleeData,
1160 &GroupID,
1161 dwCallbackData);
1162
1163 if (((CallBack == CF_ACCEPT) && GroupID) != 0)
1164 {
1165 /* TBD: Check for Validity */
1166 }
1167
1168 if (CallBack == CF_ACCEPT)
1169 {
1170 if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
1171 {
1172 /* I don't support this yet */
1173 }
1174 if (CalleeData.buf)
1175 {
1176 // SockSetConnectData Sockets(SocketID), IOCTL_AFD_SET_CONNECT_DATA, CalleeData.Buffer, CalleeData.BuffSize, 0
1177 }
1178 }
1179 else
1180 {
1181 /* Callback rejected. Build Defer Structure */
1182 DeferData.SequenceNumber = ListenReceiveData->SequenceNumber;
1183 DeferData.RejectConnection = (CallBack == CF_REJECT);
1184
1185 /* Send IOCTL */
1186 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1187 SockEvent,
1188 NULL,
1189 NULL,
1190 &IOSB,
1191 IOCTL_AFD_DEFER_ACCEPT,
1192 &DeferData,
1193 sizeof(DeferData),
1194 NULL,
1195 0);
1196
1197 /* Wait for return */
1198 if (Status == STATUS_PENDING)
1199 {
1200 WaitForSingleObject(SockEvent, INFINITE);
1201 Status = IOSB.Status;
1202 }
1203
1204 NtClose( SockEvent );
1205
1206 if (!NT_SUCCESS(Status))
1207 {
1208 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1209 return INVALID_SOCKET;
1210 }
1211
1212 if (CallBack == CF_REJECT )
1213 {
1214 *lpErrno = WSAECONNREFUSED;
1215 return INVALID_SOCKET;
1216 }
1217 else
1218 {
1219 *lpErrno = WSAECONNREFUSED;
1220 return INVALID_SOCKET;
1221 }
1222 }
1223 }
1224
1225 /* Create a new Socket */
1226 ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
1227 ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
1228 ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
1229
1230 AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
1231 Socket->SharedData.SocketType,
1232 Socket->SharedData.Protocol,
1233 &ProtocolInfo,
1234 GroupID,
1235 Socket->SharedData.CreateFlags,
1236 NULL);
1237
1238 /* Set up the Accept Structure */
1239 AcceptData.ListenHandle = (HANDLE)AcceptSocket;
1240 AcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
1241
1242 /* Send IOCTL to Accept */
1243 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1244 SockEvent,
1245 NULL,
1246 NULL,
1247 &IOSB,
1248 IOCTL_AFD_ACCEPT,
1249 &AcceptData,
1250 sizeof(AcceptData),
1251 NULL,
1252 0);
1253
1254 /* Wait for return */
1255 if (Status == STATUS_PENDING)
1256 {
1257 WaitForSingleObject(SockEvent, INFINITE);
1258 Status = IOSB.Status;
1259 }
1260
1261 if (!NT_SUCCESS(Status))
1262 {
1263 NtClose(SockEvent);
1264 WSPCloseSocket( AcceptSocket, lpErrno );
1265 MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1266 return INVALID_SOCKET;
1267 }
1268
1269 /* Return Address in SOCKADDR FORMAT */
1270 if( SocketAddress )
1271 {
1272 RtlCopyMemory (SocketAddress,
1273 &ListenReceiveData->Address.Address[0].AddressType,
1274 sizeof(*RemoteAddress));
1275 if( SocketAddressLength )
1276 *SocketAddressLength = ListenReceiveData->Address.Address[0].AddressLength;
1277 }
1278
1279 NtClose( SockEvent );
1280
1281 /* Re-enable Async Event */
1282 SockReenableAsyncSelectEvent(Socket, FD_ACCEPT);
1283
1284 AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
1285
1286 *lpErrno = 0;
1287
1288 /* Return Socket */
1289 return AcceptSocket;
1290 }
1291
1292 int
1293 WSPAPI
1294 WSPConnect(SOCKET Handle,
1295 const struct sockaddr * SocketAddress,
1296 int SocketAddressLength,
1297 LPWSABUF lpCallerData,
1298 LPWSABUF lpCalleeData,
1299 LPQOS lpSQOS,
1300 LPQOS lpGQOS,
1301 LPINT lpErrno)
1302 {
1303 IO_STATUS_BLOCK IOSB;
1304 PAFD_CONNECT_INFO ConnectInfo;
1305 PSOCKET_INFORMATION Socket = NULL;
1306 NTSTATUS Status;
1307 UCHAR ConnectBuffer[0x22];
1308 ULONG ConnectDataLength;
1309 ULONG InConnectDataLength;
1310 INT BindAddressLength;
1311 PSOCKADDR BindAddress;
1312 HANDLE SockEvent;
1313
1314 Status = NtCreateEvent(&SockEvent,
1315 GENERIC_READ | GENERIC_WRITE,
1316 NULL,
1317 1,
1318 FALSE);
1319
1320 if( !NT_SUCCESS(Status) )
1321 return -1;
1322
1323 AFD_DbgPrint(MID_TRACE,("Called\n"));
1324
1325 /* Get the Socket Structure associate to this Socket*/
1326 Socket = GetSocketStructure(Handle);
1327
1328 /* Bind us First */
1329 if (Socket->SharedData.State == SocketOpen)
1330 {
1331 /* Get the Wildcard Address */
1332 BindAddressLength = Socket->HelperData->MaxWSAddressLength;
1333 BindAddress = HeapAlloc(GetProcessHeap(), 0, BindAddressLength);
1334 if (!BindAddress)
1335 {
1336 MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
1337 return INVALID_SOCKET;
1338 }
1339 Socket->HelperData->WSHGetWildcardSockaddr (Socket->HelperContext,
1340 BindAddress,
1341 &BindAddressLength);
1342 /* Bind it */
1343 WSPBind(Handle, BindAddress, BindAddressLength, NULL);
1344 }
1345
1346 /* Set the Connect Data */
1347 if (lpCallerData != NULL)
1348 {
1349 ConnectDataLength = lpCallerData->len;
1350 Status = NtDeviceIoControlFile((HANDLE)Handle,
1351 SockEvent,
1352 NULL,
1353 NULL,
1354 &IOSB,
1355 IOCTL_AFD_SET_CONNECT_DATA,
1356 lpCallerData->buf,
1357 ConnectDataLength,
1358 NULL,
1359 0);
1360 /* Wait for return */
1361 if (Status == STATUS_PENDING)
1362 {
1363 WaitForSingleObject(SockEvent, INFINITE);
1364 Status = IOSB.Status;
1365 }
1366 }
1367
1368 /* Dynamic Structure...ugh */
1369 ConnectInfo = (PAFD_CONNECT_INFO)ConnectBuffer;
1370
1371 /* Set up Address in TDI Format */
1372 ConnectInfo->RemoteAddress.TAAddressCount = 1;
1373 ConnectInfo->RemoteAddress.Address[0].AddressLength = SocketAddressLength - sizeof(SocketAddress->sa_family);
1374 ConnectInfo->RemoteAddress.Address[0].AddressType = SocketAddress->sa_family;
1375 RtlCopyMemory (ConnectInfo->RemoteAddress.Address[0].Address,
1376 SocketAddress->sa_data,
1377 SocketAddressLength - sizeof(SocketAddress->sa_family));
1378
1379 /*
1380 * Disable FD_WRITE and FD_CONNECT
1381 * The latter fixes a race condition where the FD_CONNECT is re-enabled
1382 * at the end of this function right after the Async Thread disables it.
1383 * This should only happen at the *next* WSPConnect
1384 */
1385 if (Socket->SharedData.AsyncEvents & FD_CONNECT)
1386 {
1387 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
1388 }
1389
1390 /* Tell AFD that we want Connection Data back, have it allocate a buffer */
1391 if (lpCalleeData != NULL)
1392 {
1393 InConnectDataLength = lpCalleeData->len;
1394 Status = NtDeviceIoControlFile((HANDLE)Handle,
1395 SockEvent,
1396 NULL,
1397 NULL,
1398 &IOSB,
1399 IOCTL_AFD_SET_CONNECT_DATA_SIZE,
1400 &InConnectDataLength,
1401 sizeof(InConnectDataLength),
1402 NULL,
1403 0);
1404
1405 /* Wait for return */
1406 if (Status == STATUS_PENDING)
1407 {
1408 WaitForSingleObject(SockEvent, INFINITE);
1409 Status = IOSB.Status;
1410 }
1411 }
1412
1413 /* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
1414 ConnectInfo->Root = 0;
1415 ConnectInfo->UseSAN = FALSE;
1416 ConnectInfo->Unknown = 0;
1417
1418 /* FIXME: Handle Async Connect */
1419 if (Socket->SharedData.NonBlocking)
1420 {
1421 AFD_DbgPrint(MIN_TRACE, ("Async Connect UNIMPLEMENTED!\n"));
1422 }
1423
1424 /* Send IOCTL */
1425 Status = NtDeviceIoControlFile((HANDLE)Handle,
1426 SockEvent,
1427 NULL,
1428 NULL,
1429 &IOSB,
1430 IOCTL_AFD_CONNECT,
1431 ConnectInfo,
1432 0x22,
1433 NULL,
1434 0);
1435 /* Wait for return */
1436 if (Status == STATUS_PENDING)
1437 {
1438 WaitForSingleObject(SockEvent, INFINITE);
1439 Status = IOSB.Status;
1440 }
1441
1442 /* Get any pending connect data */
1443 if (lpCalleeData != NULL)
1444 {
1445 Status = NtDeviceIoControlFile((HANDLE)Handle,
1446 SockEvent,
1447 NULL,
1448 NULL,
1449 &IOSB,
1450 IOCTL_AFD_GET_CONNECT_DATA,
1451 NULL,
1452 0,
1453 lpCalleeData->buf,
1454 lpCalleeData->len);
1455 /* Wait for return */
1456 if (Status == STATUS_PENDING)
1457 {
1458 WaitForSingleObject(SockEvent, INFINITE);
1459 Status = IOSB.Status;
1460 }
1461 }
1462
1463 /* Re-enable Async Event */
1464 SockReenableAsyncSelectEvent(Socket, FD_WRITE);
1465
1466 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
1467 SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
1468
1469 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1470
1471 NtClose( SockEvent );
1472
1473 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1474 }
1475 int
1476 WSPAPI
1477 WSPShutdown(SOCKET Handle,
1478 int HowTo,
1479 LPINT lpErrno)
1480
1481 {
1482 IO_STATUS_BLOCK IOSB;
1483 AFD_DISCONNECT_INFO DisconnectInfo;
1484 PSOCKET_INFORMATION Socket = NULL;
1485 NTSTATUS Status;
1486 HANDLE SockEvent;
1487
1488 Status = NtCreateEvent(&SockEvent,
1489 GENERIC_READ | GENERIC_WRITE,
1490 NULL,
1491 1,
1492 FALSE);
1493
1494 if( !NT_SUCCESS(Status) )
1495 return -1;
1496
1497 AFD_DbgPrint(MID_TRACE,("Called\n"));
1498
1499 /* Get the Socket Structure associate to this Socket*/
1500 Socket = GetSocketStructure(Handle);
1501
1502 /* Set AFD Disconnect Type */
1503 switch (HowTo)
1504 {
1505 case SD_RECEIVE:
1506 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
1507 Socket->SharedData.ReceiveShutdown = TRUE;
1508 break;
1509 case SD_SEND:
1510 DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
1511 Socket->SharedData.SendShutdown = TRUE;
1512 break;
1513 case SD_BOTH:
1514 DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
1515 Socket->SharedData.ReceiveShutdown = TRUE;
1516 Socket->SharedData.SendShutdown = TRUE;
1517 break;
1518 }
1519
1520 DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
1521
1522 /* Send IOCTL */
1523 Status = NtDeviceIoControlFile((HANDLE)Handle,
1524 SockEvent,
1525 NULL,
1526 NULL,
1527 &IOSB,
1528 IOCTL_AFD_DISCONNECT,
1529 &DisconnectInfo,
1530 sizeof(DisconnectInfo),
1531 NULL,
1532 0);
1533
1534 /* Wait for return */
1535 if (Status == STATUS_PENDING)
1536 {
1537 WaitForSingleObject(SockEvent, INFINITE);
1538 Status = IOSB.Status;
1539 }
1540
1541 AFD_DbgPrint(MID_TRACE,("Ending\n"));
1542
1543 NtClose( SockEvent );
1544
1545 return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
1546 }
1547
1548
1549 INT
1550 WSPAPI
1551 WSPGetSockName(IN SOCKET Handle,
1552 OUT LPSOCKADDR Name,
1553 IN OUT LPINT NameLength,
1554 OUT LPINT lpErrno)
1555 {
1556 IO_STATUS_BLOCK IOSB;
1557 ULONG TdiAddressSize;
1558 PTDI_ADDRESS_INFO TdiAddress;
1559 PTRANSPORT_ADDRESS SocketAddress;
1560 PSOCKET_INFORMATION Socket = NULL;
1561 NTSTATUS Status;
1562 HANDLE SockEvent;
1563
1564 Status = NtCreateEvent(&SockEvent,
1565 GENERIC_READ | GENERIC_WRITE,
1566 NULL,
1567 1,
1568 FALSE);
1569
1570 if( !NT_SUCCESS(Status) )
1571 return SOCKET_ERROR;
1572
1573 /* Get the Socket Structure associate to this Socket*/
1574 Socket = GetSocketStructure(Handle);
1575
1576 /* Allocate a buffer for the address */
1577 TdiAddressSize =
1578 sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
1579 TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1580
1581 if ( TdiAddress == NULL )
1582 {
1583 NtClose( SockEvent );
1584 *lpErrno = WSAENOBUFS;
1585 return SOCKET_ERROR;
1586 }
1587
1588 SocketAddress = &TdiAddress->Address;
1589
1590 /* Send IOCTL */
1591 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1592 SockEvent,
1593 NULL,
1594 NULL,
1595 &IOSB,
1596 IOCTL_AFD_GET_SOCK_NAME,
1597 NULL,
1598 0,
1599 TdiAddress,
1600 TdiAddressSize);
1601
1602 /* Wait for return */
1603 if (Status == STATUS_PENDING)
1604 {
1605 WaitForSingleObject(SockEvent, INFINITE);
1606 Status = IOSB.Status;
1607 }
1608
1609 NtClose( SockEvent );
1610
1611 if (NT_SUCCESS(Status))
1612 {
1613 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1614 {
1615 Name->sa_family = SocketAddress->Address[0].AddressType;
1616 RtlCopyMemory (Name->sa_data,
1617 SocketAddress->Address[0].Address,
1618 SocketAddress->Address[0].AddressLength);
1619 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1620 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
1621 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1622 ((struct sockaddr_in *)Name)->sin_port));
1623 HeapFree(GlobalHeap, 0, TdiAddress);
1624 return 0;
1625 }
1626 else
1627 {
1628 HeapFree(GlobalHeap, 0, TdiAddress);
1629 *lpErrno = WSAEFAULT;
1630 return SOCKET_ERROR;
1631 }
1632 }
1633
1634 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1635 }
1636
1637
1638 INT
1639 WSPAPI
1640 WSPGetPeerName(IN SOCKET s,
1641 OUT LPSOCKADDR Name,
1642 IN OUT LPINT NameLength,
1643 OUT LPINT lpErrno)
1644 {
1645 IO_STATUS_BLOCK IOSB;
1646 ULONG TdiAddressSize;
1647 PTRANSPORT_ADDRESS SocketAddress;
1648 PSOCKET_INFORMATION Socket = NULL;
1649 NTSTATUS Status;
1650 HANDLE SockEvent;
1651
1652 Status = NtCreateEvent(&SockEvent,
1653 GENERIC_READ | GENERIC_WRITE,
1654 NULL,
1655 1,
1656 FALSE);
1657
1658 if( !NT_SUCCESS(Status) )
1659 return SOCKET_ERROR;
1660
1661 /* Get the Socket Structure associate to this Socket*/
1662 Socket = GetSocketStructure(s);
1663
1664 /* Allocate a buffer for the address */
1665 TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
1666 SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
1667
1668 if ( SocketAddress == NULL )
1669 {
1670 NtClose( SockEvent );
1671 *lpErrno = WSAENOBUFS;
1672 return SOCKET_ERROR;
1673 }
1674
1675 /* Send IOCTL */
1676 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1677 SockEvent,
1678 NULL,
1679 NULL,
1680 &IOSB,
1681 IOCTL_AFD_GET_PEER_NAME,
1682 NULL,
1683 0,
1684 SocketAddress,
1685 TdiAddressSize);
1686
1687 /* Wait for return */
1688 if (Status == STATUS_PENDING)
1689 {
1690 WaitForSingleObject(SockEvent, INFINITE);
1691 Status = IOSB.Status;
1692 }
1693
1694 NtClose( SockEvent );
1695
1696 if (NT_SUCCESS(Status))
1697 {
1698 if (*NameLength >= SocketAddress->Address[0].AddressLength)
1699 {
1700 Name->sa_family = SocketAddress->Address[0].AddressType;
1701 RtlCopyMemory (Name->sa_data,
1702 SocketAddress->Address[0].Address,
1703 SocketAddress->Address[0].AddressLength);
1704 *NameLength = 2 + SocketAddress->Address[0].AddressLength;
1705 AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
1706 *NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
1707 ((struct sockaddr_in *)Name)->sin_port));
1708 HeapFree(GlobalHeap, 0, SocketAddress);
1709 return 0;
1710 }
1711 else
1712 {
1713 HeapFree(GlobalHeap, 0, SocketAddress);
1714 *lpErrno = WSAEFAULT;
1715 return SOCKET_ERROR;
1716 }
1717 }
1718
1719 return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
1720 }
1721
1722 INT
1723 WSPAPI
1724 WSPIoctl(IN SOCKET Handle,
1725 IN DWORD dwIoControlCode,
1726 IN LPVOID lpvInBuffer,
1727 IN DWORD cbInBuffer,
1728 OUT LPVOID lpvOutBuffer,
1729 IN DWORD cbOutBuffer,
1730 OUT LPDWORD lpcbBytesReturned,
1731 IN LPWSAOVERLAPPED lpOverlapped,
1732 IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
1733 IN LPWSATHREADID lpThreadId,
1734 OUT LPINT lpErrno)
1735 {
1736 PSOCKET_INFORMATION Socket = NULL;
1737
1738 /* Get the Socket Structure associate to this Socket*/
1739 Socket = GetSocketStructure(Handle);
1740
1741 switch( dwIoControlCode )
1742 {
1743 case FIONBIO:
1744 if( cbInBuffer < sizeof(INT) || IS_INTRESOURCE(lpvInBuffer) )
1745 {
1746 *lpErrno = WSAEFAULT;
1747 return SOCKET_ERROR;
1748 }
1749 Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
1750 return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
1751 case FIONREAD:
1752 if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
1753 {
1754 *lpErrno = WSAEFAULT;
1755 return SOCKET_ERROR;
1756 }
1757 return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
1758 default:
1759 *lpErrno = WSAEINVAL;
1760 return SOCKET_ERROR;
1761 }
1762 }
1763
1764
1765 INT
1766 WSPAPI
1767 WSPGetSockOpt(IN SOCKET Handle,
1768 IN INT Level,
1769 IN INT OptionName,
1770 OUT CHAR FAR* OptionValue,
1771 IN OUT LPINT OptionLength,
1772 OUT LPINT lpErrno)
1773 {
1774 PSOCKET_INFORMATION Socket = NULL;
1775 PVOID Buffer;
1776 INT BufferSize;
1777 BOOLEAN BoolBuffer;
1778
1779 /* Get the Socket Structure associate to this Socket*/
1780 Socket = GetSocketStructure(Handle);
1781 if (Socket == NULL)
1782 {
1783 *lpErrno = WSAENOTSOCK;
1784 return SOCKET_ERROR;
1785 }
1786
1787 AFD_DbgPrint(MID_TRACE, ("Called\n"));
1788
1789 switch (Level)
1790 {
1791 case SOL_SOCKET:
1792 switch (OptionName)
1793 {
1794 case SO_TYPE:
1795 Buffer = &Socket->SharedData.SocketType;
1796 BufferSize = sizeof(INT);
1797 break;
1798
1799 case SO_RCVBUF:
1800 Buffer = &Socket->SharedData.SizeOfRecvBuffer;
1801 BufferSize = sizeof(INT);
1802 break;
1803
1804 case SO_SNDBUF:
1805 Buffer = &Socket->SharedData.SizeOfSendBuffer;
1806 BufferSize = sizeof(INT);
1807 break;
1808
1809 case SO_ACCEPTCONN:
1810 BoolBuffer = Socket->SharedData.Listening;
1811 Buffer = &BoolBuffer;
1812 BufferSize = sizeof(BOOLEAN);
1813 break;
1814
1815 case SO_BROADCAST:
1816 BoolBuffer = Socket->SharedData.Broadcast;
1817 Buffer = &BoolBuffer;
1818 BufferSize = sizeof(BOOLEAN);
1819 break;
1820
1821 case SO_DEBUG:
1822 BoolBuffer = Socket->SharedData.Debug;
1823 Buffer = &BoolBuffer;
1824 BufferSize = sizeof(BOOLEAN);
1825 break;
1826
1827 /* case SO_CONDITIONAL_ACCEPT: */
1828 case SO_DONTLINGER:
1829 case SO_DONTROUTE:
1830 case SO_ERROR:
1831 case SO_GROUP_ID:
1832 case SO_GROUP_PRIORITY:
1833 case SO_KEEPALIVE:
1834 case SO_LINGER:
1835 case SO_MAX_MSG_SIZE:
1836 case SO_OOBINLINE:
1837 case SO_PROTOCOL_INFO:
1838 case SO_REUSEADDR:
1839 AFD_DbgPrint(MID_TRACE, ("Unimplemented option (%x)\n",
1840 OptionName));
1841
1842 default:
1843 *lpErrno = WSAEINVAL;
1844 return SOCKET_ERROR;
1845 }
1846
1847 if (*OptionLength < BufferSize)
1848 {
1849 *lpErrno = WSAEFAULT;
1850 *OptionLength = BufferSize;
1851 return SOCKET_ERROR;
1852 }
1853 RtlCopyMemory(OptionValue, Buffer, BufferSize);
1854
1855 return 0;
1856
1857 case IPPROTO_TCP: /* FIXME */
1858 default:
1859 *lpErrno = WSAEINVAL;
1860 return SOCKET_ERROR;
1861 }
1862 }
1863
1864
1865 /*
1866 * FUNCTION: Initialize service provider for a client
1867 * ARGUMENTS:
1868 * wVersionRequested = Highest WinSock SPI version that the caller can use
1869 * lpWSPData = Address of WSPDATA structure to initialize
1870 * lpProtocolInfo = Pointer to structure that defines the desired protocol
1871 * UpcallTable = Pointer to upcall table of the WinSock DLL
1872 * lpProcTable = Address of procedure table to initialize
1873 * RETURNS:
1874 * Status of operation
1875 */
1876 INT
1877 WSPAPI
1878 WSPStartup(IN WORD wVersionRequested,
1879 OUT LPWSPDATA lpWSPData,
1880 IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
1881 IN WSPUPCALLTABLE UpcallTable,
1882 OUT LPWSPPROC_TABLE lpProcTable)
1883
1884 {
1885 NTSTATUS Status;
1886
1887 AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
1888 Status = NO_ERROR;
1889 Upcalls = UpcallTable;
1890
1891 if (Status == NO_ERROR)
1892 {
1893 lpProcTable->lpWSPAccept = WSPAccept;
1894 lpProcTable->lpWSPAddressToString = WSPAddressToString;
1895 lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
1896 lpProcTable->lpWSPBind = WSPBind;
1897 lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
1898 lpProcTable->lpWSPCleanup = WSPCleanup;
1899 lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
1900 lpProcTable->lpWSPConnect = WSPConnect;
1901 lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
1902 lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
1903 lpProcTable->lpWSPEventSelect = WSPEventSelect;
1904 lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
1905 lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
1906 lpProcTable->lpWSPGetSockName = WSPGetSockName;
1907 lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
1908 lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
1909 lpProcTable->lpWSPIoctl = WSPIoctl;
1910 lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
1911 lpProcTable->lpWSPListen = WSPListen;
1912 lpProcTable->lpWSPRecv = WSPRecv;
1913 lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
1914 lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
1915 lpProcTable->lpWSPSelect = WSPSelect;
1916 lpProcTable->lpWSPSend = WSPSend;
1917 lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
1918 lpProcTable->lpWSPSendTo = WSPSendTo;
1919 lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
1920 lpProcTable->lpWSPShutdown = WSPShutdown;
1921 lpProcTable->lpWSPSocket = WSPSocket;
1922 lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
1923 lpWSPData->wVersion = MAKEWORD(2, 2);
1924 lpWSPData->wHighVersion = MAKEWORD(2, 2);
1925 }
1926
1927 AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
1928
1929 return Status;
1930 }
1931
1932
1933 /*
1934 * FUNCTION: Cleans up service provider for a client
1935 * ARGUMENTS:
1936 * lpErrno = Address of buffer for error information
1937 * RETURNS:
1938 * 0 if successful, or SOCKET_ERROR if not
1939 */
1940 INT
1941 WSPAPI
1942 WSPCleanup(OUT LPINT lpErrno)
1943
1944 {
1945 AFD_DbgPrint(MAX_TRACE, ("\n"));
1946 AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
1947 *lpErrno = NO_ERROR;
1948
1949 return 0;
1950 }
1951
1952
1953
1954 int
1955 GetSocketInformation(PSOCKET_INFORMATION Socket,
1956 ULONG AfdInformationClass,
1957 PULONG Ulong OPTIONAL,
1958 PLARGE_INTEGER LargeInteger OPTIONAL)
1959 {
1960 IO_STATUS_BLOCK IOSB;
1961 AFD_INFO InfoData;
1962 NTSTATUS Status;
1963 HANDLE SockEvent;
1964
1965 Status = NtCreateEvent(&SockEvent,
1966 GENERIC_READ | GENERIC_WRITE,
1967 NULL,
1968 1,
1969 FALSE);
1970
1971 if( !NT_SUCCESS(Status) )
1972 return -1;
1973
1974 /* Set Info Class */
1975 InfoData.InformationClass = AfdInformationClass;
1976
1977 /* Send IOCTL */
1978 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
1979 SockEvent,
1980 NULL,
1981 NULL,
1982 &IOSB,
1983 IOCTL_AFD_GET_INFO,
1984 &InfoData,
1985 sizeof(InfoData),
1986 &InfoData,
1987 sizeof(InfoData));
1988
1989 /* Wait for return */
1990 if (Status == STATUS_PENDING)
1991 {
1992 WaitForSingleObject(SockEvent, INFINITE);
1993 }
1994
1995 /* Return Information */
1996 if (Ulong != NULL)
1997 {
1998 *Ulong = InfoData.Information.Ulong;
1999 }
2000 if (LargeInteger != NULL)
2001 {
2002 *LargeInteger = InfoData.Information.LargeInteger;
2003 }
2004
2005 NtClose( SockEvent );
2006
2007 return 0;
2008
2009 }
2010
2011
2012 int
2013 SetSocketInformation(PSOCKET_INFORMATION Socket,
2014 ULONG AfdInformationClass,
2015 PULONG Ulong OPTIONAL,
2016 PLARGE_INTEGER LargeInteger OPTIONAL)
2017 {
2018 IO_STATUS_BLOCK IOSB;
2019 AFD_INFO InfoData;
2020 NTSTATUS Status;
2021 HANDLE SockEvent;
2022
2023 Status = NtCreateEvent(&SockEvent,
2024 GENERIC_READ | GENERIC_WRITE,
2025 NULL,
2026 1,
2027 FALSE);
2028
2029 if( !NT_SUCCESS(Status) )
2030 return -1;
2031
2032 /* Set Info Class */
2033 InfoData.InformationClass = AfdInformationClass;
2034
2035 /* Set Information */
2036 if (Ulong != NULL)
2037 {
2038 InfoData.Information.Ulong = *Ulong;
2039 }
2040 if (LargeInteger != NULL)
2041 {
2042 InfoData.Information.LargeInteger = *LargeInteger;
2043 }
2044
2045 AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
2046 AfdInformationClass, *Ulong));
2047
2048 /* Send IOCTL */
2049 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
2050 SockEvent,
2051 NULL,
2052 NULL,
2053 &IOSB,
2054 IOCTL_AFD_SET_INFO,
2055 &InfoData,
2056 sizeof(InfoData),
2057 NULL,
2058 0);
2059
2060 /* Wait for return */
2061 if (Status == STATUS_PENDING)
2062 {
2063 WaitForSingleObject(SockEvent, INFINITE);
2064 }
2065
2066 NtClose( SockEvent );
2067
2068 return 0;
2069
2070 }
2071
2072 PSOCKET_INFORMATION
2073 GetSocketStructure(SOCKET Handle)
2074 {
2075 ULONG i;
2076
2077 for (i=0; i<SocketCount; i++)
2078 {
2079 if (Sockets[i]->Handle == Handle)
2080 {
2081 return Sockets[i];
2082 }
2083 }
2084 return 0;
2085 }
2086
2087 int CreateContext(PSOCKET_INFORMATION Socket)
2088 {
2089 IO_STATUS_BLOCK IOSB;
2090 SOCKET_CONTEXT ContextData;
2091 NTSTATUS Status;
2092 HANDLE SockEvent;
2093
2094 Status = NtCreateEvent(&SockEvent,
2095 GENERIC_READ | GENERIC_WRITE,
2096 NULL,
2097 1,
2098 FALSE);
2099
2100 if( !NT_SUCCESS(Status) )
2101 return -1;
2102
2103 /* Create Context */
2104 ContextData.SharedData = Socket->SharedData;
2105 ContextData.SizeOfHelperData = 0;
2106 RtlCopyMemory (&ContextData.LocalAddress,
2107 Socket->LocalAddress,
2108 Socket->SharedData.SizeOfLocalAddress);
2109 RtlCopyMemory (&ContextData.RemoteAddress,
2110 Socket->RemoteAddress,
2111 Socket->SharedData.SizeOfRemoteAddress);
2112
2113 /* Send IOCTL */
2114 Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
2115 SockEvent,
2116 NULL,
2117 NULL,
2118 &IOSB,
2119 IOCTL_AFD_SET_CONTEXT,
2120 &ContextData,
2121 sizeof(ContextData),
2122 NULL,
2123 0);
2124
2125 /* Wait for Completition */
2126 if (Status == STATUS_PENDING)
2127 {
2128 WaitForSingleObject(SockEvent, INFINITE);
2129 }
2130
2131 NtClose( SockEvent );
2132
2133 return 0;
2134 }
2135
2136 BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
2137 {
2138 HANDLE hAsyncThread;
2139 DWORD AsyncThreadId;
2140 HANDLE AsyncEvent;
2141 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2142 NTSTATUS Status;
2143
2144 /* Check if the Thread Already Exists */
2145 if (SockAsyncThreadRefCount)
2146 {
2147 return TRUE;
2148 }
2149
2150 /* Create the Completion Port */
2151 if (!SockAsyncCompletionPort)
2152 {
2153 Status = NtCreateIoCompletion(&SockAsyncCompletionPort,
2154 IO_COMPLETION_ALL_ACCESS,
2155 NULL,
2156 2); // Allow 2 threads only
2157
2158 /* Protect Handle */
2159 HandleFlags.ProtectFromClose = TRUE;
2160 HandleFlags.Inherit = FALSE;
2161 Status = NtSetInformationObject(SockAsyncCompletionPort,
2162 ObjectHandleFlagInformation,
2163 &HandleFlags,
2164 sizeof(HandleFlags));
2165 }
2166
2167 /* Create the Async Event */
2168 Status = NtCreateEvent(&AsyncEvent,
2169 EVENT_ALL_ACCESS,
2170 NULL,
2171 NotificationEvent,
2172 FALSE);
2173
2174 /* Create the Async Thread */
2175 hAsyncThread = CreateThread(NULL,
2176 0,
2177 (LPTHREAD_START_ROUTINE)SockAsyncThread,
2178 NULL,
2179 0,
2180 &AsyncThreadId);
2181
2182 /* Close the Handle */
2183 NtClose(hAsyncThread);
2184
2185 /* Increase the Reference Count */
2186 SockAsyncThreadRefCount++;
2187 return TRUE;
2188 }
2189
2190 int SockAsyncThread(PVOID ThreadParam)
2191 {
2192 PVOID AsyncContext;
2193 PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
2194 IO_STATUS_BLOCK IOSB;
2195 NTSTATUS Status;
2196
2197 /* Make the Thread Higher Priority */
2198 SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_ABOVE_NORMAL);
2199
2200 /* Do a KQUEUE/WorkItem Style Loop, thanks to IoCompletion Ports */
2201 do
2202 {
2203 Status = NtRemoveIoCompletion (SockAsyncCompletionPort,
2204 (PVOID*)&AsyncCompletionRoutine,
2205 &AsyncContext,
2206 &IOSB,
2207 NULL);
2208 /* Call the Async Function */
2209 if (NT_SUCCESS(Status))
2210 {
2211 (*AsyncCompletionRoutine)(AsyncContext, &IOSB);
2212 }
2213 else
2214 {
2215 /* It Failed, sleep for a second */
2216 Sleep(1000);
2217 }
2218 } while ((Status != STATUS_TIMEOUT));
2219
2220 /* The Thread has Ended */
2221 return 0;
2222 }
2223
2224 BOOLEAN SockGetAsyncSelectHelperAfdHandle(VOID)
2225 {
2226 UNICODE_STRING AfdHelper;
2227 OBJECT_ATTRIBUTES ObjectAttributes;
2228 IO_STATUS_BLOCK IoSb;
2229 NTSTATUS Status;
2230 FILE_COMPLETION_INFORMATION CompletionInfo;
2231 OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
2232
2233 /* First, make sure we're not already intialized */
2234 if (SockAsyncHelperAfdHandle)
2235 {
2236 return TRUE;
2237 }
2238
2239 /* Set up Handle Name and Object */
2240 RtlInitUnicodeString(&AfdHelper, L"\\Device\\Afd\\AsyncSelectHlp" );
2241 InitializeObjectAttributes(&ObjectAttributes,
2242 &AfdHelper,
2243 OBJ_INHERIT | OBJ_CASE_INSENSITIVE,
2244 NULL,
2245 NULL);
2246
2247 /* Open the Handle to AFD */
2248 Status = NtCreateFile(&SockAsyncHelperAfdHandle,
2249 GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
2250 &ObjectAttributes,
2251 &IoSb,
2252 NULL,
2253 0,
2254 FILE_SHARE_READ | FILE_SHARE_WRITE,
2255 FILE_OPEN_IF,
2256 0,
2257 NULL,
2258 0);
2259
2260 /*
2261 * Now Set up the Completion Port Information
2262 * This means that whenever a Poll is finished, the routine will be executed
2263 */
2264 CompletionInfo.Port = SockAsyncCompletionPort;
2265 CompletionInfo.Key = SockAsyncSelectCompletionRoutine;
2266 Status = NtSetInformationFile(SockAsyncHelperAfdHandle,
2267 &IoSb,
2268 &CompletionInfo,
2269 sizeof(CompletionInfo),
2270 FileCompletionInformation);
2271
2272
2273 /* Protect the Handle */
2274 HandleFlags.ProtectFromClose = TRUE;
2275 HandleFlags.Inherit = FALSE;
2276 Status = NtSetInformationObject(SockAsyncCompletionPort,
2277 ObjectHandleFlagInformation,
2278 &HandleFlags,
2279 sizeof(HandleFlags));
2280
2281
2282 /* Set this variable to true so that Send/Recv/Accept will know wether to renable disabled events */
2283 SockAsyncSelectCalled = TRUE;
2284 return TRUE;
2285 }
2286
2287 VOID SockAsyncSelectCompletionRoutine(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2288 {
2289
2290 PASYNC_DATA AsyncData = Context;
2291 PSOCKET_INFORMATION Socket;
2292 ULONG x;
2293
2294 /* Get the Socket */
2295 Socket = AsyncData->ParentSocket;
2296
2297 /* Check if the Sequence Number Changed behind our back */
2298 if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber )
2299 {
2300 return;
2301 }
2302
2303 /* Check we were manually called b/c of a failure */
2304 if (!NT_SUCCESS(IoStatusBlock->Status))
2305 {
2306 /* FIXME: Perform Upcall */
2307 return;
2308 }
2309
2310 for (x = 1; x; x<<=1)
2311 {
2312 switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x)
2313 {
2314 case AFD_EVENT_RECEIVE:
2315 if (0 != (Socket->SharedData.AsyncEvents & FD_READ) &&
2316 0 == (Socket->SharedData.AsyncDisabledEvents & FD_READ))
2317 {
2318 /* Make the Notifcation */
2319 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2320 Socket->SharedData.wMsg,
2321 Socket->Handle,
2322 WSAMAKESELECTREPLY(FD_READ, 0));
2323 /* Disable this event until the next read(); */
2324 Socket->SharedData.AsyncDisabledEvents |= FD_READ;
2325 }
2326 break;
2327
2328 case AFD_EVENT_OOB_RECEIVE:
2329 if (0 != (Socket->SharedData.AsyncEvents & FD_OOB) &&
2330 0 == (Socket->SharedData.AsyncDisabledEvents & FD_OOB))
2331 {
2332 /* Make the Notifcation */
2333 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2334 Socket->SharedData.wMsg,
2335 Socket->Handle,
2336 WSAMAKESELECTREPLY(FD_OOB, 0));
2337 /* Disable this event until the next read(); */
2338 Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
2339 }
2340 break;
2341
2342 case AFD_EVENT_SEND:
2343 if (0 != (Socket->SharedData.AsyncEvents & FD_WRITE) &&
2344 0 == (Socket->SharedData.AsyncDisabledEvents & FD_WRITE))
2345 {
2346 /* Make the Notifcation */
2347 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2348 Socket->SharedData.wMsg,
2349 Socket->Handle,
2350 WSAMAKESELECTREPLY(FD_WRITE, 0));
2351 /* Disable this event until the next write(); */
2352 Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
2353 }
2354 break;
2355
2356 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2357 case AFD_EVENT_CONNECT:
2358 case AFD_EVENT_CONNECT_FAIL:
2359 if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
2360 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
2361 {
2362 /* Make the Notifcation */
2363 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2364 Socket->SharedData.wMsg,
2365 Socket->Handle,
2366 WSAMAKESELECTREPLY(FD_CONNECT, 0));
2367 /* Disable this event forever; */
2368 Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT;
2369 }
2370 break;
2371
2372 case AFD_EVENT_ACCEPT:
2373 if (0 != (Socket->SharedData.AsyncEvents & FD_ACCEPT) &&
2374 0 == (Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT))
2375 {
2376 /* Make the Notifcation */
2377 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2378 Socket->SharedData.wMsg,
2379 Socket->Handle,
2380 WSAMAKESELECTREPLY(FD_ACCEPT, 0));
2381 /* Disable this event until the next accept(); */
2382 Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
2383 }
2384 break;
2385
2386 case AFD_EVENT_DISCONNECT:
2387 case AFD_EVENT_ABORT:
2388 case AFD_EVENT_CLOSE:
2389 if (0 != (Socket->SharedData.AsyncEvents & FD_CLOSE) &&
2390 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CLOSE))
2391 {
2392 /* Make the Notifcation */
2393 (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
2394 Socket->SharedData.wMsg,
2395 Socket->Handle,
2396 WSAMAKESELECTREPLY(FD_CLOSE, 0));
2397 /* Disable this event forever; */
2398 Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
2399 }
2400 break;
2401 /* FIXME: Support QOS */
2402 }
2403 }
2404
2405 /* Check if there are any events left for us to check */
2406 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2407 {
2408 return;
2409 }
2410
2411 /* Keep Polling */
2412 SockProcessAsyncSelect(Socket, AsyncData);
2413 return;
2414 }
2415
2416 VOID SockProcessAsyncSelect(PSOCKET_INFORMATION Socket, PASYNC_DATA AsyncData)
2417 {
2418
2419 ULONG lNetworkEvents;
2420 NTSTATUS Status;
2421
2422 /* Set up the Async Data Event Info */
2423 AsyncData->AsyncSelectInfo.Timeout.HighPart = 0x7FFFFFFF;
2424 AsyncData->AsyncSelectInfo.Timeout.LowPart = 0xFFFFFFFF;
2425 AsyncData->AsyncSelectInfo.HandleCount = 1;
2426 AsyncData->AsyncSelectInfo.Exclusive = TRUE;
2427 AsyncData->AsyncSelectInfo.Handles[0].Handle = Socket->Handle;
2428 AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
2429
2430 /* Remove unwanted events */
2431 lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
2432
2433 /* Set Events to wait for */
2434 if (lNetworkEvents & FD_READ)
2435 {
2436 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_RECEIVE;
2437 }
2438
2439 if (lNetworkEvents & FD_WRITE)
2440 {
2441 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_SEND;
2442 }
2443
2444 if (lNetworkEvents & FD_OOB)
2445 {
2446 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_OOB_RECEIVE;
2447 }
2448
2449 if (lNetworkEvents & FD_ACCEPT)
2450 {
2451 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_ACCEPT;
2452 }
2453
2454 /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
2455 if (lNetworkEvents & FD_CONNECT)
2456 {
2457 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_CONNECT | AFD_EVENT_CONNECT_FAIL;
2458 }
2459
2460 if (lNetworkEvents & FD_CLOSE)
2461 {
2462 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_DISCONNECT | AFD_EVENT_ABORT | AFD_EVENT_CLOSE;
2463 }
2464
2465 if (lNetworkEvents & FD_QOS)
2466 {
2467 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_QOS;
2468 }
2469
2470 if (lNetworkEvents & FD_GROUP_QOS)
2471 {
2472 AsyncData->AsyncSelectInfo.Handles[0].Events |= AFD_EVENT_GROUP_QOS;
2473 }
2474
2475 /* Send IOCTL */
2476 Status = NtDeviceIoControlFile (SockAsyncHelperAfdHandle,
2477 NULL,
2478 NULL,
2479 AsyncData,
2480 &AsyncData->IoStatusBlock,
2481 IOCTL_AFD_SELECT,
2482 &AsyncData->AsyncSelectInfo,
2483 sizeof(AsyncData->AsyncSelectInfo),
2484 &AsyncData->AsyncSelectInfo,
2485 sizeof(AsyncData->AsyncSelectInfo));
2486
2487 /* I/O Manager Won't call the completion routine, let's do it manually */
2488 if (NT_SUCCESS(Status))
2489 {
2490 return;
2491 }
2492 else
2493 {
2494 AsyncData->IoStatusBlock.Status = Status;
2495 SockAsyncSelectCompletionRoutine(AsyncData, &AsyncData->IoStatusBlock);
2496 }
2497 }
2498
2499 VOID SockProcessQueuedAsyncSelect(PVOID Context, PIO_STATUS_BLOCK IoStatusBlock)
2500 {
2501 PASYNC_DATA AsyncData = Context;
2502 BOOL FreeContext = TRUE;
2503 PSOCKET_INFORMATION Socket;
2504
2505 /* Get the Socket */
2506 Socket = AsyncData->ParentSocket;
2507
2508 /* If someone closed it, stop the function */
2509 if (Socket->SharedData.State != SocketClosed)
2510 {
2511 /* Check if the Sequence Number changed by now, in which case quit */
2512 if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber)
2513 {
2514 /* Do the actuall select, if needed */
2515 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)))
2516 {
2517 SockProcessAsyncSelect(Socket, AsyncData);
2518 FreeContext = FALSE;
2519 }
2520 }
2521 }
2522
2523 /* Free the Context */
2524 if (FreeContext)
2525 {
2526 HeapFree(GetProcessHeap(), 0, AsyncData);
2527 }
2528
2529 return;
2530 }
2531
2532 VOID
2533 SockReenableAsyncSelectEvent (IN PSOCKET_INFORMATION Socket,
2534 IN ULONG Event)
2535 {
2536 PASYNC_DATA AsyncData;
2537
2538 /* Make sure the event is actually disabled */
2539 if (!(Socket->SharedData.AsyncDisabledEvents & Event))
2540 {
2541 return;
2542 }
2543
2544 /* Re-enable it */
2545 Socket->SharedData.AsyncDisabledEvents &= ~Event;
2546
2547 /* Return if no more events are being polled */
2548 if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
2549 {
2550 return;
2551 }
2552
2553 /* Wait on new events */
2554 AsyncData = HeapAlloc(GetProcessHeap(), 0, sizeof(ASYNC_DATA));
2555 if (!AsyncData) return;
2556
2557 /* Create the Asynch Thread if Needed */
2558 SockCreateOrReferenceAsyncThread();
2559
2560 /* Increase the sequence number to stop anything else */
2561 Socket->SharedData.SequenceNumber++;
2562
2563 /* Set up the Async Data */
2564 AsyncData->ParentSocket = Socket;
2565 AsyncData->SequenceNumber = Socket->SharedData.SequenceNumber;
2566
2567 /* Begin Async Select by using I/O Completion */
2568 NtSetIoCompletion(SockAsyncCompletionPort,
2569 (PVOID)&SockProcessQueuedAsyncSelect,
2570 AsyncData,
2571 0,
2572 0);
2573
2574 /* All done */
2575 return;
2576 }
2577
2578 BOOL
2579 WINAPI
2580 DllMain(HANDLE hInstDll,
2581 ULONG dwReason,
2582 PVOID Reserved)
2583 {
2584
2585 switch (dwReason)
2586 {
2587 case DLL_PROCESS_ATTACH:
2588
2589 AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
2590
2591 /* Don't need thread attach notifications
2592 so disable them to improve performance */
2593 DisableThreadLibraryCalls(hInstDll);
2594
2595 /* List of DLL Helpers */
2596 InitializeListHead(&SockHelpersListHead);
2597
2598 /* Heap to use when allocating */
2599 GlobalHeap = GetProcessHeap();
2600
2601 /* Allocate Heap for 1024 Sockets, can be expanded later */
2602 Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
2603 if (!Sockets) return FALSE;
2604
2605 AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
2606
2607 break;
2608
2609 case DLL_THREAD_ATTACH:
2610 break;
2611
2612 case DLL_THREAD_DETACH:
2613 break;
2614
2615 case DLL_PROCESS_DETACH:
2616 break;
2617 }
2618
2619 AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
2620
2621 return TRUE;
2622 }
2623
2624 /* EOF */
2625
2626