HANDLE GlobalHeap;
WSPUPCALLTABLE Upcalls;
LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
-ULONG SocketCount = 0;
-PSOCKET_INFORMATION *Sockets = NULL;
+PSOCKET_INFORMATION SocketListHead = NULL;
+CRITICAL_SECTION SocketListLock;
LIST_ENTRY SockHelpersListHead = { NULL, NULL };
ULONG SockAsyncThreadRefCount;
HANDLE SockAsyncHelperAfdHandle;
-HANDLE SockAsyncCompletionPort;
+HANDLE SockAsyncCompletionPort = NULL;
BOOLEAN SockAsyncSelectCalled;
ULONG SizeOfEA;
PAFD_CREATE_PACKET AfdPacket;
HANDLE Sock;
- PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
+ PSOCKET_INFORMATION Socket = NULL;
PFILE_FULL_EA_INFORMATION EABuffer = NULL;
PHELPER_DATA HelperData;
PVOID HelperDLLContext;
/* Save Handle */
Socket->Handle = (SOCKET)Sock;
- /* XXX See if there's a structure we can reuse -- We need to do this
- * more properly. */
- PrevSocket = GetSocketStructure( (SOCKET)Sock );
-
- if( PrevSocket )
- {
- RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
- RtlFreeHeap( GlobalHeap, 0, Socket );
- Socket = PrevSocket;
- }
-
/* Save Group Info */
if (g != 0)
{
- GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, 0, &GroupData);
+ GetSocketInformation(Socket, AFD_INFO_GROUP_ID_TYPE, NULL, NULL, &GroupData);
Socket->SharedData.GroupID = GroupData.u.LowPart;
Socket->SharedData.GroupType = GroupData.u.HighPart;
}
/* Get Window Sizes and Save them */
GetSocketInformation (Socket,
AFD_INFO_SEND_WINDOW_SIZE,
+ NULL,
&Socket->SharedData.SizeOfSendBuffer,
NULL);
GetSocketInformation (Socket,
AFD_INFO_RECEIVE_WINDOW_SIZE,
+ NULL,
&Socket->SharedData.SizeOfRecvBuffer,
NULL);
/* Save in Process Sockets List */
- Sockets[SocketCount] = Socket;
- SocketCount ++;
+ EnterCriticalSection(&SocketListLock);
+ Socket->NextSocket = SocketListHead;
+ SocketListHead = Socket;
+ LeaveCriticalSection(&SocketListLock);
/* Create the Socket Context */
CreateContext(Socket);
return INVALID_SOCKET;
}
-
-DWORD MsafdReturnWithErrno(NTSTATUS Status,
- LPINT Errno,
- DWORD Received,
- LPDWORD ReturnedBytes)
+INT
+TranslateNtStatusError(NTSTATUS Status)
{
- if( ReturnedBytes )
- *ReturnedBytes = 0;
- if( Errno )
+ switch (Status)
{
- switch (Status)
- {
- case STATUS_CANT_WAIT:
- *Errno = WSAEWOULDBLOCK;
- break;
- case STATUS_TIMEOUT:
- *Errno = WSAETIMEDOUT;
- break;
- case STATUS_SUCCESS:
- /* Return Number of bytes Read */
- if( ReturnedBytes )
- *ReturnedBytes = Received;
- break;
- case STATUS_FILE_CLOSED:
- case STATUS_END_OF_FILE:
- *Errno = WSAESHUTDOWN;
- break;
- case STATUS_PENDING:
- *Errno = WSA_IO_PENDING;
- break;
- case STATUS_BUFFER_TOO_SMALL:
- case STATUS_BUFFER_OVERFLOW:
- DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
- *Errno = WSAEMSGSIZE;
- break;
- case STATUS_NO_MEMORY: /* Fall through to STATUS_INSUFFICIENT_RESOURCES */
- case STATUS_INSUFFICIENT_RESOURCES:
- DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
- *Errno = WSAENOBUFS;
- break;
- case STATUS_INVALID_CONNECTION:
- DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
- *Errno = WSAEAFNOSUPPORT;
- break;
- case STATUS_INVALID_ADDRESS:
- DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
- *Errno = WSAEADDRNOTAVAIL;
- break;
- case STATUS_REMOTE_NOT_LISTENING:
- DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
- *Errno = WSAECONNREFUSED;
- break;
- case STATUS_NETWORK_UNREACHABLE:
- DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
- *Errno = WSAENETUNREACH;
- break;
- case STATUS_INVALID_PARAMETER:
- DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
- *Errno = WSAEINVAL;
- break;
- case STATUS_CANCELLED:
- DbgPrint("MSAFD: STATUS_CANCELLED\n");
- *Errno = WSA_OPERATION_ABORTED;
- break;
- default:
- DbgPrint("MSAFD: Error %x is unknown\n", Status);
- *Errno = WSAEINVAL;
- break;
- }
- }
+ case STATUS_CANT_WAIT:
+ return WSAEWOULDBLOCK;
+
+ case STATUS_TIMEOUT:
+ return WSAETIMEDOUT;
+
+ case STATUS_SUCCESS:
+ return NO_ERROR;
+
+ case STATUS_FILE_CLOSED:
+ case STATUS_END_OF_FILE:
+ return WSAESHUTDOWN;
+
+ case STATUS_PENDING:
+ return WSA_IO_PENDING;
+
+ case STATUS_BUFFER_TOO_SMALL:
+ case STATUS_BUFFER_OVERFLOW:
+ DbgPrint("MSAFD: STATUS_BUFFER_TOO_SMALL/STATUS_BUFFER_OVERFLOW\n");
+ return WSAEMSGSIZE;
+
+ case STATUS_NO_MEMORY:
+ case STATUS_INSUFFICIENT_RESOURCES:
+ DbgPrint("MSAFD: STATUS_NO_MEMORY/STATUS_INSUFFICIENT_RESOURCES\n");
+ return WSAENOBUFS;
+
+ case STATUS_INVALID_CONNECTION:
+ DbgPrint("MSAFD: STATUS_INVALID_CONNECTION\n");
+ return WSAEAFNOSUPPORT;
+
+ case STATUS_INVALID_ADDRESS:
+ DbgPrint("MSAFD: STATUS_INVALID_ADDRESS\n");
+ return WSAEADDRNOTAVAIL;
+
+ case STATUS_REMOTE_NOT_LISTENING:
+ DbgPrint("MSAFD: STATUS_REMOTE_NOT_LISTENING\n");
+ return WSAECONNREFUSED;
- /* Success */
- return Status == STATUS_SUCCESS ? 0 : SOCKET_ERROR;
+ case STATUS_NETWORK_UNREACHABLE:
+ DbgPrint("MSAFD: STATUS_NETWORK_UNREACHABLE\n");
+ return WSAENETUNREACH;
+
+ case STATUS_INVALID_PARAMETER:
+ DbgPrint("MSAFD: STATUS_INVALID_PARAMETER\n");
+ return WSAEINVAL;
+
+ case STATUS_CANCELLED:
+ DbgPrint("MSAFD: STATUS_CANCELLED\n");
+ return WSA_OPERATION_ABORTED;
+
+ case STATUS_ADDRESS_ALREADY_EXISTS:
+ DbgPrint("MSAFD: STATUS_ADDRESS_ALREADY_EXISTS\n");
+ return WSAEADDRINUSE;
+
+ case STATUS_LOCAL_DISCONNECT:
+ DbgPrint("MSAFD: STATUS_LOCAL_DISCONNECT\n");
+ return WSAECONNABORTED;
+
+ case STATUS_REMOTE_DISCONNECT:
+ DbgPrint("MSAFD: STATUS_REMOTE_DISCONNECT\n");
+ return WSAECONNRESET;
+
+ default:
+ DbgPrint("MSAFD: Unhandled NTSTATUS value: 0x%x\n", Status);
+ return WSAENETDOWN;
+ }
}
/*
OUT LPINT lpErrno)
{
IO_STATUS_BLOCK IoStatusBlock;
- PSOCKET_INFORMATION Socket = NULL;
+ PSOCKET_INFORMATION Socket = NULL, CurrentSocket;
NTSTATUS Status;
HANDLE SockEvent;
AFD_DISCONNECT_INFO DisconnectInfo;
SOCKET_STATE OldState;
+ LONG LingerWait = -1;
/* Create the Wait Event */
Status = NtCreateEvent(&SockEvent,
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
{
/* FIXME: Should we do this on Datagram Sockets too? */
if ((OldState == SocketConnected) && (Socket->SharedData.LingerData.l_onoff))
{
- ULONG LingerWait;
ULONG SendsInProgress;
ULONG SleepWait;
/* Find out how many Sends are in Progress */
if (GetSocketInformation(Socket,
AFD_INFO_SENDS_IN_PROGRESS,
+ NULL,
&SendsInProgress,
NULL))
{
/* Bail out if no more sends are pending */
if (!SendsInProgress)
+ {
+ LingerWait = -1;
break;
+ }
+
/*
* We have to execute a sleep, so it's kind of like
* a block. If the socket is Nonblock, we cannot
Sleep(SleepWait);
LingerWait -= SleepWait;
}
+ }
- /*
- * We have reached the timeout or sends are over.
- * Disconnect if the timeout has been reached.
- */
+ if (OldState == SocketConnected)
+ {
if (LingerWait <= 0)
{
DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(0);
- DisconnectInfo.DisconnectType = AFD_DISCONNECT_ABORT;
-
- /* Send IOCTL */
- Status = NtDeviceIoControlFile((HANDLE)Handle,
- SockEvent,
- NULL,
- NULL,
- &IoStatusBlock,
- IOCTL_AFD_DISCONNECT,
- &DisconnectInfo,
- sizeof(DisconnectInfo),
- NULL,
- 0);
-
- /* Wait for return */
- if (Status == STATUS_PENDING)
+ DisconnectInfo.DisconnectType = LingerWait < 0 ? AFD_DISCONNECT_SEND : AFD_DISCONNECT_ABORT;
+
+ if (((DisconnectInfo.DisconnectType & AFD_DISCONNECT_SEND) && (!Socket->SharedData.SendShutdown)) ||
+ ((DisconnectInfo.DisconnectType & AFD_DISCONNECT_ABORT) && (!Socket->SharedData.ReceiveShutdown)))
{
- WaitForSingleObject(SockEvent, INFINITE);
- Status = IoStatusBlock.Status;
+ /* Send IOCTL */
+ Status = NtDeviceIoControlFile((HANDLE)Handle,
+ SockEvent,
+ NULL,
+ NULL,
+ &IoStatusBlock,
+ IOCTL_AFD_DISCONNECT,
+ &DisconnectInfo,
+ sizeof(DisconnectInfo),
+ NULL,
+ 0);
+
+ /* Wait for return */
+ if (Status == STATUS_PENDING)
+ {
+ WaitForSingleObject(SockEvent, INFINITE);
+ Status = IoStatusBlock.Status;
+ }
}
}
}
NtClose(Socket->TdiConnectionHandle);
Socket->TdiConnectionHandle = NULL;
+ EnterCriticalSection(&SocketListLock);
+ if (SocketListHead == Socket)
+ {
+ SocketListHead = SocketListHead->NextSocket;
+ }
+ else
+ {
+ CurrentSocket = SocketListHead;
+ while (CurrentSocket->NextSocket)
+ {
+ if (CurrentSocket->NextSocket == Socket)
+ {
+ CurrentSocket->NextSocket = CurrentSocket->NextSocket->NextSocket;
+ break;
+ }
+
+ CurrentSocket = CurrentSocket->NextSocket;
+ }
+ }
+ LeaveCriticalSection(&SocketListLock);
+
/* Close the handle */
NtClose((HANDLE)Handle);
NtClose(SockEvent);
- return NO_ERROR;
+ HeapFree(GlobalHeap, 0, Socket);
+ return MsafdReturnWithErrno(Status, lpErrno, 0, NULL);
}
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ HeapFree(GlobalHeap, 0, BindData);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Set up Address in TDI Format */
BindData->Address.TAAddressCount = 1;
Status = IOSB.Status;
}
+ NtClose( SockEvent );
+ HeapFree(GlobalHeap, 0, BindData);
+
+ if (Status != STATUS_SUCCESS)
+ return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
+
/* Set up Socket Data */
Socket->SharedData.State = SocketBound;
Socket->TdiAddressHandle = (HANDLE)IOSB.Information;
- NtClose( SockEvent );
- HeapFree(GlobalHeap, 0, BindData);
- if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_BIND))
+ if (Socket->HelperEvents & WSH_NOTIFY_BIND)
{
Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
Socket->Handle,
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
if (Socket->SharedData.Listening)
return 0;
{
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
- }
+ }
+
+ NtClose( SockEvent );
+
+ if (Status != STATUS_SUCCESS)
+ return MsafdReturnWithErrno ( Status, lpErrno, 0, NULL );
/* Set to Listening */
Socket->SharedData.Listening = TRUE;
- NtClose( SockEvent );
-
- if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_LISTEN))
+ if (Socket->HelperEvents & WSH_NOTIFY_LISTEN)
{
Status = Socket->HelperData->WSHNotify(Socket->HelperContext,
Socket->Handle,
int
WSPAPI
-WSPSelect(int nfds,
- fd_set *readfds,
- fd_set *writefds,
- fd_set *exceptfds,
- const LPTIMEVAL timeout,
- LPINT lpErrno)
+WSPSelect(IN int nfds,
+ IN OUT fd_set *readfds OPTIONAL,
+ IN OUT fd_set *writefds OPTIONAL,
+ IN OUT fd_set *exceptfds OPTIONAL,
+ IN const struct timeval *timeout OPTIONAL,
+ OUT LPINT lpErrno)
{
IO_STATUS_BLOCK IOSB;
PAFD_POLL_INFO PollInfo;
NTSTATUS Status;
- LONG HandleCount, OutCount = 0;
+ ULONG HandleCount;
+ LONG OutCount = 0;
ULONG PollBufferSize;
PVOID PollBuffer;
ULONG i, j = 0, x;
if ( HandleCount == 0 )
{
- AFD_DbgPrint(MAX_TRACE,("HandleCount: %d. Return SOCKET_ERROR\n",
+ AFD_DbgPrint(MAX_TRACE,("HandleCount: %u. Return SOCKET_ERROR\n",
HandleCount));
if (lpErrno) *lpErrno = WSAEINVAL;
return SOCKET_ERROR;
PollBufferSize = sizeof(*PollInfo) + ((HandleCount - 1) * sizeof(AFD_HANDLE));
- AFD_DbgPrint(MID_TRACE,("HandleCount: %d BufferSize: %d\n",
+ AFD_DbgPrint(MID_TRACE,("HandleCount: %u BufferSize: %u\n",
HandleCount, PollBufferSize));
/* Convert Timeout to NT Format */
}
PollInfo->HandleCount = j;
- PollBufferSize = ((PCHAR)&PollInfo->Handles[j+1]) - ((PCHAR)PollInfo);
+ PollBufferSize = FIELD_OFFSET(AFD_POLL_INFO, Handles) + PollInfo->HandleCount * sizeof(AFD_HANDLE);
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)PollInfo->Handles[0].Handle,
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
+ Status = IOSB.Status;
}
/* Clear the Structures */
ULONG CallBack;
WSAPROTOCOL_INFOW ProtocolInfo;
SOCKET AcceptSocket;
+ PSOCKET_INFORMATION AcceptSocketInfo;
UCHAR ReceiveBuffer[0x1A];
HANDLE SockEvent;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return INVALID_SOCKET;
+ }
/* If this is non-blocking, make sure there's something for us to accept */
FD_ZERO(&ReadSet);
Timeout.tv_sec=0;
Timeout.tv_usec=0;
- WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
+ if (WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, lpErrno) == SOCKET_ERROR)
+ {
+ NtClose(SockEvent);
+ return INVALID_SOCKET;
+ }
if (ReadSet.fd_array[0] != Socket->Handle)
{
NtClose(SockEvent);
- return 0;
+ *lpErrno = WSAEWOULDBLOCK;
+ return INVALID_SOCKET;
}
/* Send IOCTL */
&ProtocolInfo,
GroupID,
Socket->SharedData.CreateFlags,
- NULL);
+ lpErrno);
+ if (AcceptSocket == INVALID_SOCKET)
+ return INVALID_SOCKET;
/* Set up the Accept Structure */
AcceptData.ListenHandle = (HANDLE)AcceptSocket;
MsafdReturnWithErrno( Status, lpErrno, 0, NULL );
return INVALID_SOCKET;
}
+
+ AcceptSocketInfo = GetSocketStructure(AcceptSocket);
+ if (!AcceptSocketInfo)
+ {
+ NtClose(SockEvent);
+ WSPCloseSocket( AcceptSocket, lpErrno );
+ MsafdReturnWithErrno( STATUS_INVALID_CONNECTION, lpErrno, 0, NULL );
+ return INVALID_SOCKET;
+ }
+
+ AcceptSocketInfo->SharedData.State = SocketConnected;
/* Return Address in SOCKADDR FORMAT */
if( SocketAddress )
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Bind us First */
if (Socket->SharedData.State == SocketOpen)
BindAddress,
&BindAddressLength);
/* Bind it */
- WSPBind(Handle, BindAddress, BindAddressLength, NULL);
+ if (WSPBind(Handle, BindAddress, BindAddressLength, lpErrno) == SOCKET_ERROR)
+ return INVALID_SOCKET;
}
/* Set the Connect Data */
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
}
+
+ if (Status != STATUS_SUCCESS)
+ goto notify;
}
/* Dynamic Structure...ugh */
WaitForSingleObject(SockEvent, INFINITE);
Status = IOSB.Status;
}
+
+ if (Status != STATUS_SUCCESS)
+ goto notify;
}
/* AFD doesn't seem to care if these are invalid, but let's 0 them anyways */
Status = IOSB.Status;
}
+ if (Status != STATUS_SUCCESS)
+ goto notify;
+
+ Socket->SharedData.State = SocketConnected;
Socket->TdiConnectionHandle = (HANDLE)IOSB.Information;
/* Get any pending connect data */
}
}
+ AFD_DbgPrint(MID_TRACE,("Ending\n"));
+
+notify:
/* Re-enable Async Event */
SockReenableAsyncSelectEvent(Socket, FD_WRITE);
/* FIXME: THIS IS NOT RIGHT!!! HACK HACK HACK! */
SockReenableAsyncSelectEvent(Socket, FD_CONNECT);
- AFD_DbgPrint(MID_TRACE,("Ending\n"));
-
NtClose( SockEvent );
if (Status == STATUS_SUCCESS && (Socket->HelperEvents & WSH_NOTIFY_CONNECT))
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Set AFD Disconnect Type */
switch (HowTo)
break;
}
- DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1);
+ DisconnectInfo.Timeout = RtlConvertLongToLargeInteger(-1000000);
/* Send IOCTL */
Status = NtDeviceIoControlFile((HANDLE)Handle,
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+
+ if (!Name || !NameLength)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
/* Allocate a buffer for the address */
TdiAddressSize =
if (NT_SUCCESS(Status))
{
- if (*NameLength >= SocketAddress->Address[0].AddressLength)
+ if (*NameLength >= Socket->SharedData.SizeOfLocalAddress)
{
Name->sa_family = SocketAddress->Address[0].AddressType;
RtlCopyMemory (Name->sa_data,
SocketAddress->Address[0].Address,
SocketAddress->Address[0].AddressLength);
- *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+ *NameLength = Socket->SharedData.SizeOfLocalAddress;
AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %x Port %x\n",
*NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
((struct sockaddr_in *)Name)->sin_port));
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(s);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+
+ if (Socket->SharedData.State != SocketConnected)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTCONN;
+ return SOCKET_ERROR;
+ }
+
+ if (!Name || !NameLength)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAEFAULT;
+ return SOCKET_ERROR;
+ }
/* Allocate a buffer for the address */
- TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
+ TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + Socket->SharedData.SizeOfRemoteAddress;
SocketAddress = HeapAlloc(GlobalHeap, 0, TdiAddressSize);
if ( SocketAddress == NULL )
if (NT_SUCCESS(Status))
{
- if (*NameLength >= SocketAddress->Address[0].AddressLength)
+ if (*NameLength >= Socket->SharedData.SizeOfRemoteAddress)
{
Name->sa_family = SocketAddress->Address[0].AddressType;
RtlCopyMemory (Name->sa_data,
SocketAddress->Address[0].Address,
SocketAddress->Address[0].AddressLength);
- *NameLength = 2 + SocketAddress->Address[0].AddressLength;
+ *NameLength = Socket->SharedData.SizeOfRemoteAddress;
AFD_DbgPrint (MID_TRACE, ("NameLength %d Address: %s Port %x\n",
*NameLength, ((struct sockaddr_in *)Name)->sin_addr.s_addr,
((struct sockaddr_in *)Name)->sin_port));
OUT LPINT lpErrno)
{
PSOCKET_INFORMATION Socket = NULL;
+ BOOLEAN NeedsCompletion;
+ BOOLEAN NonBlocking;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+
+ *lpcbBytesReturned = 0;
switch( dwIoControlCode )
{
*lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
- return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+ NonBlocking = *((PULONG)lpvInBuffer) ? TRUE : FALSE;
+ Socket->SharedData.NonBlocking = NonBlocking ? 1 : 0;
+ *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, &NonBlocking, NULL, NULL);
+ if (*lpErrno != NO_ERROR)
+ return SOCKET_ERROR;
+ else
+ return NO_ERROR;
case FIONREAD:
if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
{
*lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
- default:
+ *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, NULL, (PULONG)lpvOutBuffer, NULL);
+ if (*lpErrno != NO_ERROR)
+ return SOCKET_ERROR;
+ else
+ {
+ *lpcbBytesReturned = sizeof(ULONG);
+ return NO_ERROR;
+ }
+ case SIO_GET_EXTENSION_FUNCTION_POINTER:
*lpErrno = WSAEINVAL;
return SOCKET_ERROR;
+ default:
+ *lpErrno = Socket->HelperData->WSHIoctl(Socket->HelperContext,
+ Handle,
+ Socket->TdiAddressHandle,
+ Socket->TdiConnectionHandle,
+ dwIoControlCode,
+ lpvInBuffer,
+ cbInBuffer,
+ lpvOutBuffer,
+ cbOutBuffer,
+ lpcbBytesReturned,
+ lpOverlapped,
+ lpCompletionRoutine,
+ (LPBOOL)&NeedsCompletion);
+
+ if (*lpErrno != NO_ERROR)
+ return SOCKET_ERROR;
+ else
+ return NO_ERROR;
}
}
}
- /* FIXME: We should handle some cases here */
+ /* FIXME: We should handle some more cases here */
+ if (level == SOL_SOCKET)
+ {
+ switch (optname)
+ {
+ case SO_BROADCAST:
+ Socket->SharedData.Broadcast = (*optval != 0) ? 1 : 0;
+ return 0;
+ }
+ }
*lpErrno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
int
GetSocketInformation(PSOCKET_INFORMATION Socket,
ULONG AfdInformationClass,
+ PBOOLEAN Boolean OPTIONAL,
PULONG Ulong OPTIONAL,
PLARGE_INTEGER LargeInteger OPTIONAL)
{
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
+ Status = IOSB.Status;
}
+ if (Status != STATUS_SUCCESS)
+ return -1;
+
/* Return Information */
if (Ulong != NULL)
{
{
*LargeInteger = InfoData.Information.LargeInteger;
}
+ if (Boolean != NULL)
+ {
+ *Boolean = InfoData.Information.Boolean;
+ }
NtClose( SockEvent );
int
SetSocketInformation(PSOCKET_INFORMATION Socket,
- ULONG AfdInformationClass,
+ ULONG AfdInformationClass,
+ PBOOLEAN Boolean OPTIONAL,
PULONG Ulong OPTIONAL,
PLARGE_INTEGER LargeInteger OPTIONAL)
{
{
InfoData.Information.LargeInteger = *LargeInteger;
}
+ if (Boolean != NULL)
+ {
+ InfoData.Information.Boolean = *Boolean;
+ }
AFD_DbgPrint(MID_TRACE,("XXX Info %x (Data %x)\n",
AfdInformationClass, *Ulong));
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
+ Status = IOSB.Status;
}
NtClose( SockEvent );
- return 0;
+ return Status == STATUS_SUCCESS ? 0 : -1;
}
PSOCKET_INFORMATION
GetSocketStructure(SOCKET Handle)
{
- ULONG i;
+ PSOCKET_INFORMATION CurrentSocket;
+
+ EnterCriticalSection(&SocketListLock);
- for (i=0; i<SocketCount; i++)
+ CurrentSocket = SocketListHead;
+ while (CurrentSocket)
{
- if (Sockets[i]->Handle == Handle)
+ if (CurrentSocket->Handle == Handle)
{
- return Sockets[i];
+ LeaveCriticalSection(&SocketListLock);
+ return CurrentSocket;
}
+
+ CurrentSocket = CurrentSocket->NextSocket;
}
- return 0;
+
+ LeaveCriticalSection(&SocketListLock);
+
+ return NULL;
}
int CreateContext(PSOCKET_INFORMATION Socket)
if (Status == STATUS_PENDING)
{
WaitForSingleObject(SockEvent, INFINITE);
+ Status = IOSB.Status;
}
NtClose( SockEvent );
- return 0;
+ return Status == STATUS_SUCCESS ? 0 : -1;
}
BOOLEAN SockCreateOrReferenceAsyncThread(VOID)
/* Check if the Thread Already Exists */
if (SockAsyncThreadRefCount)
{
+ ASSERT(SockAsyncCompletionPort);
return TRUE;
}
IO_COMPLETION_ALL_ACCESS,
NULL,
2); // Allow 2 threads only
-
+ if (!NT_SUCCESS(Status))
+ {
+ AFD_DbgPrint(MID_TRACE,("Failed to create completion port\n"));
+ return FALSE;
+ }
/* Protect Handle */
HandleFlags.ProtectFromClose = TRUE;
HandleFlags.Inherit = FALSE;
/* Heap to use when allocating */
GlobalHeap = GetProcessHeap();
- /* Allocate Heap for 1024 Sockets, can be expanded later */
- Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
- if (!Sockets) return FALSE;
+ /* Initialize the lock that protects our socket list */
+ InitializeCriticalSection(&SocketListLock);
AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
break;
case DLL_PROCESS_DETACH:
+
+ /* Delete the socket list lock */
+ DeleteCriticalSection(&SocketListLock);
+
break;
}