Protocol = SharedData->Protocol;
}
- if (AddressFamily == AF_UNSPEC && SocketType == 0 && Protocol == 0)
+ if (lpProtocolInfo)
+ {
+ if (lpProtocolInfo->iAddressFamily && AddressFamily <= 0)
+ AddressFamily = lpProtocolInfo->iAddressFamily;
+ if (lpProtocolInfo->iSocketType && SocketType <= 0)
+ SocketType = lpProtocolInfo->iSocketType;
+ if (lpProtocolInfo->iProtocol && Protocol <= 0)
+ Protocol = lpProtocolInfo->iProtocol;
+ }
+
+ /* FIXME: AF_NETDES should be AF_MAX */
+ if (AddressFamily < AF_UNSPEC || AddressFamily > AF_NETDES)
return WSAEINVAL;
- /* Set the defaults */
- if (AddressFamily == AF_UNSPEC)
- AddressFamily = AF_INET;
+ if (SocketType < 0 && SocketType > SOCK_SEQPACKET)
+ return WSAEINVAL;
+ if (Protocol < 0 && Protocol > IPPROTO_MAX)
+ return WSAEINVAL;
+
+ /* when no protocol and socket type are specified the first entry
+ * from WSAEnumProtocols that has the flag PFL_MATCHES_PROTOCOL_ZERO
+ * is returned */
+ if (SocketType == 0 && Protocol == 0 && lpProtocolInfo && (lpProtocolInfo->dwProviderFlags & PFL_MATCHES_PROTOCOL_ZERO) == 0)
+ return WSAEINVAL;
+
+ /* Set the defaults */
if (SocketType == 0)
{
switch (Protocol)
break;
default:
TRACE("Unknown Protocol (%d). We will try SOCK_STREAM.\n", Protocol);
- SocketType = SOCK_STREAM;
- break;
+ return WSAEINVAL;
}
}
break;
default:
TRACE("Unknown SocketType (%d). We will try IPPROTO_TCP.\n", SocketType);
- Protocol = IPPROTO_TCP;
- break;
+ return WSAEINVAL;
}
}
+ if (AddressFamily == AF_UNSPEC)
+ return WSAEINVAL;
+
/* Get Helper Data and Transport */
Status = SockGetTdiName (&AddressFamily,
&SocketType,
PAFD_POLL_INFO PollInfo;
NTSTATUS Status;
ULONG HandleCount;
- LONG OutCount = 0;
ULONG PollBufferSize;
PVOID PollBuffer;
ULONG i, j = 0, x;
HANDLE SockEvent;
- BOOL HandleCounted;
LARGE_INTEGER Timeout;
+ PSOCKET_INFORMATION Socket;
+ SOCKET Handle;
+ ULONG Events;
/* Find out how many sockets we have, and how large the buffer needs
* to be */
if (readfds != NULL) {
for (i = 0; i < readfds->fd_count; i++, j++)
{
+ Socket = GetSocketStructure(readfds->fd_array[i]);
+ if (!Socket)
+ {
+ ERR("Invalid socket handle provided in readfds %d\n", readfds->fd_array[i]);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
PollInfo->Handles[j].Handle = readfds->fd_array[i];
PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
AFD_EVENT_DISCONNECT |
AFD_EVENT_ABORT |
AFD_EVENT_CLOSE |
AFD_EVENT_ACCEPT;
+ if (Socket->SharedData->OobInline != 0)
+ PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
}
}
if (writefds != NULL)
{
for (i = 0; i < writefds->fd_count; i++, j++)
{
+ Socket = GetSocketStructure(writefds->fd_array[i]);
+ if (!Socket)
+ {
+ ERR("Invalid socket handle provided in writefds %d\n", writefds->fd_array[i]);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
PollInfo->Handles[j].Handle = writefds->fd_array[i];
- PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
+ PollInfo->Handles[j].Events = AFD_EVENT_SEND;
+ if (Socket->SharedData->NonBlocking != 0)
+ PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT;
}
}
if (exceptfds != NULL)
{
for (i = 0; i < exceptfds->fd_count; i++, j++)
{
+ Socket = GetSocketStructure(exceptfds->fd_array[i]);
+ if (!Socket)
+ {
+ TRACE("Invalid socket handle provided in exceptfds %d\n", exceptfds->fd_array[i]);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
- PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
+ PollInfo->Handles[j].Events = 0;
+ if (Socket->SharedData->OobInline == 0)
+ PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
+ if (Socket->SharedData->NonBlocking != 0)
+ PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT_FAIL;
+ if (PollInfo->Handles[j].Events == 0)
+ {
+ TRACE("No events can be checked for exceptfds %d. It is nonblocking and OOB line is disabled. Skipping it.", exceptfds->fd_array[i]);
+ j--;
+ }
}
}
/* Return in FDSET Format */
for (i = 0; i < HandleCount; i++)
{
- HandleCounted = FALSE;
+ Events = PollInfo->Handles[i].Events;
+ Handle = PollInfo->Handles[i].Handle;
for(x = 1; x; x<<=1)
{
- switch (PollInfo->Handles[i].Events & x)
+ Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ TRACE("Invalid socket handle found %d\n", Handle);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
+ switch (Events & x)
{
case AFD_EVENT_RECEIVE:
case AFD_EVENT_DISCONNECT:
case AFD_EVENT_ACCEPT:
case AFD_EVENT_CLOSE:
TRACE("Event %x on handle %x\n",
- PollInfo->Handles[i].Events,
- PollInfo->Handles[i].Handle);
- if (! HandleCounted)
- {
- OutCount++;
- HandleCounted = TRUE;
- }
+ Events,
+ Handle);
if( readfds )
- FD_SET(PollInfo->Handles[i].Handle, readfds);
+ FD_SET(Handle, readfds);
break;
case AFD_EVENT_SEND:
+ TRACE("Event %x on handle %x\n",
+ Events,
+ Handle);
+ if (writefds)
+ FD_SET(Handle, writefds);
+ break;
case AFD_EVENT_CONNECT:
TRACE("Event %x on handle %x\n",
- PollInfo->Handles[i].Events,
- PollInfo->Handles[i].Handle);
- if (! HandleCounted)
- {
- OutCount++;
- HandleCounted = TRUE;
- }
- if( writefds )
- FD_SET(PollInfo->Handles[i].Handle, writefds);
+ Events,
+ Handle);
+ if( writefds && Socket->SharedData->NonBlocking != 0 )
+ FD_SET(Handle, writefds);
break;
case AFD_EVENT_OOB_RECEIVE:
+ TRACE("Event %x on handle %x\n",
+ Events,
+ Handle);
+ if( readfds && Socket->SharedData->OobInline != 0 )
+ FD_SET(Handle, readfds);
+ if( exceptfds && Socket->SharedData->OobInline == 0 )
+ FD_SET(Handle, exceptfds);
+ break;
case AFD_EVENT_CONNECT_FAIL:
TRACE("Event %x on handle %x\n",
- PollInfo->Handles[i].Events,
- PollInfo->Handles[i].Handle);
- if (! HandleCounted)
- {
- OutCount++;
- HandleCounted = TRUE;
- }
- if( exceptfds )
- FD_SET(PollInfo->Handles[i].Handle, exceptfds);
+ Events,
+ Handle);
+ if( exceptfds && Socket->SharedData->NonBlocking != 0 )
+ FD_SET(Handle, exceptfds);
break;
}
}
TRACE("*lpErrno = %x\n", *lpErrno);
}
- TRACE("%d events\n", OutCount);
+ HandleCount = (readfds ? readfds->fd_count : 0) +
+ (writefds && writefds != readfds ? writefds->fd_count : 0) +
+ (exceptfds && exceptfds != readfds && exceptfds != writefds ? exceptfds->fd_count : 0);
+
+ TRACE("%d events\n", HandleCount);
- return OutCount;
+ return HandleCount;
}
SOCKET
return AcceptSocket;
}
+VOID
+NTAPI
+AfdConnectAPC(PVOID ApcContext,
+ PIO_STATUS_BLOCK IoStatusBlock,
+ ULONG Reserved)
+{
+ PAFDCONNECTAPCCONTEXT Context = ApcContext;
+
+ if (IoStatusBlock->Status == STATUS_SUCCESS)
+ {
+ Context->lpSocket->SharedData->State = SocketConnected;
+ Context->lpSocket->TdiConnectionHandle = (HANDLE)IoStatusBlock->Information;
+ }
+
+ if (Context->lpConnectInfo) HeapFree(GetProcessHeap(), 0, Context->lpConnectInfo);
+
+ /* Re-enable Async Event */
+ SockReenableAsyncSelectEvent(Context->lpSocket, FD_WRITE);
+
+ /* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
+ SockReenableAsyncSelectEvent(Context->lpSocket, FD_CONNECT);
+
+ if (IoStatusBlock->Status == STATUS_SUCCESS && (Context->lpSocket->HelperEvents & WSH_NOTIFY_CONNECT))
+ {
+ Context->lpSocket->HelperData->WSHNotify(Context->lpSocket->HelperContext,
+ Context->lpSocket->Handle,
+ Context->lpSocket->TdiAddressHandle,
+ Context->lpSocket->TdiConnectionHandle,
+ WSH_NOTIFY_CONNECT);
+ }
+ else if (IoStatusBlock->Status != STATUS_SUCCESS && (Context->lpSocket->HelperEvents & WSH_NOTIFY_CONNECT_ERROR))
+ {
+ Context->lpSocket->HelperData->WSHNotify(Context->lpSocket->HelperContext,
+ Context->lpSocket->Handle,
+ Context->lpSocket->TdiAddressHandle,
+ Context->lpSocket->TdiConnectionHandle,
+ WSH_NOTIFY_CONNECT_ERROR);
+ }
+ HeapFree(GlobalHeap, 0, ApcContext);
+}
int
WSPAPI
WSPConnect(SOCKET Handle,
PSOCKADDR BindAddress;
HANDLE SockEvent;
int SocketDataLength;
-
- Status = NtCreateEvent(&SockEvent,
- EVENT_ALL_ACCESS,
- NULL,
- 1,
- FALSE);
-
- if (!NT_SUCCESS(Status))
- return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
+ PVOID APCContext = NULL;
+ PVOID APCFunction = NULL;
TRACE("Called\n");
Socket = GetSocketStructure(Handle);
if (!Socket)
{
- NtClose(SockEvent);
if (lpErrno) *lpErrno = WSAENOTSOCK;
return SOCKET_ERROR;
}
+ Status = NtCreateEvent(&SockEvent,
+ EVENT_ALL_ACCESS,
+ NULL,
+ 1,
+ FALSE);
+
+ if (!NT_SUCCESS(Status))
+ return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
+
/* Bind us First */
if (Socket->SharedData->State == SocketOpen)
{
/* FIXME: Handle Async Connect */
if (Socket->SharedData->NonBlocking)
{
- ERR("Async Connect UNIMPLEMENTED!\n");
+ APCFunction = &AfdConnectAPC; // should be a private io completition function inside us
+ APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDCONNECTAPCCONTEXT));
+ if (!APCContext)
+ {
+ ERR("Not enough memory for APC Context\n");
+ return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+ }
+ ((PAFDCONNECTAPCCONTEXT)APCContext)->lpConnectInfo = ConnectInfo;
+ ((PAFDCONNECTAPCCONTEXT)APCContext)->lpSocket = Socket;
}
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Handle,
SockEvent,
- NULL,
- NULL,
+ APCFunction,
+ APCContext,
&IOSB,
IOCTL_AFD_CONNECT,
ConnectInfo,
NULL,
0);
/* Wait for return */
- if (Status == STATUS_PENDING)
+ if (Status == STATUS_PENDING && !Socket->SharedData->NonBlocking)
{
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
}
+ if (Status == STATUS_PENDING)
+ {
+ TRACE("Leaving (Pending)\n");
+ return MsafdReturnWithErrno(STATUS_CANT_WAIT, lpErrno, 0, NULL);
+ }
+
+ if (APCContext) HeapFree(GetProcessHeap(), 0, APCContext);
+
if (Status != STATUS_SUCCESS)
goto notify;