HANDLE GlobalHeap;
WSPUPCALLTABLE Upcalls;
LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
-ULONG SocketCount = 0;
-PSOCKET_INFORMATION *Sockets = NULL;
+PSOCKET_INFORMATION SocketListHead = NULL;
+CRITICAL_SECTION SocketListLock;
LIST_ENTRY SockHelpersListHead = { NULL, NULL };
ULONG SockAsyncThreadRefCount;
HANDLE SockAsyncHelperAfdHandle;
ULONG SizeOfEA;
PAFD_CREATE_PACKET AfdPacket;
HANDLE Sock;
- PSOCKET_INFORMATION Socket = NULL, PrevSocket = NULL;
+ PSOCKET_INFORMATION Socket = NULL;
PFILE_FULL_EA_INFORMATION EABuffer = NULL;
PHELPER_DATA HelperData;
PVOID HelperDLLContext;
/* Save Handle */
Socket->Handle = (SOCKET)Sock;
- /* XXX See if there's a structure we can reuse -- We need to do this
- * more properly. */
- PrevSocket = GetSocketStructure( (SOCKET)Sock );
-
- if( PrevSocket )
- {
- RtlCopyMemory( PrevSocket, Socket, sizeof(*Socket) );
- RtlFreeHeap( GlobalHeap, 0, Socket );
- Socket = PrevSocket;
- }
-
/* Save Group Info */
if (g != 0)
{
NULL);
/* Save in Process Sockets List */
- Sockets[SocketCount] = Socket;
- SocketCount ++;
+ EnterCriticalSection(&SocketListLock);
+ Socket->NextSocket = SocketListHead;
+ SocketListHead = Socket;
+ LeaveCriticalSection(&SocketListLock);
/* Create the Socket Context */
CreateContext(Socket);
}
}
-DWORD MsafdReturnWithErrno(NTSTATUS Status,
- LPINT Errno,
- DWORD Received,
- LPDWORD ReturnedBytes)
-{
- *Errno = TranslateNtStatusError(Status);
-
- if (ReturnedBytes)
- {
- if (!*Errno)
- *ReturnedBytes = Received;
- else
- *ReturnedBytes = 0;
- }
-
- return *Errno ? SOCKET_ERROR : 0;
-}
-
/*
* FUNCTION: Closes an open socket
* ARGUMENTS:
OUT LPINT lpErrno)
{
IO_STATUS_BLOCK IoStatusBlock;
- PSOCKET_INFORMATION Socket = NULL;
+ PSOCKET_INFORMATION Socket = NULL, CurrentSocket;
NTSTATUS Status;
HANDLE SockEvent;
AFD_DISCONNECT_INFO DisconnectInfo;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
if (Socket->HelperEvents & WSH_NOTIFY_CLOSE)
{
NtClose(Socket->TdiConnectionHandle);
Socket->TdiConnectionHandle = NULL;
+ EnterCriticalSection(&SocketListLock);
+ if (SocketListHead == Socket)
+ {
+ SocketListHead = SocketListHead->NextSocket;
+ }
+ else
+ {
+ CurrentSocket = SocketListHead;
+ while (CurrentSocket->NextSocket)
+ {
+ if (CurrentSocket->NextSocket == Socket)
+ {
+ CurrentSocket->NextSocket = CurrentSocket->NextSocket->NextSocket;
+ break;
+ }
+
+ CurrentSocket = CurrentSocket->NextSocket;
+ }
+ }
+ LeaveCriticalSection(&SocketListLock);
+
+ HeapFree(GlobalHeap, 0, Socket);
+
/* Close the handle */
NtClose((HANDLE)Handle);
NtClose(SockEvent);
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ HeapFree(GlobalHeap, 0, BindData);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Set up Address in TDI Format */
BindData->Address.TAAddressCount = 1;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
if (Socket->SharedData.Listening)
return 0;
int
WSPAPI
-WSPSelect(int nfds,
- fd_set *readfds,
- fd_set *writefds,
- fd_set *exceptfds,
- const LPTIMEVAL timeout,
- LPINT lpErrno)
+WSPSelect(IN int nfds,
+ IN OUT fd_set *readfds OPTIONAL,
+ IN OUT fd_set *writefds OPTIONAL,
+ IN OUT fd_set *exceptfds OPTIONAL,
+ IN const struct timeval *timeout OPTIONAL,
+ OUT LPINT lpErrno)
{
IO_STATUS_BLOCK IOSB;
PAFD_POLL_INFO PollInfo;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return INVALID_SOCKET;
+ }
/* If this is non-blocking, make sure there's something for us to accept */
FD_ZERO(&ReadSet);
Timeout.tv_sec=0;
Timeout.tv_usec=0;
- WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, NULL);
+ if (WSPSelect(0, &ReadSet, NULL, NULL, &Timeout, lpErrno) == SOCKET_ERROR)
+ {
+ NtClose(SockEvent);
+ return INVALID_SOCKET;
+ }
if (ReadSet.fd_array[0] != Socket->Handle)
{
NtClose(SockEvent);
- return 0;
+ *lpErrno = WSAEWOULDBLOCK;
+ return INVALID_SOCKET;
}
/* Send IOCTL */
&ProtocolInfo,
GroupID,
Socket->SharedData.CreateFlags,
- NULL);
+ lpErrno);
+ if (AcceptSocket == INVALID_SOCKET)
+ return INVALID_SOCKET;
/* Set up the Accept Structure */
AcceptData.ListenHandle = (HANDLE)AcceptSocket;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Bind us First */
if (Socket->SharedData.State == SocketOpen)
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Set AFD Disconnect Type */
switch (HowTo)
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Allocate a buffer for the address */
TdiAddressSize =
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(s);
+ if (!Socket)
+ {
+ NtClose(SockEvent);
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
/* Allocate a buffer for the address */
TdiAddressSize = sizeof(TRANSPORT_ADDRESS) + *NameLength;
OUT LPINT lpErrno)
{
PSOCKET_INFORMATION Socket = NULL;
+ BOOLEAN NeedsCompletion;
/* Get the Socket Structure associate to this Socket*/
Socket = GetSocketStructure(Handle);
+ if (!Socket)
+ {
+ *lpErrno = WSAENOTSOCK;
+ return SOCKET_ERROR;
+ }
+
+ *lpcbBytesReturned = 0;
switch( dwIoControlCode )
{
return SOCKET_ERROR;
}
Socket->SharedData.NonBlocking = *((PULONG)lpvInBuffer) ? 1 : 0;
- return SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+ *lpErrno = SetSocketInformation(Socket, AFD_INFO_BLOCKING_MODE, (PULONG)lpvInBuffer, NULL);
+ if (*lpErrno != NO_ERROR)
+ return SOCKET_ERROR;
+ else
+ return NO_ERROR;
case FIONREAD:
if( cbOutBuffer < sizeof(INT) || IS_INTRESOURCE(lpvOutBuffer) )
{
*lpErrno = WSAEFAULT;
return SOCKET_ERROR;
}
- return GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
+ *lpErrno = GetSocketInformation(Socket, AFD_INFO_RECEIVE_CONTENT_SIZE, (PULONG)lpvOutBuffer, NULL);
+ if (*lpErrno != NO_ERROR)
+ return SOCKET_ERROR;
+ else
+ {
+ *lpcbBytesReturned = sizeof(ULONG);
+ return NO_ERROR;
+ }
default:
- *lpErrno = WSAEINVAL;
- return SOCKET_ERROR;
+ *lpErrno = Socket->HelperData->WSHIoctl(Socket->HelperContext,
+ Handle,
+ Socket->TdiAddressHandle,
+ Socket->TdiConnectionHandle,
+ dwIoControlCode,
+ lpvInBuffer,
+ cbInBuffer,
+ lpvOutBuffer,
+ cbOutBuffer,
+ lpcbBytesReturned,
+ lpOverlapped,
+ lpCompletionRoutine,
+ (LPBOOL)&NeedsCompletion);
+
+ if (*lpErrno != NO_ERROR)
+ return SOCKET_ERROR;
+ else
+ return NO_ERROR;
}
}
}
- /* FIXME: We should handle some cases here */
+ /* FIXME: We should handle some more cases here */
+ if (level == SOL_SOCKET)
+ {
+ switch (optname)
+ {
+ case SO_BROADCAST:
+ Socket->SharedData.Broadcast = (*optval != 0) ? 1 : 0;
+ return 0;
+ }
+ }
*lpErrno = Socket->HelperData->WSHSetSocketInformation(Socket->HelperContext,
PSOCKET_INFORMATION
GetSocketStructure(SOCKET Handle)
{
- ULONG i;
+ PSOCKET_INFORMATION CurrentSocket;
+
+ EnterCriticalSection(&SocketListLock);
- for (i=0; i<SocketCount; i++)
+ CurrentSocket = SocketListHead;
+ while (CurrentSocket)
{
- if (Sockets[i]->Handle == Handle)
+ if (CurrentSocket->Handle == Handle)
{
- return Sockets[i];
+ LeaveCriticalSection(&SocketListLock);
+ return CurrentSocket;
}
+
+ CurrentSocket = CurrentSocket->NextSocket;
}
- return 0;
+
+ LeaveCriticalSection(&SocketListLock);
+
+ return NULL;
}
int CreateContext(PSOCKET_INFORMATION Socket)
/* Heap to use when allocating */
GlobalHeap = GetProcessHeap();
- /* Allocate Heap for 1024 Sockets, can be expanded later */
- Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
- if (!Sockets) return FALSE;
+ /* Initialize the lock that protects our socket list */
+ InitializeCriticalSection(&SocketListLock);
AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
break;
case DLL_PROCESS_DETACH:
+
+ /* Delete the socket list lock */
+ DeleteCriticalSection(&SocketListLock);
+
break;
}