- Merge some small changes from aicom-network-branch to fix potential memory corrupt...
[reactos.git] / reactos / dll / win32 / wshtcpip / wshtcpip.c
index 91c7fba..65d7cb0 100644 (file)
@@ -156,9 +156,39 @@ WSHGetSockaddrType(
     return NO_ERROR;
 }
 
-
-
-
+UINT
+GetAddressOption(INT Level, INT OptionName)
+{
+    switch (Level)
+    {
+       case IPPROTO_IP:
+          switch (OptionName)
+          {
+             case IP_TTL:
+                return AO_OPTION_TTL;
+
+             case IP_DONTFRAGMENT:
+                return AO_OPTION_IP_DONTFRAGMENT;
+
+#if 0
+             case IP_RECEIVE_BROADCAST:
+                return AO_OPTION_BROADCAST;
+#endif
+
+             case IP_HDRINCL:
+                return AO_OPTION_IP_HDRINCL;
+
+             default:
+                DPRINT1("Unknown option name for IPPROTO_IP: %d\n", OptionName);
+                return 0;
+          }
+          break;
+
+       default:
+          DPRINT1("Unknown level: %d\n", Level);
+          return 0;
+    }
+}
 
 INT
 EXPORT
@@ -170,7 +200,7 @@ WSHGetSocketInformation(
     IN  INT Level,
     IN  INT OptionName,
     OUT PCHAR OptionValue,
-    OUT INT OptionLength)
+    OUT LPINT OptionLength)
 {
     UNIMPLEMENTED
 
@@ -309,6 +339,37 @@ WSHJoinLeaf(
     return NO_ERROR;
 }
 
+INT
+SendRequest(
+    IN PVOID Request,
+    IN DWORD RequestSize,
+    IN DWORD IOCTL)
+{
+    BOOLEAN Status;
+    HANDLE TcpCC;
+    DWORD BytesReturned;
+
+    if (openTcpFile(&TcpCC) != STATUS_SUCCESS)
+        return WSAEINVAL;
+
+    Status = DeviceIoControl(TcpCC,
+                             IOCTL,
+                             Request,
+                             RequestSize,
+                             NULL,
+                             0,
+                             &BytesReturned,
+                             NULL);
+
+    closeTcpFile(TcpCC);
+
+    DPRINT("DeviceIoControl: %d\n", ((Status == TRUE) ? 0 : GetLastError()));
+
+    if (!Status)
+        return WSAEINVAL;
+
+    return NO_ERROR;
+}
 
 INT
 EXPORT
@@ -319,15 +380,86 @@ WSHNotify(
     IN  HANDLE TdiConnectionObjectHandle,
     IN  DWORD NotifyEvent)
 {
+    PSOCKET_CONTEXT Context = HelperDllSocketContext;
+    NTSTATUS Status;
+    HANDLE TcpCC;
+    TDIEntityID *EntityIDs;
+    DWORD EntityCount, i;
+    PQUEUED_REQUEST QueuedRequest, NextQueuedRequest;
+
     switch (NotifyEvent)
     {
         case WSH_NOTIFY_CLOSE:
-        HeapFree(GetProcessHeap(), 0, HelperDllSocketContext);
-        break;
+            DPRINT("WSHNotify: WSH_NOTIFY_CLOSE\n");
+            QueuedRequest = Context->RequestQueue;
+            while (QueuedRequest)
+            {
+                NextQueuedRequest = QueuedRequest->Next;
+
+                HeapFree(GetProcessHeap(), 0, QueuedRequest->Info);
+                HeapFree(GetProcessHeap(), 0, QueuedRequest);
+
+                QueuedRequest = NextQueuedRequest;
+            }
+            HeapFree(GetProcessHeap(), 0, HelperDllSocketContext);
+            break;
+
+
+        case WSH_NOTIFY_BIND:
+            DPRINT("WSHNotify: WSH_NOTIFY_BIND\n");
+            Status = openTcpFile(&TcpCC);
+            if (Status != STATUS_SUCCESS)
+                return WSAEINVAL;
+
+            Status = tdiGetEntityIDSet(TcpCC,
+                                       &EntityIDs,
+                                       &EntityCount);
 
+            closeTcpFile(TcpCC);
+
+            if (Status != STATUS_SUCCESS)
+                return WSAEINVAL;
+
+            for (i = 0; i < EntityCount; i++)
+            {
+                if (EntityIDs[i].tei_entity == CO_TL_ENTITY ||
+                    EntityIDs[i].tei_entity == CL_TL_ENTITY ||
+                    EntityIDs[i].tei_entity == ER_ENTITY)
+                {
+                    Context->AddrFileInstance = EntityIDs[i].tei_instance;
+                    Context->AddrFileEntityType = EntityIDs[i].tei_entity;
+                }
+            }
+
+            DPRINT("Instance: %x Type: %x\n", Context->AddrFileInstance, Context->AddrFileEntityType);
+
+            tdiFreeThingSet(EntityIDs);
+
+            Context->SocketState = SocketStateBound;
+
+            QueuedRequest = Context->RequestQueue;
+            while (QueuedRequest)
+            {
+                QueuedRequest->Info->ID.toi_entity.tei_entity = Context->AddrFileEntityType;
+                QueuedRequest->Info->ID.toi_entity.tei_instance = Context->AddrFileInstance;
+
+                SendRequest(QueuedRequest->Info,
+                            sizeof(*QueuedRequest->Info) + QueuedRequest->Info->BufferSize,
+                            IOCTL_TCP_SET_INFORMATION_EX);
+
+                NextQueuedRequest = QueuedRequest->Next;
+
+                HeapFree(GetProcessHeap(), 0, QueuedRequest->Info);
+                HeapFree(GetProcessHeap(), 0, QueuedRequest);
+
+                QueuedRequest = NextQueuedRequest;
+            }
+            Context->RequestQueue = NULL;
+            break;
+            
         default:
-        DPRINT1("Unwanted notification received! (%d)\n", NotifyEvent);
-        break;
+            DPRINT1("Unwanted notification received! (%d)\n", NotifyEvent);
+            break;
     }
 
     return NO_ERROR;
@@ -450,7 +582,7 @@ WSHOpenSocket2(
 
     /* Setup a socket context area */
 
-    Context = HeapAlloc(GetProcessHeap(), 0, sizeof(SOCKET_CONTEXT));
+    Context = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(SOCKET_CONTEXT));
     if (!Context) {
         RtlFreeUnicodeString(TransportDeviceName);
         return WSAENOBUFS;
@@ -460,14 +592,14 @@ WSHOpenSocket2(
     Context->SocketType    = *SocketType;
     Context->Protocol      = *Protocol;
     Context->Flags         = Flags;
+    Context->SocketState   = SocketStateCreated;
 
     *HelperDllSocketContext = Context;
-    *NotificationEvents = WSH_NOTIFY_CLOSE;
+    *NotificationEvents = WSH_NOTIFY_CLOSE | WSH_NOTIFY_BIND;
 
     return NO_ERROR;
 }
 
-
 INT
 EXPORT
 WSHSetSocketInformation(
@@ -480,9 +612,76 @@ WSHSetSocketInformation(
     IN  PCHAR OptionValue,
     IN  INT OptionLength)
 {
-    UNIMPLEMENTED
+    PSOCKET_CONTEXT Context = HelperDllSocketContext;
+    UINT RealOptionName;
+    INT Status;
+    PTCP_REQUEST_SET_INFORMATION_EX Info;
+    PQUEUED_REQUEST Queued, NextQueued;
 
-    return NO_ERROR;
+    DPRINT("WSHSetSocketInformation\n");
+
+    /* FIXME: We only handle address file object here */
+
+    RealOptionName = GetAddressOption(Level, OptionName);
+
+    /* FIXME: Support all options */
+    if (!RealOptionName)
+        return 0; /* return WSAEINVAL; */
+
+    Info = HeapAlloc(GetProcessHeap(), 0, sizeof(*Info) + OptionLength);
+    if (!Info)
+        return WSAENOBUFS;
+
+    Info->ID.toi_entity.tei_entity = Context->AddrFileEntityType;
+    Info->ID.toi_entity.tei_instance = Context->AddrFileInstance;
+    Info->ID.toi_class = INFO_CLASS_PROTOCOL;
+    Info->ID.toi_type = INFO_TYPE_ADDRESS_OBJECT;
+    Info->ID.toi_id = RealOptionName;
+    Info->BufferSize = OptionLength;
+    memcpy(Info->Buffer, OptionValue, OptionLength);
+
+    if (Context->SocketState == SocketStateCreated)
+    {
+        if (Context->RequestQueue)
+        {
+            Queued = Context->RequestQueue;
+            while ((NextQueued = Queued->Next))
+            {
+               Queued = NextQueued;
+            }
+
+            Queued->Next = HeapAlloc(GetProcessHeap(), 0, sizeof(QUEUED_REQUEST));
+            if (!Queued->Next)
+            {
+                HeapFree(GetProcessHeap(), 0, Info);
+                return WSAENOBUFS;
+            }
+
+            NextQueued = Queued->Next;
+            NextQueued->Next = NULL;
+            NextQueued->Info = Info;
+        }
+        else
+        {
+            Context->RequestQueue = HeapAlloc(GetProcessHeap(), 0, sizeof(QUEUED_REQUEST));
+            if (!Context->RequestQueue)
+            {
+                HeapFree(GetProcessHeap(), 0, Info);
+                return WSAENOBUFS;
+            }
+
+            Context->RequestQueue->Next = NULL;
+            Context->RequestQueue->Info = Info;
+        }
+
+        return 0;
+    }
+
+    Status = SendRequest(Info, sizeof(*Info) + Info->BufferSize, IOCTL_TCP_SET_INFORMATION_EX);
+
+    HeapFree(GetProcessHeap(), 0, Info);
+
+    return Status;
 }