[MSAFD]
[reactos.git] / reactos / dll / win32 / msafd / misc / dllmain.c
index a8ab0dc..9142a39 100644 (file)
@@ -23,7 +23,7 @@ HANDLE GlobalHeap;
 WSPUPCALLTABLE Upcalls;
 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
 ULONG SocketCount = 0;
-PSOCKET_INFORMATION *Sockets = NULL;
+PSOCKET_INFORMATION SocketListHead = NULL;
 LIST_ENTRY SockHelpersListHead = { NULL, NULL };
 ULONG SockAsyncThreadRefCount;
 HANDLE SockAsyncHelperAfdHandle;
@@ -292,8 +292,8 @@ WSPSocket(int AddressFamily,
                           NULL);
 
     /* Save in Process Sockets List */
-    Sockets[SocketCount] = Socket;
-    SocketCount ++;
+    Socket->NextSocket = SocketListHead;
+    SocketListHead = Socket;
 
     /* Create the Socket Context */
     CreateContext(Socket);
@@ -423,7 +423,7 @@ WSPCloseSocket(IN SOCKET Handle,
                OUT LPINT lpErrno)
 {
     IO_STATUS_BLOCK IoStatusBlock;
-    PSOCKET_INFORMATION Socket = NULL;
+    PSOCKET_INFORMATION Socket = NULL, CurrentSocket;
     NTSTATUS Status;
     HANDLE SockEvent;
     AFD_DISCONNECT_INFO DisconnectInfo;
@@ -562,6 +562,27 @@ WSPCloseSocket(IN SOCKET Handle,
     NtClose(Socket->TdiConnectionHandle);
     Socket->TdiConnectionHandle = NULL;
 
+    if (SocketListHead == Socket)
+    {
+        SocketListHead = SocketListHead->NextSocket;
+    }
+    else
+    {
+        CurrentSocket = SocketListHead;
+        while (CurrentSocket->NextSocket)
+        {
+            if (CurrentSocket->NextSocket == Socket)
+            {
+                CurrentSocket->NextSocket = CurrentSocket->NextSocket->NextSocket;
+                break;
+            }
+
+            CurrentSocket = CurrentSocket->NextSocket;
+        }
+    }
+
+    HeapFree(GlobalHeap, 0, Socket);
+
     /* Close the handle */
     NtClose((HANDLE)Handle);
     NtClose(SockEvent);
@@ -2244,16 +2265,22 @@ SetSocketInformation(PSOCKET_INFORMATION Socket,
 PSOCKET_INFORMATION
 GetSocketStructure(SOCKET Handle)
 {
-    ULONG i;
+    PSOCKET_INFORMATION CurrentSocket;
 
-    for (i=0; i<SocketCount; i++) 
+    /* This is a special case */
+    if (SocketListHead->Handle == Handle)
+        return SocketListHead;
+
+    CurrentSocket = SocketListHead;
+    while (CurrentSocket->NextSocket)
     {
-        if (Sockets[i]->Handle == Handle)
-        {
-            return Sockets[i];
-        }
+        if (CurrentSocket->Handle == Handle)
+            return CurrentSocket;
+
+        CurrentSocket = CurrentSocket->NextSocket;
     }
-    return 0;
+
+    return NULL;
 }
 
 int CreateContext(PSOCKET_INFORMATION Socket)
@@ -2771,10 +2798,6 @@ DllMain(HANDLE hInstDll,
         /* Heap to use when allocating */
         GlobalHeap = GetProcessHeap();
 
-        /* Allocate Heap for 1024 Sockets, can be expanded later */
-        Sockets = HeapAlloc(GetProcessHeap(), 0, sizeof(PSOCKET_INFORMATION) * 1024);
-        if (!Sockets) return FALSE;
-
         AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
 
         break;