/*
* COPYRIGHT: See COPYING in the top level directory
* PROJECT: ReactOS Ancillary Function Driver DLL
- * FILE: misc/dllmain.c
+ * FILE: dll/win32/msafd/misc/dllmain.c
* PURPOSE: DLL entry point
* PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
* Alex Ionescu (alex@relsoft.net)
#include <msafd.h>
#include <winuser.h>
-
-#if DBG
-//DWORD DebugTraceLevel = DEBUG_ULTRA;
-DWORD DebugTraceLevel = 0;
-#endif /* DBG */
+#include <wchar.h>
HANDLE GlobalHeap;
WSPUPCALLTABLE Upcalls;
+DWORD CatalogEntryId; /* CatalogEntryId for upcalls */
LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
PSOCKET_INFORMATION SocketListHead = NULL;
CRITICAL_SECTION SocketListLock;
UNICODE_STRING DevName;
LARGE_INTEGER GroupData;
INT Status;
+ PSOCK_SHARED_INFO SharedData = NULL;
+
+ TRACE("Creating Socket, getting TDI Name - AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
+ AddressFamily, SocketType, Protocol);
+
+ if (lpProtocolInfo && lpProtocolInfo->dwServiceFlags3 != 0 && lpProtocolInfo->dwServiceFlags4 != 0)
+ {
+ /* Duplpicating socket from different process */
+ if ((HANDLE)lpProtocolInfo->dwServiceFlags3 == INVALID_HANDLE_VALUE)
+ {
+ Status = WSAEINVAL;
+ goto error;
+ }
+ if ((HANDLE)lpProtocolInfo->dwServiceFlags4 == INVALID_HANDLE_VALUE)
+ {
+ Status = WSAEINVAL;
+ goto error;
+ }
+ SharedData = MapViewOfFile((HANDLE)lpProtocolInfo->dwServiceFlags3,
+ FILE_MAP_ALL_ACCESS,
+ 0,
+ 0,
+ sizeof(SOCK_SHARED_INFO));
+ if (!SharedData)
+ {
+ Status = WSAEINVAL;
+ goto error;
+ }
+ InterlockedIncrement(&SharedData->RefCount);
+ AddressFamily = SharedData->AddressFamily;
+ SocketType = SharedData->SocketType;
+ Protocol = SharedData->Protocol;
+ }
+
+ if (AddressFamily == AF_UNSPEC && SocketType == 0 && Protocol == 0)
+ {
+ Status = WSAEINVAL;
+ goto error;
+ }
+
+ /* Set the defaults */
+ if (AddressFamily == AF_UNSPEC)
+ AddressFamily = AF_INET;
+
+ if (SocketType == 0)
+ {
+ switch (Protocol)
+ {
+ case IPPROTO_TCP:
+ SocketType = SOCK_STREAM;
+ break;
+ case IPPROTO_UDP:
+ SocketType = SOCK_DGRAM;
+ break;
+ case IPPROTO_RAW:
+ SocketType = SOCK_RAW;
+ break;
+ default:
+ TRACE("Unknown Protocol (%d). We will try SOCK_STREAM.\n", Protocol);
+ SocketType = SOCK_STREAM;
+ break;
+ }
+ }
- AFD_DbgPrint(MAX_TRACE, ("Creating Socket, getting TDI Name\n"));
- AFD_DbgPrint(MAX_TRACE, ("AddressFamily (%d) SocketType (%d) Protocol (%d).\n",
- AddressFamily, SocketType, Protocol));
+ if (Protocol == 0)
+ {
+ switch (SocketType)
+ {
+ case SOCK_STREAM:
+ Protocol = IPPROTO_TCP;
+ break;
+ case SOCK_DGRAM:
+ Protocol = IPPROTO_UDP;
+ break;
+ case SOCK_RAW:
+ Protocol = IPPROTO_RAW;
+ break;
+ default:
+ TRACE("Unknown SocketType (%d). We will try IPPROTO_TCP.\n", SocketType);
+ Protocol = IPPROTO_TCP;
+ break;
+ }
+ }
/* Get Helper Data and Transport */
Status = SockGetTdiName (&AddressFamily,
/* Check for error */
if (Status != NO_ERROR)
{
- AFD_DbgPrint(MID_TRACE,("SockGetTdiName: Status %x\n", Status));
+ ERR("SockGetTdiName: Status %x\n", Status);
goto error;
}
/* Set Socket Data */
Socket = HeapAlloc(GlobalHeap, 0, sizeof(*Socket));
if (!Socket)
- return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
-
+ {
+ Status = WSAENOBUFS;
+ goto error;
+ }
RtlZeroMemory(Socket, sizeof(*Socket));
- Socket->RefCount = 2;
- Socket->Handle = -1;
- Socket->SharedData.Listening = FALSE;
- Socket->SharedData.State = SocketOpen;
- Socket->SharedData.AddressFamily = AddressFamily;
- Socket->SharedData.SocketType = SocketType;
- Socket->SharedData.Protocol = Protocol;
+ if (SharedData)
+ {
+ Socket->SharedData = SharedData;
+ Socket->SharedDataHandle = (HANDLE)lpProtocolInfo->dwServiceFlags3;
+ Sock = (HANDLE)lpProtocolInfo->dwServiceFlags4;
+ Socket->Handle = (SOCKET)lpProtocolInfo->dwServiceFlags4;
+ }
+ else
+ {
+ Socket->SharedDataHandle = INVALID_HANDLE_VALUE;
+ Socket->SharedData = HeapAlloc(GlobalHeap, 0, sizeof(*Socket->SharedData));
+ if (!Socket->SharedData)
+ {
+ Status = WSAENOBUFS;
+ goto error;
+ }
+ RtlZeroMemory(Socket->SharedData, sizeof(*Socket->SharedData));
+ Socket->SharedData->State = SocketOpen;
+ Socket->SharedData->RefCount = 1L;
+ Socket->SharedData->Listening = FALSE;
+ Socket->SharedData->AddressFamily = AddressFamily;
+ Socket->SharedData->SocketType = SocketType;
+ Socket->SharedData->Protocol = Protocol;
+ Socket->SharedData->SizeOfLocalAddress = HelperData->MaxWSAddressLength;
+ Socket->SharedData->SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
+ Socket->SharedData->UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
+ Socket->SharedData->CreateFlags = dwFlags;
+ Socket->SharedData->ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
+ Socket->SharedData->ProviderFlags = lpProtocolInfo->dwProviderFlags;
+ Socket->SharedData->UseSAN = FALSE;
+ Socket->SharedData->NonBlocking = FALSE; /* Sockets start blocking */
+ Socket->SharedData->RecvTimeout = INFINITE;
+ Socket->SharedData->SendTimeout = INFINITE;
+ Socket->SharedData->OobInline = FALSE;
+
+ /* Ask alex about this */
+ if( Socket->SharedData->SocketType == SOCK_DGRAM ||
+ Socket->SharedData->SocketType == SOCK_RAW )
+ {
+ TRACE("Connectionless socket\n");
+ Socket->SharedData->ServiceFlags1 |= XP1_CONNECTIONLESS;
+ }
+ Socket->Handle = INVALID_SOCKET;
+ }
+
Socket->HelperContext = HelperDLLContext;
Socket->HelperData = HelperData;
Socket->HelperEvents = HelperEvents;
- Socket->LocalAddress = &Socket->WSLocalAddress;
- Socket->SharedData.SizeOfLocalAddress = HelperData->MaxWSAddressLength;
- Socket->RemoteAddress = &Socket->WSRemoteAddress;
- Socket->SharedData.SizeOfRemoteAddress = HelperData->MaxWSAddressLength;
- Socket->SharedData.UseDelayedAcceptance = HelperData->UseDelayedAcceptance;
- Socket->SharedData.CreateFlags = dwFlags;
- Socket->SharedData.CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
- Socket->SharedData.ServiceFlags1 = lpProtocolInfo->dwServiceFlags1;
- Socket->SharedData.ProviderFlags = lpProtocolInfo->dwProviderFlags;
- Socket->SharedData.GroupID = g;
- Socket->SharedData.GroupType = 0;
- Socket->SharedData.UseSAN = FALSE;
- Socket->SharedData.NonBlocking = FALSE; /* Sockets start blocking */
+ Socket->LocalAddress = &Socket->SharedData->WSLocalAddress;
+ Socket->RemoteAddress = &Socket->SharedData->WSRemoteAddress;
Socket->SanData = NULL;
-
- /* Ask alex about this */
- if( Socket->SharedData.SocketType == SOCK_DGRAM ||
- Socket->SharedData.SocketType == SOCK_RAW )
- {
- AFD_DbgPrint(MID_TRACE,("Connectionless socket\n"));
- Socket->SharedData.ServiceFlags1 |= XP1_CONNECTIONLESS;
- }
+ RtlCopyMemory(&Socket->ProtocolInfo, lpProtocolInfo, sizeof(Socket->ProtocolInfo));
+ if (SharedData)
+ goto ok;
/* Packet Size */
SizeOfPacket = TransportName.Length + sizeof(AFD_CREATE_PACKET) + sizeof(WCHAR);
/* Set up EA Buffer */
EABuffer = HeapAlloc(GlobalHeap, 0, SizeOfEA);
if (!EABuffer)
- return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+ {
+ Status = WSAENOBUFS;
+ goto error;
+ }
RtlZeroMemory(EABuffer, SizeOfEA);
EABuffer->NextEntryOffset = 0;
AfdPacket->GroupID = g;
/* Set up Endpoint Flags */
- if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECTIONLESS) != 0)
+ if ((Socket->SharedData->ServiceFlags1 & XP1_CONNECTIONLESS) != 0)
{
if ((SocketType != SOCK_DGRAM) && (SocketType != SOCK_RAW))
{
/* Only RAW or UDP can be Connectionless */
+ Status = WSAEINVAL;
goto error;
}
AfdPacket->EndpointFlags |= AFD_ENDPOINT_CONNECTIONLESS;
}
- if ((Socket->SharedData.ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0)
+ if ((Socket->SharedData->ServiceFlags1 & XP1_MESSAGE_ORIENTED) != 0)
{
if (SocketType == SOCK_STREAM)
{
- if ((Socket->SharedData.ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
+ if ((Socket->SharedData->ServiceFlags1 & XP1_PSEUDO_STREAM) == 0)
{
/* The Provider doesn't actually support Message Oriented Streams */
+ Status = WSAEINVAL;
goto error;
}
}
WSA_FLAG_MULTIPOINT_D_ROOT |
WSA_FLAG_MULTIPOINT_D_LEAF))
{
- if ((Socket->SharedData.ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
+ if ((Socket->SharedData->ServiceFlags1 & XP1_SUPPORT_MULTIPOINT) == 0)
{
/* The Provider doesn't actually support Multipoint */
+ Status = WSAEINVAL;
goto error;
}
AfdPacket->EndpointFlags |= AFD_ENDPOINT_MULTIPOINT;
if (dwFlags & WSA_FLAG_MULTIPOINT_C_ROOT)
{
- if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
+ if (((Socket->SharedData->ServiceFlags1 & XP1_MULTIPOINT_CONTROL_PLANE) == 0)
|| ((dwFlags & WSA_FLAG_MULTIPOINT_C_LEAF) != 0))
{
/* The Provider doesn't support Control Planes, or you already gave a leaf */
+ Status = WSAEINVAL;
goto error;
}
AfdPacket->EndpointFlags |= AFD_ENDPOINT_C_ROOT;
if (dwFlags & WSA_FLAG_MULTIPOINT_D_ROOT)
{
- if (((Socket->SharedData.ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
+ if (((Socket->SharedData->ServiceFlags1 & XP1_MULTIPOINT_DATA_PLANE) == 0)
|| ((dwFlags & WSA_FLAG_MULTIPOINT_D_LEAF) != 0))
{
/* The Provider doesn't support Data Planes, or you already gave a leaf */
+ Status = WSAEINVAL;
goto error;
}
AfdPacket->EndpointFlags |= AFD_ENDPOINT_D_ROOT;
HeapFree(GlobalHeap, 0, EABuffer);
- if (Status != STATUS_SUCCESS)
+ if (!NT_SUCCESS(Status))
{
- AFD_DbgPrint(MIN_TRACE, ("Failed to open socket\n"));
-
- HeapFree(GlobalHeap, 0, Socket);
-
- return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
+ ERR("Failed to open socket. Status 0x%08x\n", Status);
+ Status = TranslateNtStatusError(Status);
+ goto error;
}
/* Save Handle */
/* Save Group Info */
if (g != 0)
{
- GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, NULL, NULL, &GroupData);
- Socket->SharedData.GroupID = GroupData.u.LowPart;
- Socket->SharedData.GroupType = GroupData.u.HighPart;
+ GetSocketInformation(Socket,
+ AFD_INFO_GROUP_ID_TYPE,
+ NULL,
+ NULL,
+ &GroupData,
+ NULL,
+ NULL);
+ Socket->SharedData->GroupID = GroupData.u.LowPart;
+ Socket->SharedData->GroupType = GroupData.u.HighPart;
}
/* Get Window Sizes and Save them */
GetSocketInformation (Socket,
AFD_INFO_SEND_WINDOW_SIZE,
NULL,
- &Socket->SharedData.SizeOfSendBuffer,
+ &Socket->SharedData->SizeOfSendBuffer,
+ NULL,
+ NULL,
NULL);
GetSocketInformation (Socket,
AFD_INFO_RECEIVE_WINDOW_SIZE,
NULL,
- &Socket->SharedData.SizeOfRecvBuffer,
+ &Socket->SharedData->SizeOfRecvBuffer,
+ NULL,
+ NULL,
NULL);
+ok:
/* Save in Process Sockets List */
EnterCriticalSection(&SocketListLock);
CreateContext(Socket);
/* Notify Winsock */
- Upcalls.lpWPUModifyIFSHandle(1, (SOCKET)Sock, lpErrno);
+ Upcalls.lpWPUModifyIFSHandle(Socket->ProtocolInfo.dwCatalogEntryId, (SOCKET)Sock, lpErrno);
/* Return Socket Handle */
- AFD_DbgPrint(MID_TRACE,("Success %x\n", Sock));
+ TRACE("Success %x\n", Sock);
return (SOCKET)Sock;
error:
- AFD_DbgPrint(MID_TRACE,("Ending %x\n", Status));
+ ERR("Ending %x\n", Status);
+
+ if( SharedData )
+ {
+ UnmapViewOfFile(SharedData);
+ NtClose((HANDLE)lpProtocolInfo->dwServiceFlags3);
+ }
+ else
+ {
+ if( Socket && Socket->SharedData )
+ HeapFree(GlobalHeap, 0, Socket->SharedData);
+ }
if( Socket )
HeapFree(GlobalHeap, 0, Socket);
+ if( EABuffer )
+ HeapFree(GlobalHeap, 0, EABuffer);
+
if( lpErrno )
*lpErrno = Status;
return INVALID_SOCKET;
}
+
+INT
+WSPAPI
+WSPDuplicateSocket(
+ IN SOCKET Handle,
+ IN DWORD dwProcessId,
+ OUT LPWSAPROTOCOL_INFOW lpProtocolInfo,
+ OUT LPINT lpErrno)
+{
+ HANDLE hProcess, hDuplicatedSharedData, hDuplicatedHandle;
+ PSOCKET_INFORMATION Socket;
+ PSOCK_SHARED_INFO pSharedData, pOldSharedData;
+ BOOL bDuplicated;
+
+ if (Handle == INVALID_SOCKET)
+ return MsafdReturnWithErrno(STATUS_INVALID_PARAMETER, lpErrno, 0, NULL);
+ Socket = GetSocketStructure(Handle);
+ if( !Socket )
+ {
+ if( lpErrno )
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+ if ( !(hProcess = OpenProcess(PROCESS_DUP_HANDLE, FALSE, dwProcessId)) )
+ return MsafdReturnWithErrno(STATUS_INVALID_PARAMETER, lpErrno, 0, NULL);
+
+ /* It is a not yet duplicated socket, so map the memory, copy the SharedData and free heap */
+ if( Socket->SharedDataHandle == INVALID_HANDLE_VALUE )
+ {
+ Socket->SharedDataHandle = CreateFileMapping(INVALID_HANDLE_VALUE,
+ NULL,
+ PAGE_READWRITE | SEC_COMMIT,
+ 0,
+ (sizeof(SOCK_SHARED_INFO) + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1),
+ NULL);
+ if( Socket->SharedDataHandle == INVALID_HANDLE_VALUE )
+ return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+ pSharedData = MapViewOfFile(Socket->SharedDataHandle,
+ FILE_MAP_ALL_ACCESS,
+ 0,
+ 0,
+ sizeof(SOCK_SHARED_INFO));
+
+ RtlCopyMemory(pSharedData, Socket->SharedData, sizeof(SOCK_SHARED_INFO));
+ pOldSharedData = Socket->SharedData;
+ Socket->SharedData = pSharedData;
+ HeapFree(GlobalHeap, 0, pOldSharedData);
+ }
+ /* Duplicate the handles for the new process */
+ bDuplicated = DuplicateHandle(GetCurrentProcess(),
+ Socket->SharedDataHandle,
+ hProcess,
+ (LPHANDLE)&hDuplicatedSharedData,
+ 0,
+ FALSE,
+ DUPLICATE_SAME_ACCESS);
+ if (!bDuplicated)
+ {
+ NtClose(hProcess);
+ return MsafdReturnWithErrno(STATUS_ACCESS_DENIED, lpErrno, 0, NULL);
+ }
+ bDuplicated = DuplicateHandle(GetCurrentProcess(),
+ (HANDLE)Socket->Handle,
+ hProcess,
+ (LPHANDLE)&hDuplicatedHandle,
+ 0,
+ FALSE,
+ DUPLICATE_SAME_ACCESS);
+ NtClose(hProcess);
+ if( !bDuplicated )
+ return MsafdReturnWithErrno(STATUS_ACCESS_DENIED, lpErrno, 0, NULL);
+
+
+ if (!lpProtocolInfo)
+ return MsafdReturnWithErrno(STATUS_ACCESS_VIOLATION, lpErrno, 0, NULL);
+
+ RtlCopyMemory(lpProtocolInfo, &Socket->ProtocolInfo, sizeof(*lpProtocolInfo));
+
+ lpProtocolInfo->iAddressFamily = Socket->SharedData->AddressFamily;
+ lpProtocolInfo->iProtocol = Socket->SharedData->Protocol;
+ lpProtocolInfo->iSocketType = Socket->SharedData->SocketType;
+ lpProtocolInfo->dwServiceFlags3 = (DWORD)hDuplicatedSharedData;
+ lpProtocolInfo->dwServiceFlags4 = (DWORD)hDuplicatedHandle;
+
+ if( lpErrno )
+ *lpErrno = NO_ERROR;
+
+ return NO_ERROR;
+}
+
INT
TranslateNtStatusError(NTSTATUS Status)
{
return NO_ERROR;
case STATUS_FILE_CLOSED:
+ return WSAECONNRESET;
+
case STATUS_END_OF_FILE:
return WSAESHUTDOWN;
case STATUS_BUFFER_TOO_SMALL:
case STATUS_BUFFER_OVERFLOW:
- DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
return WSAEMSGSIZE;
case STATUS_NO_MEMORY:
case STATUS_INSUFFICIENT_RESOURCES:
- DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
return WSAENOBUFS;
case STATUS_INVALID_CONNECTION:
- DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
+ return WSAENOTCONN;
+
+ case STATUS_PROTOCOL_NOT_SUPPORTED:
return WSAEAFNOSUPPORT;
case STATUS_INVALID_ADDRESS:
- DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
return WSAEADDRNOTAVAIL;
case STATUS_REMOTE_NOT_LISTENING:
- DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
+ case STATUS_REMOTE_DISCONNECT:
return WSAECONNREFUSED;
case STATUS_NETWORK_UNREACHABLE:
- DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
return WSAENETUNREACH;
case STATUS_INVALID_PARAMETER:
- DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
return WSAEINVAL;
case STATUS_CANCELLED:
- DbgPrint("MSAFD: STATUS_CANCELLED\n");
return WSA_OPERATION_ABORTED;
case STATUS_ADDRESS_ALREADY_EXISTS:
- DbgPrint("MSAFD: STATUS_ADDRESS_ALREADY_EXISTS\n");
return WSAEADDRINUSE;
case STATUS_LOCAL_DISCONNECT:
- DbgPrint("MSAFD: STATUS_LOCAL_DISCONNECT\n");
return WSAECONNABORTED;
- case STATUS_REMOTE_DISCONNECT:
- DbgPrint("MSAFD: STATUS_REMOTE_DISCONNECT\n");
- return WSAECONNRESET;
-
case STATUS_ACCESS_VIOLATION:
- DbgPrint("MSAFD: STATUS_ACCESS_VIOLATION\n");
return WSAEFAULT;
+ case STATUS_ACCESS_DENIED:
+ return WSAEACCES;
+
+ case STATUS_NOT_IMPLEMENTED:
+ return WSAEOPNOTSUPP;
+
default:
- DbgPrint("MSAFD: Unhandled NTSTATUS value: 0x%x\n", Status);
+ ERR("MSAFD: Unhandled NTSTATUS value: 0x%x\n", Status);
return WSAENETDOWN;
}
}
AFD_DISCONNECT_INFO DisconnectInfo;
SOCKET_STATE OldState;
LONG LingerWait = -1;
+ DWORD References;
+
+ /* Get the Socket Structure associate to this Socket*/
+ Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Create the Wait Event */
Status = NtCreateEvent(&SockEvent,
FALSE);
if(!NT_SUCCESS(Status))
- return SOCKET_ERROR;
-
- /* Get the Socket Structure associate to this Socket*/
- Socket = GetSocketStructure(Handle);
- if (!Socket)
{
- NtClose(SockEvent);
- *lpErrno = WSAENOTSOCK;
- return SOCKET_ERROR;
+ ERR("NtCreateEvent failed: 0x%08x", Status);
+ return SOCKET_ERROR;
}
if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
if (Status)
{
if (lpErrno) *lpErrno = Status;
+ ERR("WSHNotify failed. Error 0x%#x", Status);
NtClose(SockEvent);
return SOCKET_ERROR;
}
}
/* If a Close is already in Process, give up */
- if (Socket->SharedData.State == SocketClosed)
+ if (Socket->SharedData->State == SocketClosed)
{
+ WARN("Socket is closing.\n");
NtClose(SockEvent);
- *lpErrno = WSAENOTSOCK;
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
return SOCKET_ERROR;
}
+ /* Decrement reference count on SharedData */
+ References = InterlockedDecrement(&Socket->SharedData->RefCount);
+ if (References)
+ goto ok;
+
/* Set the state to close */
- OldState = Socket->SharedData.State;
- Socket->SharedData.State = SocketClosed;
+ OldState = Socket->SharedData->State;
+ Socket->SharedData->State = SocketClosed;
/* If SO_LINGER is ON and the Socket is connected, we need to disconnect */
/* FIXME: Should we do this on Datagram Sockets too? */
- if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
+ if ((OldState == SocketConnected) && (Socket->SharedData->LingerData.l_onoff))
{
ULONG SendsInProgress;
ULONG SleepWait;
/* We need to respect the timeout */
SleepWait = 100;
- LingerWait = Socket->SharedData.LingerData.l_linger * 1000;
+ LingerWait = Socket->SharedData->LingerData.l_linger * 1000;
/* Loop until no more sends are pending, within the timeout */
while (LingerWait)
AFD_INFO_SENDS_IN_PROGRESS,
NULL,
&SendsInProgress,
+ NULL,
+ NULL,
NULL))
{
/* Bail out if anything but NO_ERROR */
/*
* We have to execute a sleep, so it's kind of like
* a block. If the socket is Nonblock, we cannot
- * go on since asyncronous operation is expected
+ * go on since asynchronous operation is expected
* and we cannot offer it
*/
- if (Socket->SharedData.NonBlocking)
+ if (Socket->SharedData->NonBlocking)
{
+ WARN("Would block!\n");
NtClose(SockEvent);
- Socket->SharedData.State = OldState;
- *lpErrno = WSAEWOULDBLOCK;
+ Socket->SharedData->State = OldState;
+ if (lpErrno) *lpErrno = WSAEWOULDBLOCK;
return SOCKET_ERROR;
}
DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
DisconnectInfo.DisconnectType = LingerWait < 0 ? AFD_DISCONNECT_SEND : AFD_DISCONNECT_ABORT;
- if (((DisconnectInfo.DisconnectType & AFD_DISCONNECT_SEND) && (!Socket->SharedData.SendShutdown)) ||
- ((DisconnectInfo.DisconnectType & AFD_DISCONNECT_ABORT) && (!Socket->SharedData.ReceiveShutdown)))
+ if (((DisconnectInfo.DisconnectType & AFD_DISCONNECT_SEND) && (!Socket->SharedData->SendShutdown)) ||
+ ((DisconnectInfo.DisconnectType & AFD_DISCONNECT_ABORT) && (!Socket->SharedData->ReceiveShutdown)))
{
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Handle,
/* Cleanup Time! */
Socket->HelperContext = NULL;
- Socket->SharedData.AsyncDisabledEvents = -1;
+ Socket->SharedData->AsyncDisabledEvents = -1;
NtClose(Socket->TdiAddressHandle);
Socket->TdiAddressHandle = NULL;
NtClose(Socket->TdiConnectionHandle);
Socket->TdiConnectionHandle = NULL;
-
+ok:
EnterCriticalSection(&SocketListLock);
if (SocketListHead == Socket)
{
NtClose((HANDLE)Handle);
NtClose(SockEvent);
+ if( Socket->SharedDataHandle != INVALID_HANDLE_VALUE )
+ {
+ /* It is a duplicated socket, so unmap the memory */
+ UnmapViewOfFile(Socket->SharedData);
+ NtClose(Socket->SharedDataHandle);
+ Socket->SharedData = NULL;
+ }
+ if( !References && Socket->SharedData )
+ {
+ HeapFree(GlobalHeap, 0, Socket->SharedData);
+ }
HeapFree(GlobalHeap, 0, Socket);
return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
}
SOCKADDR_INFO SocketInfo;
HANDLE SockEvent;
- /* See below */
- BindData = HeapAlloc(GlobalHeap, 0, 0xA + SocketAddressLength);
- if (!BindData)
+ /* Get the Socket Structure associate to this Socket*/
+ Socket = GetSocketStructure(Handle);
+ if (!Socket)
{
- return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+ if (Socket->SharedData->State != SocketOpen)
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+ if (!SocketAddress || SocketAddressLength < Socket->SharedData->SizeOfLocalAddress)
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+
+ /* Get Address Information */
+ Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
+ SocketAddressLength,
+ &SocketInfo);
+
+ if (SocketInfo.AddressInfo == SockaddrAddressInfoBroadcast && !Socket->SharedData->Broadcast)
+ {
+ if (lpErrno) *lpErrno = WSAEADDRNOTAVAIL;
+ return SOCKET_ERROR;
}
Status = NtCreateEvent(&SockEvent,
if (!NT_SUCCESS(Status))
{
- HeapFree(GlobalHeap, 0, BindData);
return SOCKET_ERROR;
}
- /* Get the Socket Structure associate to this Socket*/
- Socket = GetSocketStructure(Handle);
- if (!Socket)
+ /* See below */
+ BindData = HeapAlloc(GlobalHeap, 0, 0xA + SocketAddressLength);
+ if (!BindData)
{
- HeapFree(GlobalHeap, 0, BindData);
- *lpErrno = WSAENOTSOCK;
- return SOCKET_ERROR;
+ return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
}
/* Set up Address in TDI Format */
SocketAddress->sa_data,
SocketAddressLength - sizeof(SocketAddress->sa_family));
- /* Get Address Information */
- Socket->HelperData->WSHGetSockaddrType ((PSOCKADDR)SocketAddress,
- SocketAddressLength,
- &SocketInfo);
-
/* Set the Share Type */
- if (Socket->SharedData.ExclusiveAddressUse)
+ if (Socket->SharedData->ExclusiveAddressUse)
{
BindData->ShareType = AFD_SHARE_EXCLUSIVE;
}
{
BindData->ShareType = AFD_SHARE_WILDCARD;
}
- else if (Socket->SharedData.ReuseAddresses)
+ else if (Socket->SharedData->ReuseAddresses)
{
BindData->ShareType = AFD_SHARE_REUSE;
}
&IOSB,
IOCTL_AFD_BIND,
BindData,
- 0xA + Socket->SharedData.SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
+ 0xA + Socket->SharedData->SizeOfLocalAddress, /* Can't figure out a way to calculate this in C*/
BindData,
- 0xA + Socket->SharedData.SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
+ 0xA + Socket->SharedData->SizeOfLocalAddress); /* Can't figure out a way to calculate this C */
/* Wait for return */
if (Status == STATUS_PENDING)
NtClose( SockEvent );
HeapFree(GlobalHeap, 0, BindData);
+ Socket->SharedData->SocketLastError = TranslateNtStatusError(Status);
if (Status != STATUS_SUCCESS)
return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
/* Set up Socket Data */
- Socket->SharedData.State = SocketBound;
+ Socket->SharedData->State = SocketBound;
Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
if (Socket->HelperEvents & WSH_NOTIFY_BIND)
Socket = GetSocketStructure(Handle);
if (!Socket)
{
- *lpErrno = WSAENOTSOCK;
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
return SOCKET_ERROR;
}
- if (Socket->SharedData.Listening)
- return 0;
+ if (Socket->SharedData->Listening)
+ return NO_ERROR;
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
FALSE);
if( !NT_SUCCESS(Status) )
- return -1;
+ return SOCKET_ERROR;
/* Set Up Listen Structure */
ListenData.UseSAN = FALSE;
- ListenData.UseDelayedAcceptance = Socket->SharedData.UseDelayedAcceptance;
+ ListenData.UseDelayedAcceptance = Socket->SharedData->UseDelayedAcceptance;
ListenData.Backlog = Backlog;
/* Send IOCTL */
NtClose( SockEvent );
+ Socket->SharedData->SocketLastError = TranslateNtStatusError(Status);
if (Status != STATUS_SUCCESS)
return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
/* Set to Listening */
- Socket->SharedData.Listening = TRUE;
+ Socket->SharedData->Listening = TRUE;
if (Socket->HelperEvents & WSH_NOTIFY_LISTEN)
{
PAFD_POLL_INFO PollInfo;
NTSTATUS Status;
ULONG HandleCount;
- LONG OutCount = 0;
ULONG PollBufferSize;
PVOID PollBuffer;
ULONG i, j = 0, x;
HANDLE SockEvent;
- BOOL HandleCounted;
LARGE_INTEGER Timeout;
+ PSOCKET_INFORMATION Socket;
+ SOCKET Handle;
+ ULONG Events;
+ fd_set selectfds;
/* Find out how many sockets we have, and how large the buffer needs
* to be */
+ FD_ZERO(&selectfds);
+ if (readfds != NULL)
+ {
+ for (i = 0; i < readfds->fd_count; i++)
+ {
+ FD_SET(readfds->fd_array[i], &selectfds);
+ }
+ }
+ if (writefds != NULL)
+ {
+ for (i = 0; i < writefds->fd_count; i++)
+ {
+ FD_SET(writefds->fd_array[i], &selectfds);
+ }
+ }
+ if (exceptfds != NULL)
+ {
+ for (i = 0; i < exceptfds->fd_count; i++)
+ {
+ FD_SET(exceptfds->fd_array[i], &selectfds);
+ }
+ }
- HandleCount = ( readfds ? readfds->fd_count : 0 ) +
- ( writefds ? writefds->fd_count : 0 ) +
- ( exceptfds ? exceptfds->fd_count : 0 );
+ HandleCount = selectfds.fd_count;
if ( HandleCount == 0 )
{
- AFD_DbgPrint(MAX_TRACE,("HandleCount: %u. Return SOCKET_ERROR\n",
- HandleCount));
+ WARN("No handles! Returning SOCKET_ERROR\n", HandleCount);
if (lpErrno) *lpErrno = WSAEINVAL;
return SOCKET_ERROR;
}
PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
- AFD_DbgPrint(MID_TRACE,("HandleCount: %u BufferSize: %u\n",
- HandleCount, PollBufferSize));
+ TRACE("HandleCount: %u BufferSize: %u\n", HandleCount, PollBufferSize);
/* Convert Timeout to NT Format */
if (timeout == NULL)
{
Timeout.u.LowPart = -1;
Timeout.u.HighPart = 0x7FFFFFFF;
- AFD_DbgPrint(MAX_TRACE,("Infinite timeout\n"));
+ TRACE("Infinite timeout\n");
}
else
{
if (Timeout.QuadPart > 0)
{
if (lpErrno) *lpErrno = WSAEINVAL;
- return SOCKET_ERROR;
+ return SOCKET_ERROR;
}
- AFD_DbgPrint(MAX_TRACE,("Timeout: Orig %d.%06d kernel %d\n",
+ TRACE("Timeout: Orig %d.%06d kernel %d\n",
timeout->tv_sec, timeout->tv_usec,
- Timeout.u.LowPart));
+ Timeout.u.LowPart);
}
Status = NtCreateEvent(&SockEvent,
1,
FALSE);
- if( !NT_SUCCESS(Status) )
+ if(!NT_SUCCESS(Status))
+ {
+ if (lpErrno)
+ *lpErrno = WSAEFAULT;
+
+ ERR("NtCreateEvent failed, 0x%08x\n", Status);
return SOCKET_ERROR;
+ }
/* Allocate */
PollBuffer = HeapAlloc(GlobalHeap, 0, PollBufferSize);
PollInfo->Exclusive = FALSE;
PollInfo->Timeout = Timeout;
+ for (i = 0; i < selectfds.fd_count; i++)
+ {
+ PollInfo->Handles[i].Handle = selectfds.fd_array[i];
+ }
if (readfds != NULL) {
- for (i = 0; i < readfds->fd_count; i++, j++)
+ for (i = 0; i < readfds->fd_count; i++)
{
- PollInfo->Handles[j].Handle = readfds->fd_array[i];
- PollInfo->Handles[j].Events = AFD_EVENT_RECEIVE |
- AFD_EVENT_DISCONNECT |
- AFD_EVENT_ABORT |
- AFD_EVENT_CLOSE |
- AFD_EVENT_ACCEPT;
+ for (j = 0; j < HandleCount; j++)
+ {
+ if (PollInfo->Handles[j].Handle == readfds->fd_array[i])
+ break;
+ }
+ if (j >= HandleCount)
+ {
+ ERR("Error while counting readfds %ld > %ld\n", j, HandleCount);
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
+ Socket = GetSocketStructure(readfds->fd_array[i]);
+ if (!Socket)
+ {
+ ERR("Invalid socket handle provided in readfds %d\n", readfds->fd_array[i]);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
+ PollInfo->Handles[j].Events |= AFD_EVENT_RECEIVE |
+ AFD_EVENT_DISCONNECT |
+ AFD_EVENT_ABORT |
+ AFD_EVENT_CLOSE |
+ AFD_EVENT_ACCEPT;
+ //if (Socket->SharedData->OobInline != 0)
+ // PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
}
}
if (writefds != NULL)
{
- for (i = 0; i < writefds->fd_count; i++, j++)
+ for (i = 0; i < writefds->fd_count; i++)
{
+ for (j = 0; j < HandleCount; j++)
+ {
+ if (PollInfo->Handles[j].Handle == writefds->fd_array[i])
+ break;
+ }
+ if (j >= HandleCount)
+ {
+ ERR("Error while counting writefds %ld > %ld\n", j, HandleCount);
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
+ Socket = GetSocketStructure(writefds->fd_array[i]);
+ if (!Socket)
+ {
+ ERR("Invalid socket handle provided in writefds %d\n", writefds->fd_array[i]);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
PollInfo->Handles[j].Handle = writefds->fd_array[i];
- PollInfo->Handles[j].Events = AFD_EVENT_SEND | AFD_EVENT_CONNECT;
+ PollInfo->Handles[j].Events |= AFD_EVENT_SEND;
+ if (Socket->SharedData->NonBlocking != 0)
+ PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT;
}
}
if (exceptfds != NULL)
{
- for (i = 0; i < exceptfds->fd_count; i++, j++)
+ for (i = 0; i < exceptfds->fd_count; i++)
{
+ for (j = 0; j < HandleCount; j++)
+ {
+ if (PollInfo->Handles[j].Handle == exceptfds->fd_array[i])
+ break;
+ }
+ if (j > HandleCount)
+ {
+ ERR("Error while counting exceptfds %ld > %ld\n", j, HandleCount);
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
+ Socket = GetSocketStructure(exceptfds->fd_array[i]);
+ if (!Socket)
+ {
+ TRACE("Invalid socket handle provided in exceptfds %d\n", exceptfds->fd_array[i]);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
PollInfo->Handles[j].Handle = exceptfds->fd_array[i];
- PollInfo->Handles[j].Events = AFD_EVENT_OOB_RECEIVE | AFD_EVENT_CONNECT_FAIL;
+ if (Socket->SharedData->OobInline == 0)
+ PollInfo->Handles[j].Events |= AFD_EVENT_OOB_RECEIVE;
+ if (Socket->SharedData->NonBlocking != 0)
+ PollInfo->Handles[j].Events |= AFD_EVENT_CONNECT_FAIL;
}
}
- PollInfo->HandleCount = j;
+ PollInfo->HandleCount = HandleCount;
PollBufferSize = FIELD_OFFSET(AFD_POLL_INFO, Handles) + PollInfo->HandleCount * sizeof(AFD_HANDLE);
/* Send IOCTL */
PollInfo,
PollBufferSize);
- AFD_DbgPrint(MID_TRACE,("DeviceIoControlFile => %x\n", Status));
+ TRACE("DeviceIoControlFile => %x\n", Status);
- /* Wait for Completition */
+ /* Wait for Completion */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
/* Return in FDSET Format */
for (i = 0; i < HandleCount; i++)
{
- HandleCounted = FALSE;
+ Events = PollInfo->Handles[i].Events;
+ Handle = PollInfo->Handles[i].Handle;
for(x = 1; x; x<<=1)
{
- switch (PollInfo->Handles[i].Events & x)
+ Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ TRACE("Invalid socket handle found %d\n", Handle);
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ HeapFree(GlobalHeap, 0, PollBuffer);
+ NtClose(SockEvent);
+ return SOCKET_ERROR;
+ }
+ switch (Events & x)
{
case AFD_EVENT_RECEIVE:
case AFD_EVENT_DISCONNECT:
case AFD_EVENT_ABORT:
case AFD_EVENT_ACCEPT:
case AFD_EVENT_CLOSE:
- AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
- PollInfo->Handles[i].Events,
- PollInfo->Handles[i].Handle));
- if (! HandleCounted)
- {
- OutCount++;
- HandleCounted = TRUE;
- }
+ TRACE("Event %x on handle %x\n",
+ Events,
+ Handle);
+ if ((Events & x) == AFD_EVENT_DISCONNECT || (Events & x) == AFD_EVENT_CLOSE)
+ Socket->SharedData->SocketLastError = WSAECONNRESET;
+ if ((Events & x) == AFD_EVENT_ABORT)
+ Socket->SharedData->SocketLastError = WSAECONNABORTED;
if( readfds )
- FD_SET(PollInfo->Handles[i].Handle, readfds);
+ FD_SET(Handle, readfds);
break;
case AFD_EVENT_SEND:
+ TRACE("Event %x on handle %x\n",
+ Events,
+ Handle);
+ if (writefds)
+ FD_SET(Handle, writefds);
+ break;
case AFD_EVENT_CONNECT:
- AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
- PollInfo->Handles[i].Events,
- PollInfo->Handles[i].Handle));
- if (! HandleCounted)
- {
- OutCount++;
- HandleCounted = TRUE;
- }
- if( writefds )
- FD_SET(PollInfo->Handles[i].Handle, writefds);
+ TRACE("Event %x on handle %x\n",
+ Events,
+ Handle);
+ if( writefds && Socket->SharedData->NonBlocking != 0 )
+ FD_SET(Handle, writefds);
break;
case AFD_EVENT_OOB_RECEIVE:
+ TRACE("Event %x on handle %x\n",
+ Events,
+ Handle);
+ if( readfds && Socket->SharedData->OobInline != 0 )
+ FD_SET(Handle, readfds);
+ if( exceptfds && Socket->SharedData->OobInline == 0 )
+ FD_SET(Handle, exceptfds);
+ break;
case AFD_EVENT_CONNECT_FAIL:
- AFD_DbgPrint(MID_TRACE,("Event %x on handle %x\n",
- PollInfo->Handles[i].Events,
- PollInfo->Handles[i].Handle));
- if (! HandleCounted)
- {
- OutCount++;
- HandleCounted = TRUE;
- }
- if( exceptfds )
- FD_SET(PollInfo->Handles[i].Handle, exceptfds);
+ TRACE("Event %x on handle %x\n",
+ Events,
+ Handle);
+ if( exceptfds && Socket->SharedData->NonBlocking != 0 )
+ FD_SET(Handle, exceptfds);
break;
}
}
*lpErrno = WSAEINVAL;
break;
}
- AFD_DbgPrint(MID_TRACE,("*lpErrno = %x\n", *lpErrno));
+ TRACE("*lpErrno = %x\n", *lpErrno);
}
- AFD_DbgPrint(MID_TRACE,("%d events\n", OutCount));
+ HandleCount = (readfds ? readfds->fd_count : 0) +
+ (writefds && writefds != readfds ? writefds->fd_count : 0) +
+ (exceptfds && exceptfds != readfds && exceptfds != writefds ? exceptfds->fd_count : 0);
+
+ TRACE("%d events\n", HandleCount);
+
+ return HandleCount;
+}
+
+DWORD
+GetCurrentTimeInSeconds(VOID)
+{
+ SYSTEMTIME st1970 = { 1970, 1, 0, 1, 0, 0, 0, 0 };
+ union
+ {
+ FILETIME ft;
+ ULONGLONG ll;
+ } u1970, Time;
- return OutCount;
+ GetSystemTimeAsFileTime(&Time.ft);
+ SystemTimeToFileTime(&st1970, &u1970.ft);
+ return (DWORD)((Time.ll - u1970.ll) / 10000000ULL);
}
SOCKET
PSOCKADDR RemoteAddress = NULL;
GROUP GroupID = 0;
ULONG CallBack;
- WSAPROTOCOL_INFOW ProtocolInfo;
SOCKET AcceptSocket;
PSOCKET_INFORMATION AcceptSocketInfo;
UCHAR ReceiveBuffer[0x1A];
HANDLE SockEvent;
+ /* Get the Socket Structure associate to this Socket*/
+ Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+ if (!Socket->SharedData->Listening)
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+ if ((SocketAddress && !SocketAddressLength) ||
+ (SocketAddressLength && !SocketAddress) ||
+ (SocketAddressLength && *SocketAddressLength < sizeof(SOCKADDR)))
+ {
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ return INVALID_SOCKET;
+ }
+
Status = NtCreateEvent(&SockEvent,
EVENT_ALL_ACCESS,
NULL,
if( !NT_SUCCESS(Status) )
{
- MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return SOCKET_ERROR;
}
/* Dynamic Structure...ugh */
ListenReceiveData = (PAFD_RECEIVED_ACCEPT_DATA)ReceiveBuffer;
- /* Get the Socket Structure associate to this Socket*/
- Socket = GetSocketStructure(Handle);
- if (!Socket)
- {
- NtClose(SockEvent);
- *lpErrno = WSAENOTSOCK;
- return INVALID_SOCKET;
- }
-
/* If this is non-blocking, make sure there's something for us to accept */
FD_ZERO(&ReadSet);
FD_SET(Socket->Handle, &ReadSet);
if (WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, lpErrno) == SOCKET_ERROR)
{
NtClose(SockEvent);
- return INVALID_SOCKET;
+ return SOCKET_ERROR;
}
if (ReadSet.fd_array[0] != Socket->Handle)
{
NtClose(SockEvent);
- *lpErrno = WSAEWOULDBLOCK;
- return INVALID_SOCKET;
+ if (lpErrno) *lpErrno = WSAEWOULDBLOCK;
+ return SOCKET_ERROR;
}
/* Send IOCTL */
if (!NT_SUCCESS(Status))
{
NtClose( SockEvent );
- MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
}
if (lpfnCondition != NULL)
{
- if ((Socket->SharedData.ServiceFlags1 & XP1_CONNECT_DATA) != 0)
+ if ((Socket->SharedData->ServiceFlags1 & XP1_CONNECT_DATA) != 0)
{
/* Find out how much data is pending */
PendingAcceptData.SequenceNumber = ListenReceiveData->SequenceNumber;
if (!NT_SUCCESS(Status))
{
NtClose( SockEvent );
- MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
}
/* How much data to allocate */
PendingData = HeapAlloc(GlobalHeap, 0, PendingDataLength);
if (!PendingData)
{
- MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
}
/* We want the data now */
if (!NT_SUCCESS(Status))
{
NtClose( SockEvent );
- MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
}
}
}
- if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
+ if ((Socket->SharedData->ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
{
/* I don't support this yet */
}
/* Build Callee ID */
CalleeID.buf = (PVOID)Socket->LocalAddress;
- CalleeID.len = Socket->SharedData.SizeOfLocalAddress;
+ CalleeID.len = Socket->SharedData->SizeOfLocalAddress;
RemoteAddress = HeapAlloc(GlobalHeap, 0, sizeof(*RemoteAddress));
if (!RemoteAddress)
{
- MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno(STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL);
}
/* Set up Address in SOCKADDR Format */
CallerData.len = PendingDataLength;
/* Check if socket supports Conditional Accept */
- if (Socket->SharedData.UseDelayedAcceptance != 0)
+ if (Socket->SharedData->UseDelayedAcceptance != 0)
{
/* Allocate Buffer for Callee Data */
CalleeDataBuffer = HeapAlloc(GlobalHeap, 0, 4096);
if (!CalleeDataBuffer) {
- MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( STATUS_INSUFFICIENT_RESOURCES, lpErrno, 0, NULL );
}
CalleeData.buf = CalleeDataBuffer;
CalleeData.len = 4096;
if (CallBack == CF_ACCEPT)
{
- if ((Socket->SharedData.ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
+ if ((Socket->SharedData->ServiceFlags1 & XP1_QOS_SUPPORTED) != 0)
{
/* I don't support this yet */
}
if (!NT_SUCCESS(Status))
{
- MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
}
if (CallBack == CF_REJECT )
{
- *lpErrno = WSAECONNREFUSED;
- return INVALID_SOCKET;
+ if (lpErrno) *lpErrno = WSAECONNREFUSED;
+ return SOCKET_ERROR;
}
else
{
- *lpErrno = WSAECONNREFUSED;
- return INVALID_SOCKET;
+ if (lpErrno) *lpErrno = WSAECONNREFUSED;
+ return SOCKET_ERROR;
}
}
}
/* Create a new Socket */
- ProtocolInfo.dwCatalogEntryId = Socket->SharedData.CatalogEntryId;
- ProtocolInfo.dwServiceFlags1 = Socket->SharedData.ServiceFlags1;
- ProtocolInfo.dwProviderFlags = Socket->SharedData.ProviderFlags;
-
- AcceptSocket = WSPSocket (Socket->SharedData.AddressFamily,
- Socket->SharedData.SocketType,
- Socket->SharedData.Protocol,
- &ProtocolInfo,
+ AcceptSocket = WSPSocket (Socket->SharedData->AddressFamily,
+ Socket->SharedData->SocketType,
+ Socket->SharedData->Protocol,
+ &Socket->ProtocolInfo,
GroupID,
- Socket->SharedData.CreateFlags,
+ Socket->SharedData->CreateFlags,
lpErrno);
if (AcceptSocket == INVALID_SOCKET)
- return INVALID_SOCKET;
+ return SOCKET_ERROR;
/* Set up the Accept Structure */
AcceptData.ListenHandle = (HANDLE)AcceptSocket;
Status = IOSB.Status;
}
+ Socket->SharedData->SocketLastError = TranslateNtStatusError(Status);
if (!NT_SUCCESS(Status))
{
NtClose(SockEvent);
WSPCloseSocket( AcceptSocket, lpErrno );
- MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
}
AcceptSocketInfo = GetSocketStructure(AcceptSocket);
{
NtClose(SockEvent);
WSPCloseSocket( AcceptSocket, lpErrno );
- MsafdReturnWithErrno( STATUS_INVALID_CONNECTION, lpErrno, 0, NULL );
- return INVALID_SOCKET;
+ return MsafdReturnWithErrno( STATUS_PROTOCOL_NOT_SUPPORTED, lpErrno, 0, NULL );
}
- AcceptSocketInfo->SharedData.State = SocketConnected;
+ AcceptSocketInfo->SharedData->State = SocketConnected;
+ AcceptSocketInfo->SharedData->ConnectTime = GetCurrentTimeInSeconds();
/* Return Address in SOCKADDR FORMAT */
if( SocketAddress )
&ListenReceiveData->Address.Address[0].AddressType,
sizeof(*RemoteAddress));
if( SocketAddressLength )
- *SocketAddressLength = ListenReceiveData->Address.Address[0].AddressLength;
+ *SocketAddressLength = sizeof(*RemoteAddress);
}
NtClose( SockEvent );
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Socket, FD_ACCEPT);
- AFD_DbgPrint(MID_TRACE,("Socket %x\n", AcceptSocket));
+ TRACE("Socket %x\n", AcceptSocket);
if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_ACCEPT))
{
if (Status)
{
if (lpErrno) *lpErrno = Status;
- return INVALID_SOCKET;
+ return SOCKET_ERROR;
}
}
- *lpErrno = 0;
+ if (lpErrno) *lpErrno = NO_ERROR;
/* Return Socket */
return AcceptSocket;
HANDLE SockEvent;
int SocketDataLength;
- Status = NtCreateEvent(&SockEvent,
- EVENT_ALL_ACCESS,
- NULL,
- 1,
- FALSE);
-
- if (!NT_SUCCESS(Status))
- return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
-
- AFD_DbgPrint(MID_TRACE,("Called\n"));
+ TRACE("Called (%lx) %lx:%d\n", Handle, ((const struct sockaddr_in *)SocketAddress)->sin_addr, ((const struct sockaddr_in *)SocketAddress)->sin_port);
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
if (!Socket)
{
- NtClose(SockEvent);
if (lpErrno) *lpErrno = WSAENOTSOCK;
return SOCKET_ERROR;
}
+ Status = NtCreateEvent(&SockEvent,
+ EVENT_ALL_ACCESS,
+ NULL,
+ 1,
+ FALSE);
+
+ if (!NT_SUCCESS(Status))
+ return SOCKET_ERROR;
+
/* Bind us First */
- if (Socket->SharedData.State == SocketOpen)
+ if (Socket->SharedData->State == SocketOpen)
{
/* Get the Wildcard Address */
BindAddressLength = Socket->HelperData->MaxWSAddressLength;
* at the end of this function right after the Async Thread disables it.
* This should only happen at the *next* WSPConnect
*/
- if (Socket->SharedData.AsyncEvents & FD_CONNECT)
+ if (Socket->SharedData->AsyncEvents & FD_CONNECT)
{
- Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
+ Socket->SharedData->AsyncDisabledEvents |= FD_CONNECT | FD_WRITE;
}
/* Tell AFD that we want Connection Data back, have it allocate a buffer */
ConnectInfo->Unknown = 0;
/* FIXME: Handle Async Connect */
- if (Socket->SharedData.NonBlocking)
+ if (Socket->SharedData->NonBlocking)
{
- AFD_DbgPrint(MIN_TRACE, ("Async Connect UNIMPLEMENTED!\n"));
+ ERR("Async Connect UNIMPLEMENTED!\n");
}
/* Send IOCTL */
Status = IOSB.Status;
}
+ Socket->SharedData->SocketLastError = TranslateNtStatusError(Status);
if (Status != STATUS_SUCCESS)
goto notify;
- Socket->SharedData.State = SocketConnected;
+ Socket->SharedData->State = SocketConnected;
Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
+ Socket->SharedData->ConnectTime = GetCurrentTimeInSeconds();
/* Get any pending connect data */
if (lpCalleeData != NULL)
}
}
- AFD_DbgPrint(MID_TRACE,("Ending\n"));
+ TRACE("Ending %lx\n", IOSB.Status);
notify:
if (ConnectInfo) HeapFree(GetProcessHeap(), 0, ConnectInfo);
NTSTATUS Status;
HANDLE SockEvent;
- Status = NtCreateEvent(&SockEvent,
- EVENT_ALL_ACCESS,
- NULL,
- 1,
- FALSE);
-
- if( !NT_SUCCESS(Status) )
- return -1;
-
- AFD_DbgPrint(MID_TRACE,("Called\n"));
+ TRACE("Called\n");
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
if (!Socket)
{
- NtClose(SockEvent);
- *lpErrno = WSAENOTSOCK;
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
return SOCKET_ERROR;
}
+ Status = NtCreateEvent(&SockEvent,
+ EVENT_ALL_ACCESS,
+ NULL,
+ 1,
+ FALSE);
+
+ if( !NT_SUCCESS(Status) )
+ return SOCKET_ERROR;
+
/* Set AFD Disconnect Type */
switch (HowTo)
{
case SD_RECEIVE:
DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV;
- Socket->SharedData.ReceiveShutdown = TRUE;
+ Socket->SharedData->ReceiveShutdown = TRUE;
break;
case SD_SEND:
DisconnectInfo.DisconnectType= AFD_DISCONNECT_SEND;
- Socket->SharedData.SendShutdown = TRUE;
+ Socket->SharedData->SendShutdown = TRUE;
break;
case SD_BOTH:
DisconnectInfo.DisconnectType = AFD_DISCONNECT_RECV | AFD_DISCONNECT_SEND;
- Socket->SharedData.ReceiveShutdown = TRUE;
- Socket->SharedData.SendShutdown = TRUE;
+ Socket->SharedData->ReceiveShutdown = TRUE;
+ Socket->SharedData->SendShutdown = TRUE;
break;
}
Status = IOSB.Status;
}
- AFD_DbgPrint(MID_TRACE,("Ending\n"));
+ TRACE("Ending\n");
NtClose( SockEvent );
+ Socket->SharedData->SocketLastError = TranslateNtStatusError(Status);
return MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
}
{
IO_STATUS_BLOCK IOSB;
ULONG TdiAddressSize;
- PTDI_ADDRESS_INFO TdiAddress;
+ PTDI_ADDRESS_INFO TdiAddress;
PTRANSPORT_ADDRESS SocketAddress;
PSOCKET_INFORMATION Socket = NULL;
NTSTATUS Status;
HANDLE SockEvent;
- Status = NtCreateEvent(&SockEvent,
- EVENT_ALL_ACCESS,
- NULL,
- 1,
- FALSE);
-
- if( !NT_SUCCESS(Status) )
- return SOCKET_ERROR;
-
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
if (!Socket)
{
- NtClose(SockEvent);
- *lpErrno = WSAENOTSOCK;
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
return SOCKET_ERROR;
}
if (!Name || !NameLength)
{
- NtClose(SockEvent);
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
+ Status = NtCreateEvent(&SockEvent,
+ EVENT_ALL_ACCESS,
+ NULL,
+ 1,
+ FALSE);
+
+ if( !NT_SUCCESS(Status) )
+ return SOCKET_ERROR;
+
/* Allocate a buffer for the address */
TdiAddressSize =
- sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfLocalAddress;
+ sizeof(TRANSPORT_ADDRESS) + Socket->SharedData->SizeOfLocalAddress;
TdiAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
if ( TdiAddress == NULL )
{
NtClose( SockEvent );
- *lpErrno = WSAENOBUFS;
+ if (lpErrno) *lpErrno = WSAENOBUFS;
return SOCKET_ERROR;
}
if (NT_SUCCESS(Status))
{
- if (*NameLength >= Socket->SharedData.SizeOfLocalAddress)
+ if (*NameLength >= Socket->SharedData->SizeOfLocalAddress)
{
Name->sa_family = SocketAddress->Address[0].AddressType;
RtlCopyMemory (Name->sa_data,
SocketAddress->Address[0].Address,
SocketAddress->Address[0].AddressLength);
- *NameLength = Socket->SharedData.SizeOfLocalAddress;
- AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
+ *NameLength = Socket->SharedData->SizeOfLocalAddress;
+ TRACE("NameLength %d Address: %x Port %x\n",
*NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
- ((struct sockaddr_in *)Name)->sin_port));
+ ((struct sockaddr_in *)Name)->sin_port);
HeapFree(GlobalHeap, 0, TdiAddress);
return 0;
}
else
{
HeapFree(GlobalHeap, 0, TdiAddress);
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
}
NTSTATUS Status;
HANDLE SockEvent;
- Status = NtCreateEvent(&SockEvent,
- EVENT_ALL_ACCESS,
- NULL,
- 1,
- FALSE);
-
- if( !NT_SUCCESS(Status) )
- return SOCKET_ERROR;
-
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(s);
if (!Socket)
{
- NtClose(SockEvent);
- *lpErrno = WSAENOTSOCK;
- return SOCKET_ERROR;
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
}
- if (Socket->SharedData.State != SocketConnected)
+ if (Socket->SharedData->State != SocketConnected)
{
- NtClose(SockEvent);
- *lpErrno = WSAENOTCONN;
+ if (lpErrno) *lpErrno = WSAENOTCONN;
return SOCKET_ERROR;
}
if (!Name || !NameLength)
{
- NtClose(SockEvent);
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
+
+ Status = NtCreateEvent(&SockEvent,
+ EVENT_ALL_ACCESS,
+ NULL,
+ 1,
+ FALSE);
+
+ if( !NT_SUCCESS(Status) )
return SOCKET_ERROR;
- }
/* Allocate a buffer for the address */
- TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfRemoteAddress;
+ TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + Socket->SharedData->SizeOfRemoteAddress;
SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
if ( SocketAddress == NULL )
{
NtClose( SockEvent );
- *lpErrno = WSAENOBUFS;
+ if (lpErrno) *lpErrno = WSAENOBUFS;
return SOCKET_ERROR;
}
if (NT_SUCCESS(Status))
{
- if (*NameLength >= Socket->SharedData.SizeOfRemoteAddress)
+ if (*NameLength >= Socket->SharedData->SizeOfRemoteAddress)
{
Name->sa_family = SocketAddress->Address[0].AddressType;
RtlCopyMemory (Name->sa_data,
SocketAddress->Address[0].Address,
SocketAddress->Address[0].AddressLength);
- *NameLength = Socket->SharedData.SizeOfRemoteAddress;
- AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
+ *NameLength = Socket->SharedData->SizeOfRemoteAddress;
+ TRACE("NameLength %d Address: %x Port %x\n",
*NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
- ((struct sockaddr_in *)Name)->sin_port));
+ ((struct sockaddr_in *)Name)->sin_port);
HeapFree(GlobalHeap, 0, SocketAddress);
return 0;
}
else
{
HeapFree(GlobalHeap, 0, SocketAddress);
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
}
OUT LPINT lpErrno)
{
PSOCKET_INFORMATION Socket = NULL;
- BOOLEAN NeedsCompletion;
+ BOOL NeedsCompletion = lpOverlapped != NULL;
BOOLEAN NonBlocking;
-
- if (!lpcbBytesReturned)
- {
- *lpErrno = WSAEFAULT;
- return SOCKET_ERROR;
- }
+ INT Errno = NO_ERROR, Ret = SOCKET_ERROR;
+ DWORD cbRet = 0;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
if (!Socket)
{
- *lpErrno = WSAENOTSOCK;
- return SOCKET_ERROR;
+ if(lpErrno)
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
}
- *lpcbBytesReturned = 0;
+ if (!lpcbBytesReturned && !lpOverlapped)
+ {
+ if(lpErrno)
+ *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
switch( dwIoControlCode )
{
case FIONBIO:
if( cbInBuffer < sizeof(INT) || IS_INTRESOURCE(lpvInBuffer) )
{
- *lpErrno = WSAEFAULT;
- return SOCKET_ERROR;
+ Errno = WSAEFAULT;
+ break;
}
NonBlocking = *((PULONG)lpvInBuffer) ? TRUE : FALSE;
- Socket->SharedData.NonBlocking = NonBlocking ? 1 : 0;
- *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, &NonBlocking, NULL, NULL);
- if (*lpErrno != NO_ERROR)
- return SOCKET_ERROR;
- else
- return NO_ERROR;
+ /* Don't allow to go in blocking mode if WSPAsyncSelect or WSPEventSelect is pending */
+ if (!NonBlocking)
+ {
+ /* If there is an WSPAsyncSelect pending, fail with WSAEINVAL */
+ if (Socket->SharedData->AsyncEvents & (~Socket->SharedData->AsyncDisabledEvents))
+ {
+ Errno = WSAEINVAL;
+ break;
+ }
+ /* If there is an WSPEventSelect pending, fail with WSAEINVAL */
+ if (Socket->NetworkEvents)
+ {
+ Errno = WSAEINVAL;
+ break;
+ }
+ }
+ Socket->SharedData->NonBlocking = NonBlocking ? 1 : 0;
+ NeedsCompletion = FALSE;
+ Errno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, &NonBlocking, NULL, NULL, lpOverlapped, lpCompletionRoutine);
+ if (Errno == NO_ERROR)
+ Ret = NO_ERROR;
+ break;
case FIONREAD:
- if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
+ if (IS_INTRESOURCE(lpvOutBuffer) || cbOutBuffer == 0)
{
- *lpErrno = WSAEFAULT;
- return SOCKET_ERROR;
+ cbRet = sizeof(ULONG);
+ Errno = WSAEFAULT;
+ break;
+ }
+ if (cbOutBuffer < sizeof(ULONG))
+ {
+ Errno = WSAEINVAL;
+ break;
+ }
+ NeedsCompletion = FALSE;
+ Errno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, NULL, (PULONG)lpvOutBuffer, NULL, lpOverlapped, lpCompletionRoutine);
+ if (Errno == NO_ERROR)
+ {
+ cbRet = sizeof(ULONG);
+ Ret = NO_ERROR;
}
- *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, NULL, (PULONG)lpvOutBuffer, NULL);
- if (*lpErrno != NO_ERROR)
- return SOCKET_ERROR;
- else
- {
- *lpcbBytesReturned = sizeof(ULONG);
- return NO_ERROR;
- }
+ break;
case SIOCATMARK:
- if (cbOutBuffer < sizeof(BOOL) || IS_INTRESOURCE(lpvOutBuffer))
+ if (IS_INTRESOURCE(lpvOutBuffer) || cbOutBuffer == 0)
{
- *lpErrno = WSAEFAULT;
- return SOCKET_ERROR;
+ cbRet = sizeof(BOOL);
+ Errno = WSAEFAULT;
+ break;
+ }
+ if (cbOutBuffer < sizeof(BOOL))
+ {
+ Errno = WSAEINVAL;
+ break;
+ }
+ if (Socket->SharedData->SocketType != SOCK_STREAM)
+ {
+ Errno = WSAEINVAL;
+ break;
}
- /* FIXME: Return false for now */
- *(BOOL*)lpvOutBuffer = FALSE;
+ /* FIXME: Return false if OOBINLINE is true for now
+ We should MSG_PEEK|MSG_OOB check with driver
+ */
+ *(BOOL*)lpvOutBuffer = !Socket->SharedData->OobInline;
- *lpcbBytesReturned = sizeof(BOOL);
- *lpErrno = NO_ERROR;
- return NO_ERROR;
+ cbRet = sizeof(BOOL);
+ Errno = NO_ERROR;
+ Ret = NO_ERROR;
+ break;
case SIO_GET_EXTENSION_FUNCTION_POINTER:
- *lpErrno = WSAEINVAL;
- return SOCKET_ERROR;
+ Errno = WSAEINVAL;
+ break;
+ case SIO_ADDRESS_LIST_QUERY:
+ if (IS_INTRESOURCE(lpvOutBuffer) || cbOutBuffer == 0)
+ {
+ cbRet = sizeof(SOCKET_ADDRESS_LIST) + sizeof(Socket->SharedData->WSLocalAddress);
+ Errno = WSAEFAULT;
+ break;
+ }
+ if (cbOutBuffer < sizeof(INT))
+ {
+ Errno = WSAEINVAL;
+ break;
+ }
+
+ cbRet = sizeof(SOCKET_ADDRESS_LIST) + sizeof(Socket->SharedData->WSLocalAddress);
+
+ ((SOCKET_ADDRESS_LIST*)lpvOutBuffer)->iAddressCount = 1;
+
+ if (cbOutBuffer < (sizeof(SOCKET_ADDRESS_LIST) + sizeof(Socket->SharedData->WSLocalAddress)))
+ {
+ Errno = WSAEFAULT;
+ break;
+ }
+
+ ((SOCKET_ADDRESS_LIST*)lpvOutBuffer)->Address[0].iSockaddrLength = sizeof(Socket->SharedData->WSLocalAddress);
+ ((SOCKET_ADDRESS_LIST*)lpvOutBuffer)->Address[0].lpSockaddr = &Socket->SharedData->WSLocalAddress;
+
+ Errno = NO_ERROR;
+ Ret = NO_ERROR;
+ break;
default:
- *lpErrno = Socket->HelperData->WSHIoctl(Socket->HelperContext,
- Handle,
- Socket->TdiAddressHandle,
- Socket->TdiConnectionHandle,
- dwIoControlCode,
- lpvInBuffer,
- cbInBuffer,
- lpvOutBuffer,
- cbOutBuffer,
- lpcbBytesReturned,
- lpOverlapped,
- lpCompletionRoutine,
- (LPBOOL)&NeedsCompletion);
-
- if (*lpErrno != NO_ERROR)
- return SOCKET_ERROR;
- else
- return NO_ERROR;
+ Errno = Socket->HelperData->WSHIoctl(Socket->HelperContext,
+ Handle,
+ Socket->TdiAddressHandle,
+ Socket->TdiConnectionHandle,
+ dwIoControlCode,
+ lpvInBuffer,
+ cbInBuffer,
+ lpvOutBuffer,
+ cbOutBuffer,
+ &cbRet,
+ lpOverlapped,
+ lpCompletionRoutine,
+ &NeedsCompletion);
+
+ if (Errno == NO_ERROR)
+ Ret = NO_ERROR;
+ break;
+ }
+ if (lpOverlapped && NeedsCompletion)
+ {
+ lpOverlapped->Internal = Errno;
+ lpOverlapped->InternalHigh = cbRet;
+ if (lpCompletionRoutine != NULL)
+ {
+ lpCompletionRoutine(Errno, cbRet, lpOverlapped, 0);
+ }
+ if (lpOverlapped->hEvent)
+ SetEvent(lpOverlapped->hEvent);
+ if (!PostQueuedCompletionStatus((HANDLE)Handle, cbRet, 0, lpOverlapped))
+ {
+ ERR("PostQueuedCompletionStatus failed %d\n", GetLastError());
+ }
+ return NO_ERROR;
}
+ if (lpErrno)
+ *lpErrno = Errno;
+ if (lpcbBytesReturned)
+ *lpcbBytesReturned = cbRet;
+ return Ret;
}
PVOID Buffer;
INT BufferSize;
BOOL BoolBuffer;
- INT IntBuffer;
+ DWORD DwordBuffer;
+ INT Errno;
+
+ TRACE("Called\n");
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
if (Socket == NULL)
{
- *lpErrno = WSAENOTSOCK;
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+ if (!OptionLength || !OptionValue)
+ {
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- AFD_DbgPrint(MID_TRACE, ("Called\n"));
-
switch (Level)
{
case SOL_SOCKET:
switch (OptionName)
{
case SO_TYPE:
- Buffer = &Socket->SharedData.SocketType;
+ Buffer = &Socket->SharedData->SocketType;
BufferSize = sizeof(INT);
break;
case SO_RCVBUF:
- Buffer = &Socket->SharedData.SizeOfRecvBuffer;
+ Buffer = &Socket->SharedData->SizeOfRecvBuffer;
BufferSize = sizeof(INT);
break;
case SO_SNDBUF:
- Buffer = &Socket->SharedData.SizeOfSendBuffer;
+ Buffer = &Socket->SharedData->SizeOfSendBuffer;
BufferSize = sizeof(INT);
break;
case SO_ACCEPTCONN:
- BoolBuffer = Socket->SharedData.Listening;
+ BoolBuffer = Socket->SharedData->Listening;
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
case SO_BROADCAST:
- BoolBuffer = Socket->SharedData.Broadcast;
+ BoolBuffer = Socket->SharedData->Broadcast;
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
case SO_DEBUG:
- BoolBuffer = Socket->SharedData.Debug;
+ BoolBuffer = Socket->SharedData->Debug;
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
case SO_DONTLINGER:
- BoolBuffer = (Socket->SharedData.LingerData.l_onoff == 0);
+ BoolBuffer = (Socket->SharedData->LingerData.l_onoff == 0);
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
case SO_LINGER:
- Buffer = &Socket->SharedData.LingerData;
+ if (Socket->SharedData->SocketType == SOCK_DGRAM)
+ {
+ if (lpErrno) *lpErrno = WSAENOPROTOOPT;
+ return SOCKET_ERROR;
+ }
+ Buffer = &Socket->SharedData->LingerData;
BufferSize = sizeof(struct linger);
break;
case SO_OOBINLINE:
- BoolBuffer = (Socket->SharedData.OobInline != 0);
+ BoolBuffer = (Socket->SharedData->OobInline != 0);
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
goto SendToHelper;
case SO_CONDITIONAL_ACCEPT:
- BoolBuffer = (Socket->SharedData.UseDelayedAcceptance != 0);
+ BoolBuffer = (Socket->SharedData->UseDelayedAcceptance != 0);
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
case SO_REUSEADDR:
- BoolBuffer = (Socket->SharedData.ReuseAddresses != 0);
+ BoolBuffer = (Socket->SharedData->ReuseAddresses != 0);
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
case SO_EXCLUSIVEADDRUSE:
- BoolBuffer = (Socket->SharedData.ExclusiveAddressUse != 0);
+ BoolBuffer = (Socket->SharedData->ExclusiveAddressUse != 0);
Buffer = &BoolBuffer;
BufferSize = sizeof(BOOL);
break;
case SO_ERROR:
- /* HACK: This needs to be properly tracked */
- IntBuffer = 0;
- DbgPrint("MSAFD: Hacked SO_ERROR returning error %d\n", IntBuffer);
-
- Buffer = &IntBuffer;
+ Buffer = &Socket->SharedData->SocketLastError;
BufferSize = sizeof(INT);
break;
+ case SO_CONNECT_TIME:
+ DwordBuffer = GetCurrentTimeInSeconds() - Socket->SharedData->ConnectTime;
+ Buffer = &DwordBuffer;
+ BufferSize = sizeof(DWORD);
+ break;
+
+ case SO_SNDTIMEO:
+ Buffer = &Socket->SharedData->SendTimeout;
+ BufferSize = sizeof(DWORD);
+ break;
+ case SO_RCVTIMEO:
+ Buffer = &Socket->SharedData->RecvTimeout;
+ BufferSize = sizeof(DWORD);
+ break;
+ case SO_PROTOCOL_INFOW:
+ Buffer = &Socket->ProtocolInfo;
+ BufferSize = sizeof(Socket->ProtocolInfo);
+ break;
+
case SO_GROUP_ID:
case SO_GROUP_PRIORITY:
case SO_MAX_MSG_SIZE:
- case SO_PROTOCOL_INFO:
default:
DbgPrint("MSAFD: Get unknown optname %x\n", OptionName);
- *lpErrno = WSAENOPROTOOPT;
+ if (lpErrno) *lpErrno = WSAENOPROTOOPT;
return SOCKET_ERROR;
}
if (*OptionLength < BufferSize)
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
*OptionLength = BufferSize;
return SOCKET_ERROR;
}
return 0;
default:
- break;
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
}
SendToHelper:
- *lpErrno = Socket->HelperData->WSHGetSocketInformation(Socket->HelperContext,
- Handle,
- Socket->TdiAddressHandle,
- Socket->TdiConnectionHandle,
- Level,
- OptionName,
- OptionValue,
- (LPINT)OptionLength);
- return (*lpErrno == 0) ? 0 : SOCKET_ERROR;
+ Errno = Socket->HelperData->WSHGetSocketInformation(Socket->HelperContext,
+ Handle,
+ Socket->TdiAddressHandle,
+ Socket->TdiConnectionHandle,
+ Level,
+ OptionName,
+ OptionValue,
+ (LPINT)OptionLength);
+ if (lpErrno) *lpErrno = Errno;
+ return (Errno == NO_ERROR) ? NO_ERROR : SOCKET_ERROR;
}
INT
OUT LPINT lpErrno)
{
PSOCKET_INFORMATION Socket;
+ INT Errno;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(s);
if (Socket == NULL)
{
- *lpErrno = WSAENOTSOCK;
+ if (lpErrno) *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+ if (!optval)
+ {
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
case SO_BROADCAST:
if (optlen < sizeof(BOOL))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- Socket->SharedData.Broadcast = (*optval != 0) ? 1 : 0;
- return 0;
+ Socket->SharedData->Broadcast = (*optval != 0) ? 1 : 0;
+ return NO_ERROR;
+
+ case SO_OOBINLINE:
+ if (optlen < sizeof(BOOL))
+ {
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
+ Socket->SharedData->OobInline = (*optval != 0) ? 1 : 0;
+ return NO_ERROR;
case SO_DONTLINGER:
if (optlen < sizeof(BOOL))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- Socket->SharedData.LingerData.l_onoff = (*optval != 0) ? 0 : 1;
- return 0;
+ Socket->SharedData->LingerData.l_onoff = (*optval != 0) ? 0 : 1;
+ return NO_ERROR;
case SO_REUSEADDR:
if (optlen < sizeof(BOOL))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- Socket->SharedData.ReuseAddresses = (*optval != 0) ? 1 : 0;
- return 0;
+ Socket->SharedData->ReuseAddresses = (*optval != 0) ? 1 : 0;
+ return NO_ERROR;
case SO_EXCLUSIVEADDRUSE:
if (optlen < sizeof(BOOL))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- Socket->SharedData.ExclusiveAddressUse = (*optval != 0) ? 1 : 0;
- return 0;
+ Socket->SharedData->ExclusiveAddressUse = (*optval != 0) ? 1 : 0;
+ return NO_ERROR;
case SO_LINGER:
if (optlen < sizeof(struct linger))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- RtlCopyMemory(&Socket->SharedData.LingerData,
+ RtlCopyMemory(&Socket->SharedData->LingerData,
optval,
sizeof(struct linger));
- return 0;
+ return NO_ERROR;
case SO_SNDBUF:
if (optlen < sizeof(DWORD))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
/* TODO: The total per-socket buffer space reserved for sends */
- AFD_DbgPrint(MIN_TRACE,("Setting send buf to %x is not implemented yet\n", optval));
- return 0;
+ ERR("Setting send buf to %x is not implemented yet\n", optval);
+ return NO_ERROR;
+
+ case SO_ERROR:
+ if (optlen < sizeof(INT))
+ {
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
+
+ RtlCopyMemory(&Socket->SharedData->SocketLastError,
+ optval,
+ sizeof(INT));
+ return NO_ERROR;
case SO_SNDTIMEO:
if (optlen < sizeof(DWORD))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- RtlCopyMemory(&Socket->SharedData.SendTimeout,
+ RtlCopyMemory(&Socket->SharedData->SendTimeout,
optval,
sizeof(DWORD));
- return 0;
+ return NO_ERROR;
case SO_RCVTIMEO:
if (optlen < sizeof(DWORD))
{
- *lpErrno = WSAEFAULT;
+ if (lpErrno) *lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- RtlCopyMemory(&Socket->SharedData.RecvTimeout,
+ RtlCopyMemory(&Socket->SharedData->RecvTimeout,
optval,
sizeof(DWORD));
- return 0;
+ return NO_ERROR;
case SO_KEEPALIVE:
case SO_DONTROUTE:
default:
/* Obviously this is a hack */
- DbgPrint("MSAFD: Set unknown optname %x\n", optname);
- return 0;
+ ERR("MSAFD: Set unknown optname %x\n", optname);
+ return NO_ERROR;
}
}
SendToHelper:
- *lpErrno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
- s,
- Socket->TdiAddressHandle,
- Socket->TdiConnectionHandle,
- level,
- optname,
- (PCHAR)optval,
- optlen);
- return (*lpErrno == 0) ? 0 : SOCKET_ERROR;
+ Errno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
+ s,
+ Socket->TdiAddressHandle,
+ Socket->TdiConnectionHandle,
+ level,
+ optname,
+ (PCHAR)optval,
+ optlen);
+ if (lpErrno) *lpErrno = Errno;
+ return (Errno == NO_ERROR) ? NO_ERROR : SOCKET_ERROR;
}
/*
{
NTSTATUS Status;
- AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
- Status = NO_ERROR;
+ if (((LOBYTE(wVersionRequested) == 2) && (HIBYTE(wVersionRequested) < 2)) ||
+ (LOBYTE(wVersionRequested) < 2))
+ {
+ ERR("WSPStartup NOT SUPPORTED for version 0x%X\n", wVersionRequested);
+ return WSAVERNOTSUPPORTED;
+ }
+ else
+ Status = NO_ERROR;
+ /* FIXME: Enable all cases of WSPStartup status */
Upcalls = UpcallTable;
if (Status == NO_ERROR)
lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
lpWSPData->wVersion = MAKEWORD(2, 2);
lpWSPData->wHighVersion = MAKEWORD(2, 2);
+ /* Save CatalogEntryId for all upcalls */
+ CatalogEntryId = lpProtocolInfo->dwCatalogEntryId;
}
- AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
-
+ TRACE("Status (%d).\n", Status);
return Status;
}
+INT
+WSPAPI
+WSPAddressToString(IN LPSOCKADDR lpsaAddress,
+ IN DWORD dwAddressLength,
+ IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
+ OUT LPWSTR lpszAddressString,
+ IN OUT LPDWORD lpdwAddressStringLength,
+ OUT LPINT lpErrno)
+{
+ DWORD size;
+ WCHAR buffer[54]; /* 32 digits + 7':' + '[' + '%" + 5 digits + ']:' + 5 digits + '\0' */
+ WCHAR *p;
+
+ if (!lpsaAddress || !lpszAddressString || !lpdwAddressStringLength)
+ {
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
+
+ switch (lpsaAddress->sa_family)
+ {
+ case AF_INET:
+ if (dwAddressLength < sizeof(SOCKADDR_IN))
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+ swprintf(buffer,
+ L"%u.%u.%u.%u:%u",
+ (unsigned int)(ntohl(((SOCKADDR_IN *)lpsaAddress)->sin_addr.s_addr) >> 24 & 0xff),
+ (unsigned int)(ntohl(((SOCKADDR_IN *)lpsaAddress)->sin_addr.s_addr) >> 16 & 0xff),
+ (unsigned int)(ntohl(((SOCKADDR_IN *)lpsaAddress)->sin_addr.s_addr) >> 8 & 0xff),
+ (unsigned int)(ntohl(((SOCKADDR_IN *)lpsaAddress)->sin_addr.s_addr) & 0xff),
+ ntohs(((SOCKADDR_IN *)lpsaAddress)->sin_port));
+
+ p = wcschr(buffer, L':');
+ if (!((SOCKADDR_IN *)lpsaAddress)->sin_port)
+ {
+ *p = 0;
+ }
+ break;
+ default:
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+
+ size = wcslen(buffer) + 1;
+
+ if (*lpdwAddressStringLength < size)
+ {
+ *lpdwAddressStringLength = size;
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
+
+ *lpdwAddressStringLength = size;
+ wcscpy(lpszAddressString, buffer);
+ return 0;
+}
+
+INT
+WSPAPI
+WSPStringToAddress(IN LPWSTR AddressString,
+ IN INT AddressFamily,
+ IN LPWSAPROTOCOL_INFOW lpProtocolInfo,
+ OUT LPSOCKADDR lpAddress,
+ IN OUT LPINT lpAddressLength,
+ OUT LPINT lpErrno)
+{
+ int numdots = 0;
+ USHORT port;
+ LONG inetaddr = 0, ip_part;
+ LPWSTR *bp = NULL;
+ SOCKADDR_IN *sockaddr;
+
+ if (!lpAddressLength || !lpAddress || !AddressString)
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+
+ sockaddr = (SOCKADDR_IN *)lpAddress;
+
+ /* Set right address family */
+ if (lpProtocolInfo != NULL)
+ {
+ sockaddr->sin_family = lpProtocolInfo->iAddressFamily;
+ }
+ else
+ {
+ sockaddr->sin_family = AddressFamily;
+ }
+
+ /* Report size */
+ if (AddressFamily == AF_INET)
+ {
+ if (*lpAddressLength < (INT)sizeof(SOCKADDR_IN))
+ {
+ if (lpErrno) *lpErrno = WSAEFAULT;
+ }
+ else
+ {
+ // translate ip string to ip
+
+ /* Get ip number */
+ bp = &AddressString;
+ inetaddr = 0;
+
+ while (*bp < &AddressString[wcslen(AddressString)])
+ {
+ ip_part = wcstol(*bp, bp, 10);
+ /* ip part number should be in range 0-255 */
+ if (ip_part < 0 || ip_part > 255)
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+ inetaddr = (inetaddr << 8) + ip_part;
+ /* we end on string end or port separator */
+ if ((*bp)[0] == 0 || (*bp)[0] == L':')
+ break;
+ /* ip parts are dot separated. verify it */
+ if ((*bp)[0] != L'.')
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+ /* count the dots */
+ numdots++;
+ /* move over the dot to next ip part */
+ (*bp)++;
+ }
+
+ /* check dots count */
+ if (numdots != 3)
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+
+ /* Get port number */
+ if ((*bp)[0] == L':')
+ {
+ /* move over the column to port part */
+ (*bp)++;
+ /* next char should be numeric */
+ if ((*bp)[0] < L'0' || (*bp)[0] > L'9')
+ {
+ if (lpErrno) *lpErrno = WSAEINVAL;
+ return SOCKET_ERROR;
+ }
+ port = wcstol(*bp, bp, 10);
+ }
+ else
+ {
+ port = 0;
+ }
+
+ if (lpErrno) *lpErrno = NO_ERROR;
+ /* rest sockaddr.sin_addr.s_addr
+ for we need to be sure it is zero when we come to while */
+ *lpAddressLength = sizeof(*sockaddr);
+ memset(lpAddress, 0, sizeof(*sockaddr));
+ sockaddr->sin_family = AF_INET;
+ sockaddr->sin_addr.s_addr = inetaddr;
+ sockaddr->sin_port = port;
+ }
+ }
+
+ if (lpErrno && !*lpErrno)
+ {
+ return 0;
+ }
+
+ return SOCKET_ERROR;
+}
+
/*
* FUNCTION: Cleans up service provider for a client
* ARGUMENTS:
WSPCleanup(OUT LPINT lpErrno)
{
- AFD_DbgPrint(MAX_TRACE, ("\n"));
- AFD_DbgPrint(MAX_TRACE, ("Leaving.\n"));
- *lpErrno = NO_ERROR;
+ TRACE("Leaving.\n");
+
+ if (lpErrno) *lpErrno = NO_ERROR;
return 0;
}
+VOID
+NTAPI
+AfdInfoAPC(PVOID ApcContext,
+ PIO_STATUS_BLOCK IoStatusBlock,
+ ULONG Reserved)
+{
+ PAFDAPCCONTEXT Context = ApcContext;
+ Context->lpCompletionRoutine(IoStatusBlock->Status, IoStatusBlock->Information, Context->lpOverlapped, 0);
+ HeapFree(GlobalHeap, 0, ApcContext);
+}
int
GetSocketInformation(PSOCKET_INFORMATION Socket,
ULONG AfdInformationClass,
PBOOLEAN Boolean OPTIONAL,
PULONG Ulong OPTIONAL,
- PLARGE_INTEGER LargeInteger OPTIONAL)
+ PLARGE_INTEGER LargeInteger OPTIONAL,
+ LPWSAOVERLAPPED Overlapped OPTIONAL,
+ LPWSAOVERLAPPED_COMPLETION_ROUTINE CompletionRoutine OPTIONAL)
{
- IO_STATUS_BLOCK IOSB;
+ PIO_STATUS_BLOCK IOSB;
+ IO_STATUS_BLOCK DummyIOSB;
AFD_INFO InfoData;
NTSTATUS Status;
+ PAFDAPCCONTEXT APCContext;
+ PIO_APC_ROUTINE APCFunction;
+ HANDLE Event = NULL;
HANDLE SockEvent;
Status = NtCreateEvent(&SockEvent,
FALSE);
if( !NT_SUCCESS(Status) )
- return -1;
+ return SOCKET_ERROR;
/* Set Info Class */
InfoData.InformationClass = AfdInformationClass;
+ /* Verify if we should use APC */
+ if (Overlapped == NULL)
+ {
+ /* Not using Overlapped structure, so use normal blocking on event */
+ APCContext = NULL;
+ APCFunction = NULL;
+ Event = SockEvent;
+ IOSB = &DummyIOSB;
+ }
+ else
+ {
+ /* Overlapped request for non overlapped opened socket */
+ if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
+ {
+ TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
+ return 0;
+ }
+ if (CompletionRoutine == NULL)
+ {
+ /* Using Overlapped Structure, but no Completition Routine, so no need for APC */
+ APCContext = (PAFDAPCCONTEXT)Overlapped;
+ APCFunction = NULL;
+ Event = Overlapped->hEvent;
+ }
+ else
+ {
+ /* Using Overlapped Structure and a Completition Routine, so use an APC */
+ APCFunction = &AfdInfoAPC; // should be a private io completition function inside us
+ APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
+ if (!APCContext)
+ {
+ ERR("Not enough memory for APC Context\n");
+ return WSAEFAULT;
+ }
+ APCContext->lpCompletionRoutine = CompletionRoutine;
+ APCContext->lpOverlapped = Overlapped;
+ APCContext->lpSocket = Socket;
+ }
+
+ IOSB = (PIO_STATUS_BLOCK)&Overlapped->Internal;
+ }
+
+ IOSB->Status = STATUS_PENDING;
+
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
- SockEvent,
- NULL,
- NULL,
- &IOSB,
+ Event,
+ APCFunction,
+ APCContext,
+ IOSB,
IOCTL_AFD_GET_INFO,
&InfoData,
sizeof(InfoData),
sizeof(InfoData));
/* Wait for return */
- if (Status == STATUS_PENDING)
+ if (Status == STATUS_PENDING && Overlapped == NULL)
{
WaitForSingleObject(SockEvent, INFINITE);
- Status = IOSB.Status;
+ Status = IOSB->Status;
+ }
+
+ TRACE("Status %x Information %d\n", Status, IOSB->Information);
+
+ if (Status == STATUS_PENDING)
+ {
+ TRACE("Leaving (Pending)\n");
+ return WSA_IO_PENDING;
}
if (Status != STATUS_SUCCESS)
- return -1;
+ return SOCKET_ERROR;
/* Return Information */
if (Ulong != NULL)
NtClose( SockEvent );
- return 0;
+ return NO_ERROR;
}
ULONG AfdInformationClass,
PBOOLEAN Boolean OPTIONAL,
PULONG Ulong OPTIONAL,
- PLARGE_INTEGER LargeInteger OPTIONAL)
+ PLARGE_INTEGER LargeInteger OPTIONAL,
+ LPWSAOVERLAPPED Overlapped OPTIONAL,
+ LPWSAOVERLAPPED_COMPLETION_ROUTINE CompletionRoutine OPTIONAL)
{
- IO_STATUS_BLOCK IOSB;
+ PIO_STATUS_BLOCK IOSB;
+ IO_STATUS_BLOCK DummyIOSB;
AFD_INFO InfoData;
NTSTATUS Status;
+ PAFDAPCCONTEXT APCContext;
+ PIO_APC_ROUTINE APCFunction;
+ HANDLE Event = NULL;
HANDLE SockEvent;
Status = NtCreateEvent(&SockEvent,
FALSE);
if( !NT_SUCCESS(Status) )
- return -1;
+ return SOCKET_ERROR;
/* Set Info Class */
InfoData.InformationClass = AfdInformationClass;
InfoData.Information.Boolean = *Boolean;
}
+ /* Verify if we should use APC */
+ if (Overlapped == NULL)
+ {
+ /* Not using Overlapped structure, so use normal blocking on event */
+ APCContext = NULL;
+ APCFunction = NULL;
+ Event = SockEvent;
+ IOSB = &DummyIOSB;
+ }
+ else
+ {
+ /* Overlapped request for non overlapped opened socket */
+ if ((Socket->SharedData->CreateFlags & SO_SYNCHRONOUS_NONALERT) != 0)
+ {
+ TRACE("Opened without flag WSA_FLAG_OVERLAPPED. Do nothing.\n");
+ return 0;
+ }
+ if (CompletionRoutine == NULL)
+ {
+ /* Using Overlapped Structure, but no Completition Routine, so no need for APC */
+ APCContext = (PAFDAPCCONTEXT)Overlapped;
+ APCFunction = NULL;
+ Event = Overlapped->hEvent;
+ }
+ else
+ {
+ /* Using Overlapped Structure and a Completition Routine, so use an APC */
+ APCFunction = &AfdInfoAPC; // should be a private io completition function inside us
+ APCContext = HeapAlloc(GlobalHeap, 0, sizeof(AFDAPCCONTEXT));
+ if (!APCContext)
+ {
+ ERR("Not enough memory for APC Context\n");
+ return WSAEFAULT;
+ }
+ APCContext->lpCompletionRoutine = CompletionRoutine;
+ APCContext->lpOverlapped = Overlapped;
+ APCContext->lpSocket = Socket;
+ }
+
+ IOSB = (PIO_STATUS_BLOCK)&Overlapped->Internal;
+ }
+
+ IOSB->Status = STATUS_PENDING;
+
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
- SockEvent,
- NULL,
- NULL,
- &IOSB,
+ Event,
+ APCFunction,
+ APCContext,
+ IOSB,
IOCTL_AFD_SET_INFO,
&InfoData,
sizeof(InfoData),
0);
/* Wait for return */
- if (Status == STATUS_PENDING)
+ if (Status == STATUS_PENDING && Overlapped == NULL)
{
WaitForSingleObject(SockEvent, INFINITE);
- Status = IOSB.Status;
+ Status = IOSB->Status;
}
NtClose( SockEvent );
- return Status == STATUS_SUCCESS ? 0 : -1;
+ TRACE("Status %x Information %d\n", Status, IOSB->Information);
+
+ if (Status == STATUS_PENDING)
+ {
+ TRACE("Leaving (Pending)\n");
+ return WSA_IO_PENDING;
+ }
+
+ return Status == STATUS_SUCCESS ? NO_ERROR : SOCKET_ERROR;
}
FALSE);
if( !NT_SUCCESS(Status) )
- return -1;
+ return SOCKET_ERROR;
/* Create Context */
- ContextData.SharedData = Socket->SharedData;
+ ContextData.SharedData = *Socket->SharedData;
ContextData.SizeOfHelperData = 0;
RtlCopyMemory (&ContextData.LocalAddress,
Socket->LocalAddress,
- Socket->SharedData.SizeOfLocalAddress);
+ Socket->SharedData->SizeOfLocalAddress);
RtlCopyMemory (&ContextData.RemoteAddress,
Socket->RemoteAddress,
- Socket->SharedData.SizeOfRemoteAddress);
+ Socket->SharedData->SizeOfRemoteAddress);
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Socket->Handle,
NULL,
0);
- /* Wait for Completition */
+ /* Wait for Completion */
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
NtClose( SockEvent );
- return Status == STATUS_SUCCESS ? 0 : -1;
+ return Status == STATUS_SUCCESS ? NO_ERROR : SOCKET_ERROR;
}
BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
2); // Allow 2 threads only
if (!NT_SUCCESS(Status))
{
- AFD_DbgPrint(MID_TRACE,("Failed to create completion port\n"));
+ ERR("Failed to create completion port: 0x%08x\n", Status);
return FALSE;
}
/* Protect Handle */
/* Create the Async Thread */
hAsyncThread = CreateThread(NULL,
0,
- (LPTHREAD_START_ROUTINE)SockAsyncThread,
+ SockAsyncThread,
NULL,
0,
&AsyncThreadId);
return TRUE;
}
-int SockAsyncThread(PVOID ThreadParam)
+ULONG
+NTAPI
+SockAsyncThread(PVOID ThreadParam)
{
PVOID AsyncContext;
PASYNC_COMPLETION_ROUTINE AsyncCompletionRoutine;
FILE_COMPLETION_INFORMATION CompletionInfo;
OBJECT_HANDLE_ATTRIBUTE_INFORMATION HandleFlags;
- /* First, make sure we're not already intialized */
+ /* First, make sure we're not already initialized */
if (SockAsyncHelperAfdHandle)
{
return TRUE;
Socket = AsyncData->ParentSocket;
/* Check if the Sequence Number Changed behind our back */
- if (AsyncData->SequenceNumber != Socket->SharedData.SequenceNumber )
+ if (AsyncData->SequenceNumber != Socket->SharedData->SequenceNumber )
{
return;
}
switch (AsyncData->AsyncSelectInfo.Handles[0].Events & x)
{
case AFD_EVENT_RECEIVE:
- if (0 != (Socket->SharedData.AsyncEvents & FD_READ) &&
- 0 == (Socket->SharedData.AsyncDisabledEvents & FD_READ))
+ if (0 != (Socket->SharedData->AsyncEvents & FD_READ) &&
+ 0 == (Socket->SharedData->AsyncDisabledEvents & FD_READ))
{
- /* Make the Notifcation */
- (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
- Socket->SharedData.wMsg,
+ /* Make the Notification */
+ (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
+ Socket->SharedData->wMsg,
Socket->Handle,
WSAMAKESELECTREPLY(FD_READ, 0));
/* Disable this event until the next read(); */
- Socket->SharedData.AsyncDisabledEvents |= FD_READ;
+ Socket->SharedData->AsyncDisabledEvents |= FD_READ;
}
break;
case AFD_EVENT_OOB_RECEIVE:
- if (0 != (Socket->SharedData.AsyncEvents & FD_OOB) &&
- 0 == (Socket->SharedData.AsyncDisabledEvents & FD_OOB))
+ if (0 != (Socket->SharedData->AsyncEvents & FD_OOB) &&
+ 0 == (Socket->SharedData->AsyncDisabledEvents & FD_OOB))
{
- /* Make the Notifcation */
- (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
- Socket->SharedData.wMsg,
+ /* Make the Notification */
+ (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
+ Socket->SharedData->wMsg,
Socket->Handle,
WSAMAKESELECTREPLY(FD_OOB, 0));
/* Disable this event until the next read(); */
- Socket->SharedData.AsyncDisabledEvents |= FD_OOB;
+ Socket->SharedData->AsyncDisabledEvents |= FD_OOB;
}
break;
case AFD_EVENT_SEND:
- if (0 != (Socket->SharedData.AsyncEvents & FD_WRITE) &&
- 0 == (Socket->SharedData.AsyncDisabledEvents & FD_WRITE))
+ if (0 != (Socket->SharedData->AsyncEvents & FD_WRITE) &&
+ 0 == (Socket->SharedData->AsyncDisabledEvents & FD_WRITE))
{
- /* Make the Notifcation */
- (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
- Socket->SharedData.wMsg,
+ /* Make the Notification */
+ (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
+ Socket->SharedData->wMsg,
Socket->Handle,
WSAMAKESELECTREPLY(FD_WRITE, 0));
/* Disable this event until the next write(); */
- Socket->SharedData.AsyncDisabledEvents |= FD_WRITE;
+ Socket->SharedData->AsyncDisabledEvents |= FD_WRITE;
}
break;
/* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
case AFD_EVENT_CONNECT:
case AFD_EVENT_CONNECT_FAIL:
- if (0 != (Socket->SharedData.AsyncEvents & FD_CONNECT) &&
- 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CONNECT))
+ if (0 != (Socket->SharedData->AsyncEvents & FD_CONNECT) &&
+ 0 == (Socket->SharedData->AsyncDisabledEvents & FD_CONNECT))
{
- /* Make the Notifcation */
- (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
- Socket->SharedData.wMsg,
+ /* Make the Notification */
+ (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
+ Socket->SharedData->wMsg,
Socket->Handle,
WSAMAKESELECTREPLY(FD_CONNECT, 0));
/* Disable this event forever; */
- Socket->SharedData.AsyncDisabledEvents |= FD_CONNECT;
+ Socket->SharedData->AsyncDisabledEvents |= FD_CONNECT;
}
break;
case AFD_EVENT_ACCEPT:
- if (0 != (Socket->SharedData.AsyncEvents & FD_ACCEPT) &&
- 0 == (Socket->SharedData.AsyncDisabledEvents & FD_ACCEPT))
+ if (0 != (Socket->SharedData->AsyncEvents & FD_ACCEPT) &&
+ 0 == (Socket->SharedData->AsyncDisabledEvents & FD_ACCEPT))
{
- /* Make the Notifcation */
- (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
- Socket->SharedData.wMsg,
+ /* Make the Notification */
+ (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
+ Socket->SharedData->wMsg,
Socket->Handle,
WSAMAKESELECTREPLY(FD_ACCEPT, 0));
/* Disable this event until the next accept(); */
- Socket->SharedData.AsyncDisabledEvents |= FD_ACCEPT;
+ Socket->SharedData->AsyncDisabledEvents |= FD_ACCEPT;
}
break;
case AFD_EVENT_DISCONNECT:
case AFD_EVENT_ABORT:
case AFD_EVENT_CLOSE:
- if (0 != (Socket->SharedData.AsyncEvents & FD_CLOSE) &&
- 0 == (Socket->SharedData.AsyncDisabledEvents & FD_CLOSE))
+ if (0 != (Socket->SharedData->AsyncEvents & FD_CLOSE) &&
+ 0 == (Socket->SharedData->AsyncDisabledEvents & FD_CLOSE))
{
- /* Make the Notifcation */
- (Upcalls.lpWPUPostMessage)(Socket->SharedData.hWnd,
- Socket->SharedData.wMsg,
+ /* Make the Notification */
+ (Upcalls.lpWPUPostMessage)(Socket->SharedData->hWnd,
+ Socket->SharedData->wMsg,
Socket->Handle,
WSAMAKESELECTREPLY(FD_CLOSE, 0));
/* Disable this event forever; */
- Socket->SharedData.AsyncDisabledEvents |= FD_CLOSE;
+ Socket->SharedData->AsyncDisabledEvents |= FD_CLOSE;
}
break;
/* FIXME: Support QOS */
}
/* Check if there are any events left for us to check */
- if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
+ if ((Socket->SharedData->AsyncEvents & (~Socket->SharedData->AsyncDisabledEvents)) == 0 )
{
return;
}
AsyncData->AsyncSelectInfo.Handles[0].Events = 0;
/* Remove unwanted events */
- lNetworkEvents = Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents);
+ lNetworkEvents = Socket->SharedData->AsyncEvents & (~Socket->SharedData->AsyncDisabledEvents);
/* Set Events to wait for */
if (lNetworkEvents & FD_READ)
Socket = AsyncData->ParentSocket;
/* If someone closed it, stop the function */
- if (Socket->SharedData.State != SocketClosed)
+ if (Socket->SharedData->State != SocketClosed)
{
/* Check if the Sequence Number changed by now, in which case quit */
- if (AsyncData->SequenceNumber == Socket->SharedData.SequenceNumber)
+ if (AsyncData->SequenceNumber == Socket->SharedData->SequenceNumber)
{
- /* Do the actuall select, if needed */
- if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)))
+ /* Do the actual select, if needed */
+ if ((Socket->SharedData->AsyncEvents & (~Socket->SharedData->AsyncDisabledEvents)))
{
SockProcessAsyncSelect(Socket, AsyncData);
FreeContext = FALSE;
PASYNC_DATA AsyncData;
/* Make sure the event is actually disabled */
- if (!(Socket->SharedData.AsyncDisabledEvents & Event))
+ if (!(Socket->SharedData->AsyncDisabledEvents & Event))
{
return;
}
/* Re-enable it */
- Socket->SharedData.AsyncDisabledEvents &= ~Event;
+ Socket->SharedData->AsyncDisabledEvents &= ~Event;
/* Return if no more events are being polled */
- if ((Socket->SharedData.AsyncEvents & (~Socket->SharedData.AsyncDisabledEvents)) == 0 )
+ if ((Socket->SharedData->AsyncEvents & (~Socket->SharedData->AsyncDisabledEvents)) == 0 )
{
return;
}
SockCreateOrReferenceAsyncThread();
/* Increase the sequence number to stop anything else */
- Socket->SharedData.SequenceNumber++;
+ Socket->SharedData->SequenceNumber++;
/* Set up the Async Data */
AsyncData->ParentSocket = Socket;
- AsyncData->SequenceNumber = Socket->SharedData.SequenceNumber;
+ AsyncData->SequenceNumber = Socket->SharedData->SequenceNumber;
/* Begin Async Select by using I/O Completion */
NtSetIoCompletion(SockAsyncCompletionPort,
{
case DLL_PROCESS_ATTACH:
- AFD_DbgPrint(MAX_TRACE, ("Loading MSAFD.DLL \n"));
+ TRACE("Loading MSAFD.DLL \n");
/* Don't need thread attach notifications
so disable them to improve performance */
/* Initialize the lock that protects our socket list */
InitializeCriticalSection(&SocketListLock);
- AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
+ TRACE("MSAFD.DLL has been loaded\n");
break;
break;
}
- AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
+ TRACE("DllMain of msafd.dll (leaving)\n");
return TRUE;
}
/* EOF */
-
-