More work on winsock stack (ping is now working)
[reactos.git] / reactos / lib / msafd / misc / dllmain.c
index c037a9f..5b206c2 100644 (file)
@@ -14,6 +14,7 @@
 
 /* See debug.h for debug/trace constants */
 DWORD DebugTraceLevel = MIN_TRACE;
+//DWORD DebugTraceLevel = DEBUG_ULTRA;
 
 #endif /* DBG */
 
@@ -26,15 +27,17 @@ WSPUPCALLTABLE Upcalls;
 LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
 CRITICAL_SECTION InitCriticalSection;
 DWORD StartupCount = 0;
+HANDLE CommandChannel;
+
 
 NTSTATUS OpenSocket(
-    SOCKET *Socket,
-    INT AddressFamily,
-    INT SocketType,
-    INT Protocol,
-    PVOID HelperContext,
-    DWORD NotificationEvents,
-    PUNICODE_STRING TdiDeviceName)
+  SOCKET *Socket,
+  INT AddressFamily,
+  INT SocketType,
+  INT Protocol,
+  PVOID HelperContext,
+  DWORD NotificationEvents,
+  PUNICODE_STRING TdiDeviceName)
 /*
  * FUNCTION: Opens a socket
  * ARGUMENTS:
@@ -49,97 +52,102 @@ NTSTATUS OpenSocket(
  *     Status of operation
  */
 {
-    OBJECT_ATTRIBUTES ObjectAttributes;
-    PAFD_SOCKET_INFORMATION SocketInfo;
-    PFILE_FULL_EA_INFORMATION EaInfo;
-    UNICODE_STRING DeviceName;
-    IO_STATUS_BLOCK Iosb;
-    HANDLE FileHandle;
-    NTSTATUS Status;
-    ULONG EaLength;
-    ULONG EaShort;
-
-    AFD_DbgPrint(MAX_TRACE, ("Called.\n"));
+  OBJECT_ATTRIBUTES ObjectAttributes;
+  PAFD_SOCKET_INFORMATION SocketInfo;
+  PFILE_FULL_EA_INFORMATION EaInfo;
+  UNICODE_STRING DeviceName;
+  IO_STATUS_BLOCK Iosb;
+  HANDLE FileHandle;
+  NTSTATUS Status;
+  ULONG EaLength;
+  ULONG EaShort;
 
-    EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
-        AFD_SOCKET_LENGTH +
-        sizeof(AFD_SOCKET_INFORMATION);
+  AFD_DbgPrint(MAX_TRACE, ("Socket (0x%X)  TdiDeviceName (%wZ)\n",
+    Socket, TdiDeviceName));
 
-    EaLength = EaShort + TdiDeviceName->Length + sizeof(WCHAR);
+  AFD_DbgPrint(MAX_TRACE, ("Socket2 (0x%X)  TdiDeviceName (%S)\n",
+    Socket, TdiDeviceName->Buffer));
 
-    EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
-    if (!EaInfo) {
-        AFD_DbgPrint(MIN_TRACE, ("Insufficient resources.\n"));
-        return STATUS_INSUFFICIENT_RESOURCES;
-    }
+  EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
+    AFD_SOCKET_LENGTH +
+    sizeof(AFD_SOCKET_INFORMATION);
 
-    RtlZeroMemory(EaInfo, EaLength);
-    EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
-    RtlCopyMemory(EaInfo->EaName,
-                  AfdSocket,
-                  AFD_SOCKET_LENGTH);
-    EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
-
-    SocketInfo = (PAFD_SOCKET_INFORMATION)(EaInfo->EaName + AFD_SOCKET_LENGTH);
-
-    SocketInfo->AddressFamily      = AddressFamily;
-    SocketInfo->SocketType         = SocketType;
-    SocketInfo->Protocol           = Protocol;
-    SocketInfo->HelperContext      = HelperContext;
-    SocketInfo->NotificationEvents = NotificationEvents;
-    /* Zeroed above so initialized to a wildcard address if a raw socket */
-    SocketInfo->Name.sa_family     = AddressFamily;
-
-    /* Store TDI device name last in buffer */
-    SocketInfo->TdiDeviceName.Buffer = (PWCHAR)(EaInfo + EaShort);
-    SocketInfo->TdiDeviceName.MaximumLength = TdiDeviceName->Length + sizeof(WCHAR);
-    RtlCopyUnicodeString(&SocketInfo->TdiDeviceName, TdiDeviceName);
-
-    AFD_DbgPrint(MAX_TRACE, ("EaInfo at (0x%X)  EaLength is (%d).\n", (UINT)EaInfo, (INT)EaLength));
-
-
-    RtlInitUnicodeString(&DeviceName, L"\\Device\\Afd");
-         InitializeObjectAttributes(&ObjectAttributes,
-        &DeviceName,
-        0,
-        NULL,
-        NULL);
-
-    Status = NtCreateFile(&FileHandle,
-        FILE_GENERIC_READ | FILE_GENERIC_WRITE,
-        &ObjectAttributes,
-        &Iosb,
-        NULL,
-                   0,
-                   0,
-                   FILE_OPEN,
-                   FILE_SYNCHRONOUS_IO_ALERT,
-        EaInfo,
-        EaLength);
-
-    HeapFree(GlobalHeap, 0, EaInfo);
-
-    if (!NT_SUCCESS(Status)) {
-                 AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n", (UINT)Status));
-      return STATUS_INSUFFICIENT_RESOURCES;
-    }
+  EaLength = EaShort + TdiDeviceName->Length + sizeof(WCHAR);
+
+  EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
+  if (!EaInfo) {
+    return STATUS_INSUFFICIENT_RESOURCES;
+  }
 
-    *Socket = (SOCKET)FileHandle;
+  RtlZeroMemory(EaInfo, EaLength);
+  EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
+  RtlCopyMemory(EaInfo->EaName,
+    AfdSocket,
+    AFD_SOCKET_LENGTH);
+  EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
+
+  SocketInfo = (PAFD_SOCKET_INFORMATION)(EaInfo->EaName + AFD_SOCKET_LENGTH);
+  SocketInfo->CommandChannel     = FALSE;
+  SocketInfo->AddressFamily      = AddressFamily;
+  SocketInfo->SocketType         = SocketType;
+  SocketInfo->Protocol           = Protocol;
+  SocketInfo->HelperContext      = HelperContext;
+  SocketInfo->NotificationEvents = NotificationEvents;
+  /* Zeroed above so initialized to a wildcard address if a raw socket */
+  SocketInfo->Name.sa_family     = AddressFamily;
+
+  /* Store TDI device name last in buffer */
+  SocketInfo->TdiDeviceName.Buffer = (PWCHAR)(EaInfo + EaShort);
+  SocketInfo->TdiDeviceName.MaximumLength = TdiDeviceName->Length + sizeof(WCHAR);
+  RtlCopyUnicodeString(&SocketInfo->TdiDeviceName, TdiDeviceName);
+
+  AFD_DbgPrint(MAX_TRACE, ("EaInfo at (0x%X)  EaLength is (%d).\n", (UINT)EaInfo, (INT)EaLength));
+
+  RtlInitUnicodeString(&DeviceName, L"\\Device\\Afd");
+       InitializeObjectAttributes(
+    &ObjectAttributes,
+    &DeviceName,
+    0,
+    NULL,
+    NULL);
 
-    return STATUS_SUCCESS;
+  Status = NtCreateFile(
+    &FileHandle,
+    FILE_GENERIC_READ | FILE_GENERIC_WRITE,
+    &ObjectAttributes,
+    &Iosb,
+    NULL,
+               0,
+               0,
+               FILE_OPEN,
+               FILE_SYNCHRONOUS_IO_ALERT,
+    EaInfo,
+    EaLength);
+
+  HeapFree(GlobalHeap, 0, EaInfo);
+
+  if (!NT_SUCCESS(Status)) {
+    AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
+      (UINT)Status));
+    return STATUS_INSUFFICIENT_RESOURCES;
+  }
+
+  *Socket = (SOCKET)FileHandle;
+
+  return STATUS_SUCCESS;
 }
 
 
 SOCKET
 WSPAPI
 WSPSocket(
-    IN  INT af,
-    IN  INT type,
-    IN  INT protocol,
-    IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
-    IN  GROUP g,
-    IN  DWORD dwFlags,
-    OUT LPINT lpErrno)
+  IN  INT af,
+  IN  INT type,
+  IN  INT protocol,
+  IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
+  IN  GROUP g,
+  IN  DWORD dwFlags,
+  OUT LPINT lpErrno)
 /*
  * FUNCTION: Creates a new socket
  * ARGUMENTS:
@@ -154,92 +162,91 @@ WSPSocket(
  *     Created socket, or INVALID_SOCKET if it could not be created
  */
 {
-    WSAPROTOCOL_INFOW ProtocolInfo;
-    UNICODE_STRING TdiDeviceName;
-    DWORD NotificationEvents;
-    PWSHELPER_DLL HelperDLL;
-    PVOID HelperContext;
-    INT AddressFamily;
-    NTSTATUS NtStatus;
-    INT SocketType;
-    SOCKET Socket2;
-    SOCKET Socket;
-    INT Protocol;
-    INT Status;
-
-    AFD_DbgPrint(MAX_TRACE, ("af (%d)  type (%d)  protocol (%d).\n",
-        af, type, protocol));
-
-    if (!lpProtocolInfo) {
-        CP
-        lpProtocolInfo = &ProtocolInfo;
-        ZeroMemory(&ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
-
-        ProtocolInfo.iAddressFamily = af;
-        ProtocolInfo.iSocketType    = type;
-        ProtocolInfo.iProtocol      = protocol;
-    }
+  WSAPROTOCOL_INFOW ProtocolInfo;
+  UNICODE_STRING TdiDeviceName;
+  DWORD NotificationEvents;
+  PWSHELPER_DLL HelperDLL;
+  PVOID HelperContext;
+  INT AddressFamily;
+  NTSTATUS NtStatus;
+  INT SocketType;
+  SOCKET Socket2;
+  SOCKET Socket;
+  INT Protocol;
+  INT Status;
+
+  AFD_DbgPrint(MAX_TRACE, ("af (%d)  type (%d)  protocol (%d).\n",
+    af, type, protocol));
+
+  if (!lpProtocolInfo) {
+    lpProtocolInfo = &ProtocolInfo;
+    ZeroMemory(&ProtocolInfo, sizeof(WSAPROTOCOL_INFOW));
+
+    ProtocolInfo.iAddressFamily = af;
+    ProtocolInfo.iSocketType    = type;
+    ProtocolInfo.iProtocol      = protocol;
+  }
 
-    HelperDLL = LocateHelperDLL(lpProtocolInfo);
-    if (!HelperDLL) {
-        *lpErrno = WSAEAFNOSUPPORT;
-        return INVALID_SOCKET;
-    }
+  HelperDLL = LocateHelperDLL(lpProtocolInfo);
+  if (!HelperDLL) {
+    *lpErrno = WSAEAFNOSUPPORT;
+    return INVALID_SOCKET;
+  }
 
-    AddressFamily = lpProtocolInfo->iAddressFamily;
-    SocketType    = lpProtocolInfo->iSocketType;
-    Protocol      = lpProtocolInfo->iProtocol;
-
-    Status = HelperDLL->EntryTable.lpWSHOpenSocket2(&AddressFamily,
-        &SocketType,
-        &Protocol,
-        0,
-        0,
-        &TdiDeviceName,
-        &HelperContext,
-        &NotificationEvents);
-    if (Status != NO_ERROR) {
-        *lpErrno = Status;
-        return INVALID_SOCKET;
-    }
+  AddressFamily = lpProtocolInfo->iAddressFamily;
+  SocketType    = lpProtocolInfo->iSocketType;
+  Protocol      = lpProtocolInfo->iProtocol;
+
+  Status = HelperDLL->EntryTable.lpWSHOpenSocket2(
+    &AddressFamily,
+    &SocketType,
+    &Protocol,
+    0,
+    0,
+    &TdiDeviceName,
+    &HelperContext,
+    &NotificationEvents);
+  if (Status != NO_ERROR) {
+    *lpErrno = Status;
+    return INVALID_SOCKET;
+  }
 
-    NtStatus = OpenSocket(&Socket,
-        AddressFamily,
-        SocketType,
-        Protocol,
-        HelperContext,
-        NotificationEvents,
-        &TdiDeviceName);
-
-    RtlFreeUnicodeString(&TdiDeviceName);
-    if (!NT_SUCCESS(NtStatus)) {
-        CP
-        *lpErrno = RtlNtStatusToDosError(Status);
-        return INVALID_SOCKET;
-    }
+  NtStatus = OpenSocket(&Socket,
+    AddressFamily,
+    SocketType,
+    Protocol,
+    HelperContext,
+    NotificationEvents,
+    &TdiDeviceName);
+
+  RtlFreeUnicodeString(&TdiDeviceName);
+  if (!NT_SUCCESS(NtStatus)) {
+    *lpErrno = RtlNtStatusToDosError(Status);
+    return INVALID_SOCKET;
+  }
 
-    /* FIXME: Assumes catalog entry id to be 1 */
-    Socket2 = Upcalls.lpWPUModifyIFSHandle(1, Socket, lpErrno);
+  /* FIXME: Assumes catalog entry id to be 1 */
+  Socket2 = Upcalls.lpWPUModifyIFSHandle(1, Socket, lpErrno);
 
-    if (Socket2 == INVALID_SOCKET) {
-        /* FIXME: Cleanup */
-        AFD_DbgPrint(MIN_TRACE, ("FIXME: Cleanup.\n"));
-        return INVALID_SOCKET;
-    }
+  if (Socket2 == INVALID_SOCKET) {
+    /* FIXME: Cleanup */
+    AFD_DbgPrint(MIN_TRACE, ("FIXME: Cleanup.\n"));
+    return INVALID_SOCKET;
+  }
 
-    *lpErrno = NO_ERROR;
+  *lpErrno = NO_ERROR;
 
-    AFD_DbgPrint(MID_TRACE, ("Returning socket descriptor (0x%X).\n", Socket2));
+  AFD_DbgPrint(MID_TRACE, ("Returning socket descriptor (0x%X).\n", Socket2));
 
-    return Socket2;
+  return Socket2;
 }
 
 
 INT
 WSPAPI
 WSPCloseSocket(
-    IN SOCKET s,
-    OUT        LPINT lpErrno)
+  IN  SOCKET s,
+  OUT  LPINT lpErrno)
 /*
  * FUNCTION: Closes an open socket
  * ARGUMENTS:
@@ -249,29 +256,29 @@ WSPCloseSocket(
  *     NO_ERROR, or SOCKET_ERROR if the socket could not be closed
  */
 {
-    NTSTATUS Status;
+  NTSTATUS Status;
 
-    AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
+  AFD_DbgPrint(MAX_TRACE, ("s (0x%X).\n", s));
 
-    Status = NtClose((HANDLE)s);
+  Status = NtClose((HANDLE)s);
 
-    if (NT_SUCCESS(Status)) {
-        *lpErrno = NO_ERROR;
-        return NO_ERROR;
-    }
+  if (NT_SUCCESS(Status)) {
+    *lpErrno = NO_ERROR;
+    return NO_ERROR;
+  }
 
-    *lpErrno = WSAENOTSOCK;
-    return SOCKET_ERROR;
+  *lpErrno = WSAENOTSOCK;
+  return SOCKET_ERROR;
 }
 
 
 INT
 WSPAPI
 WSPBind(
-    IN  SOCKET s,
-    IN  CONST LPSOCKADDR name, 
-    IN  INT namelen, 
-    OUT LPINT lpErrno)
+  IN  SOCKET s,
+  IN  CONST LPSOCKADDR name, 
+  IN  INT namelen, 
+  OUT LPINT lpErrno)
 /*
  * FUNCTION: Associates a local address with a socket
  * ARGUMENTS:
@@ -293,7 +300,8 @@ WSPBind(
 
   RtlCopyMemory(&Request.Name, name, sizeof(SOCKADDR));
 
-  Status = NtDeviceIoControlFile((HANDLE)s,
+  Status = NtDeviceIoControlFile(
+    (HANDLE)s,
     NULL,
                NULL,
                NULL,
@@ -324,12 +332,12 @@ WSPBind(
 INT
 WSPAPI
 WSPSelect(
-    IN      INT nfds,
-    IN OUT  LPFD_SET readfds,
-    IN OUT  LPFD_SET writefds,
-    IN OUT  LPFD_SET exceptfds,
-    IN      CONST LPTIMEVAL timeout,
-    OUT     LPINT lpErrno)
+  IN      INT nfds,
+  IN OUT  LPFD_SET readfds,
+  IN OUT  LPFD_SET writefds,
+  IN OUT  LPFD_SET exceptfds,
+  IN      CONST LPTIMEVAL timeout,
+  OUT     LPINT lpErrno)
 /*
  * FUNCTION: Returns status of one or more sockets
  * ARGUMENTS:
@@ -344,37 +352,221 @@ WSPSelect(
  *     Number of ready socket descriptors, or SOCKET_ERROR if an error ocurred
  */
 {
-    AFD_DbgPrint(MAX_TRACE, ("readfds (0x%X)  writefds (0x%X)  exceptfds (0x%X).\n",
+  PFILE_REQUEST_SELECT Request;
+  FILE_REPLY_SELECT Reply;
+  IO_STATUS_BLOCK Iosb;
+  NTSTATUS Status;
+  DWORD Size;
+  DWORD ReadSize;
+  DWORD WriteSize;
+  DWORD ExceptSize;
+  PVOID Current;
+
+  AFD_DbgPrint(MAX_TRACE, ("readfds (0x%X)  writefds (0x%X)  exceptfds (0x%X).\n",
         readfds, writefds, exceptfds));
+#if 0
+  /* FIXME: For now, all reads are timed out immediately */
+  if (readfds != NULL) {
+    AFD_DbgPrint(MID_TRACE, ("Timing out read query.\n"));
+    *lpErrno = WSAETIMEDOUT;
+    return SOCKET_ERROR;
+  }
 
-    /* FIXME: For now, all reads are timed out immediately */
-    if (readfds != NULL) {
-        AFD_DbgPrint(MIN_TRACE, ("Timing out read query.\n"));
-        *lpErrno = WSAETIMEDOUT;
-        return SOCKET_ERROR;
-    }
+  /* FIXME: For now, always allow write */
+  if (writefds != NULL) {
+    AFD_DbgPrint(MID_TRACE, ("Setting one socket writeable.\n"));
+    *lpErrno = NO_ERROR;
+    return 1;
+  }
+#endif
 
-    /* FIXME: For now, always allow write */
-    if (writefds != NULL) {
-        AFD_DbgPrint(MIN_TRACE, ("Setting one socket writeable.\n"));
-        *lpErrno = NO_ERROR;
-        return 1;
-    }
+  ReadSize = 0;
+  if ((readfds != NULL) && (readfds->fd_count > 0)) {
+    ReadSize = (readfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
+  }
 
-    *lpErrno = NO_ERROR;
+  WriteSize = 0;
+  if ((writefds != NULL) && (writefds->fd_count > 0)) {
+    WriteSize = (writefds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
+  }
 
-    return 0;
+  ExceptSize = 0;
+  if ((exceptfds != NULL) && (exceptfds->fd_count > 0)) {
+    ExceptSize = (exceptfds->fd_count * sizeof(SOCKET)) + sizeof(UINT);
+  }
+
+  Size = ReadSize + WriteSize + ExceptSize;
+
+  Request = (PFILE_REQUEST_SELECT)HeapAlloc(
+    GlobalHeap, 0, sizeof(FILE_REQUEST_SELECT) + Size);
+  if (!Request) {
+    *lpErrno = WSAENOBUFS;
+    return SOCKET_ERROR;
+  }
+
+  /* Put FD SETs after request structure */
+  Current = (Request + 1);
+
+  if (ReadSize > 0) {
+    Request->ReadFDSet = (LPFD_SET)Current;
+    Current += ReadSize;
+    RtlCopyMemory(Request->ReadFDSet, readfds, ReadSize);
+  } else {
+    Request->ReadFDSet = NULL;
+  }
+
+  if (WriteSize > 0) {
+    Request->WriteFDSet = (LPFD_SET)Current;
+    Current += WriteSize;
+    RtlCopyMemory(Request->WriteFDSet, writefds, WriteSize);
+  } else {
+    Request->WriteFDSet = NULL;
+  }
+
+  if (ExceptSize > 0) {
+    Request->ExceptFDSet = (LPFD_SET)Current;
+    RtlCopyMemory(Request->ExceptFDSet, exceptfds, ExceptSize);
+  } else {
+    Request->ExceptFDSet = NULL;
+  }
+
+  AFD_DbgPrint(MAX_TRACE, ("R1 (0x%X)  W1 (0x%X).\n", Request->ReadFDSet, Request->WriteFDSet));
+
+  Status = NtDeviceIoControlFile(
+    CommandChannel,
+    NULL,
+               NULL,
+               NULL,   
+               &Iosb,
+               IOCTL_AFD_SELECT,
+               Request,
+               sizeof(FILE_REQUEST_SELECT) + Size,
+               &Reply,
+               sizeof(FILE_REPLY_SELECT));
+
+  HeapFree(GlobalHeap, 0, Request);
+
+       if (Status == STATUS_PENDING) {
+    AFD_DbgPrint(MAX_TRACE, ("Waiting on transport.\n"));
+    /* FIXME: Wait only for blocking sockets */
+               Status = NtWaitForSingleObject(CommandChannel, FALSE, NULL);
+  }
+
+  if (!NT_SUCCESS(Status)) {
+    AFD_DbgPrint(MAX_TRACE, ("Status (0x%X).\n", Status));
+               *lpErrno = WSAENOBUFS;
+    return SOCKET_ERROR;
+       }
+
+  AFD_DbgPrint(MAX_TRACE, ("Select successful. Status (0x%X)  Count (0x%X).\n",
+    Reply.Status, Reply.SocketCount));
+
+  *lpErrno = Reply.Status;
+
+  return Reply.SocketCount;
+}
+
+
+NTSTATUS OpenCommandChannel(
+  VOID)
+/*
+ * FUNCTION: Opens a command channel to afd.sys
+ * ARGUMENTS:
+ *     None
+ * RETURNS:
+ *     Status of operation
+ */
+{
+  OBJECT_ATTRIBUTES ObjectAttributes;
+  PAFD_SOCKET_INFORMATION SocketInfo;
+  PFILE_FULL_EA_INFORMATION EaInfo;
+  UNICODE_STRING DeviceName;
+  IO_STATUS_BLOCK Iosb;
+  HANDLE FileHandle;
+  NTSTATUS Status;
+  ULONG EaLength;
+  ULONG EaShort;
+
+  AFD_DbgPrint(MAX_TRACE, ("Called\n"));
+
+    EaShort = sizeof(FILE_FULL_EA_INFORMATION) +
+    AFD_SOCKET_LENGTH +
+    sizeof(AFD_SOCKET_INFORMATION);
+
+  EaLength = EaShort;
+
+  EaInfo = (PFILE_FULL_EA_INFORMATION)HeapAlloc(GlobalHeap, 0, EaLength);
+  if (!EaInfo) {
+    return STATUS_INSUFFICIENT_RESOURCES;
+  }
+
+  RtlZeroMemory(EaInfo, EaLength);
+  EaInfo->EaNameLength = AFD_SOCKET_LENGTH;
+  RtlCopyMemory(EaInfo->EaName,
+    AfdSocket,
+    AFD_SOCKET_LENGTH);
+  EaInfo->EaValueLength = sizeof(AFD_SOCKET_INFORMATION);
+
+  SocketInfo = (PAFD_SOCKET_INFORMATION)(EaInfo->EaName + AFD_SOCKET_LENGTH);
+  SocketInfo->CommandChannel = TRUE;
+
+  RtlInitUnicodeString(&DeviceName, L"\\Device\\Afd");
+       InitializeObjectAttributes(
+    &ObjectAttributes,
+    &DeviceName,
+    0,
+    NULL,
+    NULL);
+
+  Status = NtCreateFile(
+    &FileHandle,
+    FILE_GENERIC_READ | FILE_GENERIC_WRITE,
+    &ObjectAttributes,
+    &Iosb,
+    NULL,
+               0,
+               0,
+               FILE_OPEN,
+               FILE_SYNCHRONOUS_IO_ALERT,
+    EaInfo,
+    EaLength);
+
+  if (!NT_SUCCESS(Status)) {
+    AFD_DbgPrint(MIN_TRACE, ("Error opening device (Status 0x%X).\n",
+      (UINT)Status));
+    return Status;
+  }
+
+  CommandChannel = FileHandle;
+
+  return STATUS_SUCCESS;
+}
+
+
+NTSTATUS CloseCommandChannel(
+  VOID)
+/*
+ * FUNCTION: Closes command channel to afd.sys
+ * ARGUMENTS:
+ *     None
+ * RETURNS:
+ *     Status of operation
+ */
+{
+  AFD_DbgPrint(MAX_TRACE, ("Called.\n"));
+
+  return NtClose(CommandChannel);
 }
 
 
 INT
 WSPAPI
 WSPStartup(
-    IN  WORD wVersionRequested,
-    OUT LPWSPDATA lpWSPData,
-    IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
-    IN  WSPUPCALLTABLE UpcallTable,
-    OUT LPWSPPROC_TABLE lpProcTable)
+  IN  WORD wVersionRequested,
+  OUT LPWSPDATA lpWSPData,
+  IN  LPWSAPROTOCOL_INFOW lpProtocolInfo,
+  IN  WSPUPCALLTABLE UpcallTable,
+  OUT LPWSPPROC_TABLE lpProcTable)
 /*
  * FUNCTION: Initialize service provider for a client
  * ARGUMENTS:
@@ -387,89 +579,89 @@ WSPStartup(
  *     Status of operation
  */
 {
-    HMODULE hWS2_32;
-    INT Status;
+  HMODULE hWS2_32;
+  INT Status;
 
-    AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
+  AFD_DbgPrint(MAX_TRACE, ("wVersionRequested (0x%X) \n", wVersionRequested));
 
-    //EnterCriticalSection(&InitCriticalSection);
+  EnterCriticalSection(&InitCriticalSection);
 
-    Upcalls = UpcallTable;
+  Upcalls = UpcallTable;
 
-    if (StartupCount == 0) {
-        /* First time called */
-
-        Status = WSAVERNOTSUPPORTED;
+  if (StartupCount == 0) {
+    /* First time called */
 
-        hWS2_32 = GetModuleHandle(L"ws2_32.dll");
+    Status = WSAVERNOTSUPPORTED;
 
-        if (hWS2_32 != NULL) {
-            lpWPUCompleteOverlappedRequest = (LPWPUCOMPLETEOVERLAPPEDREQUEST)
-                GetProcAddress(hWS2_32, "WPUCompleteOverlappedRequest");
-
-            if (lpWPUCompleteOverlappedRequest != NULL) {
-                Status = NO_ERROR;
-                StartupCount++;
-                CP
-            }
-        } else {
-            AFD_DbgPrint(MIN_TRACE, ("GetModuleHandle() failed for ws2_32.dll\n"));
+    Status = OpenCommandChannel();
+    if (NT_SUCCESS(Status)) {
+      hWS2_32 = GetModuleHandle(L"ws2_32.dll");
+      if (hWS2_32 != NULL) {
+        lpWPUCompleteOverlappedRequest = (LPWPUCOMPLETEOVERLAPPEDREQUEST)
+          GetProcAddress(hWS2_32, "WPUCompleteOverlappedRequest");
+        if (lpWPUCompleteOverlappedRequest != NULL) {
+          Status = NO_ERROR;
+          StartupCount++;
         }
+      } else {
+        AFD_DbgPrint(MIN_TRACE, ("GetModuleHandle() failed for ws2_32.dll\n"));
+      }
     } else {
-        Status = NO_ERROR;
-        StartupCount++;
+      AFD_DbgPrint(MIN_TRACE, ("Cannot open afd.sys\n"));
     }
+  } else {
+    Status = NO_ERROR;
+    StartupCount++;
+  }
 
-    //LeaveCriticalSection(&InitCriticalSection);
-
-    AFD_DbgPrint(MIN_TRACE, ("WSPSocket() is at 0x%X\n", WSPSocket));
-
-    if (Status == NO_ERROR) {
-        lpProcTable->lpWSPAccept = WSPAccept;
-        lpProcTable->lpWSPAddressToString = WSPAddressToString;
-        lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
-        lpProcTable->lpWSPBind = WSPBind;
-        lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
-        lpProcTable->lpWSPCleanup = WSPCleanup;
-        lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
-        lpProcTable->lpWSPConnect = WSPConnect;
-        lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
-        lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
-        lpProcTable->lpWSPEventSelect = WSPEventSelect;
-        lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
-        lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
-        lpProcTable->lpWSPGetSockName = WSPGetSockName;
-        lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
-        lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
-        lpProcTable->lpWSPIoctl = WSPIoctl;
-        lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
-        lpProcTable->lpWSPListen = WSPListen;
-        lpProcTable->lpWSPRecv = WSPRecv;
-        lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
-        lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
-        lpProcTable->lpWSPSelect = WSPSelect;
-        lpProcTable->lpWSPSend = WSPSend;
-        lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
-        lpProcTable->lpWSPSendTo = WSPSendTo;
-        lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
-        lpProcTable->lpWSPShutdown = WSPShutdown;
-        lpProcTable->lpWSPSocket = WSPSocket;
-        lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
-
-        lpWSPData->wVersion     = MAKEWORD(2, 2);
-        lpWSPData->wHighVersion = MAKEWORD(2, 2);
-    }
+  LeaveCriticalSection(&InitCriticalSection);
+
+  if (Status == NO_ERROR) {
+    lpProcTable->lpWSPAccept = WSPAccept;
+    lpProcTable->lpWSPAddressToString = WSPAddressToString;
+    lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
+    lpProcTable->lpWSPBind = WSPBind;
+    lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
+    lpProcTable->lpWSPCleanup = WSPCleanup;
+    lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
+    lpProcTable->lpWSPConnect = WSPConnect;
+    lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
+    lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
+    lpProcTable->lpWSPEventSelect = WSPEventSelect;
+    lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
+    lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
+    lpProcTable->lpWSPGetSockName = WSPGetSockName;
+    lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
+    lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
+    lpProcTable->lpWSPIoctl = WSPIoctl;
+    lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
+    lpProcTable->lpWSPListen = WSPListen;
+    lpProcTable->lpWSPRecv = WSPRecv;
+    lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
+    lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
+    lpProcTable->lpWSPSelect = WSPSelect;
+    lpProcTable->lpWSPSend = WSPSend;
+    lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
+    lpProcTable->lpWSPSendTo = WSPSendTo;
+    lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
+    lpProcTable->lpWSPShutdown = WSPShutdown;
+    lpProcTable->lpWSPSocket = WSPSocket;
+    lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
+
+    lpWSPData->wVersion     = MAKEWORD(2, 2);
+    lpWSPData->wHighVersion = MAKEWORD(2, 2);
+  }
 
-    AFD_DbgPrint(MIN_TRACE, ("Status (%d).\n", Status));
+  AFD_DbgPrint(MAX_TRACE, ("Status (%d).\n", Status));
 
-    return Status;
+  return Status;
 }
 
 
 INT
 WSPAPI
 WSPCleanup(
-    OUT LPINT lpErrno)
+  OUT LPINT lpErrno)
 /*
  * FUNCTION: Cleans up service provider for a client
  * ARGUMENTS:
@@ -478,23 +670,25 @@ WSPCleanup(
  *     0 if successful, or SOCKET_ERROR if not
  */
 {
-    AFD_DbgPrint(MAX_TRACE, ("\n"));
+  AFD_DbgPrint(MAX_TRACE, ("\n"));
 
-    //EnterCriticalSection(&InitCriticalSection);
+  EnterCriticalSection(&InitCriticalSection);
 
-    if (StartupCount > 0) {
-        StartupCount--;
+  if (StartupCount > 0) {
+    StartupCount--;
 
-        if (StartupCount == 0) {
-            AFD_DbgPrint(MAX_TRACE, ("Cleaning up msafd.dll.\n"));
-        }
+    if (StartupCount == 0) {
+      AFD_DbgPrint(MAX_TRACE, ("Cleaning up msafd.dll.\n"));
+
+      CloseCommandChannel();
     }
+  }
 
-    //LeaveCriticalSection(&InitCriticalSection);
+  LeaveCriticalSection(&InitCriticalSection);
 
-    *lpErrno = NO_ERROR;
+  *lpErrno = NO_ERROR;
 
-    return 0;
+  return 0;
 }
 
 
@@ -504,7 +698,7 @@ DllMain(HANDLE hInstDll,
         ULONG dwReason,
         PVOID Reserved)
 {
-    AFD_DbgPrint(MIN_TRACE, ("DllMain of msafd.dll\n"));
+    AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll\n"));
 
     switch (dwReason) {
     case DLL_PROCESS_ATTACH:
@@ -512,14 +706,9 @@ DllMain(HANDLE hInstDll,
            so disable them to improve performance */
         DisableThreadLibraryCalls(hInstDll);
 
-        //InitializeCriticalSection(&InitCriticalSection);
+        InitializeCriticalSection(&InitCriticalSection);
 
         GlobalHeap = GetProcessHeap();
-        //GlobalHeap = HeapCreate(0, 0, 0);
-        if (!GlobalHeap) {
-            AFD_DbgPrint(MIN_TRACE, ("Insufficient memory.\n"));
-            return FALSE;
-        }
 
         CreateHelperDLLDatabase();
         break;
@@ -532,12 +721,13 @@ DllMain(HANDLE hInstDll,
 
     case DLL_PROCESS_DETACH:
         DestroyHelperDLLDatabase();
-        //HeapDestroy(GlobalHeap);
 
-        //DeleteCriticalSection(&InitCriticalSection);
+        DeleteCriticalSection(&InitCriticalSection);
         break;
     }
 
+    AFD_DbgPrint(MAX_TRACE, ("DllMain of msafd.dll (leaving)\n"));
+
     return TRUE;
 }