[IPHLPAPI] Don't leak memory
[reactos.git] / dll / win32 / iphlpapi / iphlpapi_main.c
index af8cbd1..404c1b7 100644 (file)
@@ -767,18 +767,19 @@ DWORD WINAPI GetBestRoute(DWORD dwDestAddr, DWORD dwSourceAddr, PMIB_IPFORWARDRO
 
   AllocateAndGetIpForwardTableFromStack(&table, FALSE, GetProcessHeap(), 0);
   if (table) {
-    DWORD ndx, matchedBits, matchedNdx = 0;
+    DWORD ndx, minMaskSize, matchedNdx = 0;
 
-    for (ndx = 0, matchedBits = 0; ndx < table->dwNumEntries; ndx++) {
+    for (ndx = 0, minMaskSize = 255; ndx < table->dwNumEntries; ndx++) {
       if ((dwDestAddr & table->table[ndx].dwForwardMask) ==
        (table->table[ndx].dwForwardDest & table->table[ndx].dwForwardMask)) {
-        DWORD numShifts, mask;
+        DWORD hostMaskSize;
 
-        for (numShifts = 0, mask = table->table[ndx].dwForwardMask;
-         mask && !(mask & 1); mask >>= 1, numShifts++)
-          ;
-        if (numShifts > matchedBits) {
-          matchedBits = numShifts;
+        if (!_BitScanForward(&hostMaskSize, ntohl(table->table[ndx].dwForwardMask)))
+        {
+            hostMaskSize = 32;
+        }
+        if (hostMaskSize < minMaskSize) {
+          minMaskSize = hostMaskSize;
           matchedNdx = ndx;
         }
       }
@@ -793,6 +794,8 @@ DWORD WINAPI GetBestRoute(DWORD dwDestAddr, DWORD dwSourceAddr, PMIB_IPFORWARDRO
   return ret;
 }
 
+static int TcpTableSorter(const void *a, const void *b);
+
 /******************************************************************
  *    GetExtendedTcpTable (IPHLPAPI.@)
  *
@@ -815,20 +818,237 @@ DWORD WINAPI GetBestRoute(DWORD dwDestAddr, DWORD dwSourceAddr, PMIB_IPFORWARDRO
 
 DWORD WINAPI GetExtendedTcpTable(PVOID pTcpTable, PDWORD pdwSize, BOOL bOrder, ULONG ulAf, TCP_TABLE_CLASS TableClass, ULONG Reserved)
 {
+    DWORD i, count;
        DWORD ret = NO_ERROR;
 
-  if (TableClass == TCP_TABLE_OWNER_PID_ALL) {
-    if (*pdwSize == 0) {
-      *pdwSize = sizeof(MIB_TCPTABLE_OWNER_PID);
-      return ERROR_INSUFFICIENT_BUFFER; 
-    } else {
-      ZeroMemory(pTcpTable, sizeof(MIB_TCPTABLE_OWNER_PID));
-      return NO_ERROR;
+    if (!pdwSize)
+    {
+        return ERROR_INVALID_PARAMETER;
     }
-  }
 
+    if (ulAf != AF_INET)
+    {
+        UNIMPLEMENTED;
+        return ERROR_INVALID_PARAMETER;
+    }
+
+    switch (TableClass)
+    {
+        case TCP_TABLE_BASIC_ALL:
+            ret = GetTcpTable(pTcpTable, pdwSize, bOrder);
+            break;
+
+        case TCP_TABLE_BASIC_CONNECTIONS:
+        {
+            PMIB_TCPTABLE pOurTcpTable = getTcpTable();
+            PMIB_TCPTABLE pTheirTcpTable = pTcpTable;
+
+            if (pOurTcpTable)
+            {
+                for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                {
+                    if (pOurTcpTable->table[i].State != MIB_TCP_STATE_LISTEN)
+                    {
+                        ++count;
+                    }
+                }
+
+                if (sizeof(MIB_TCPTABLE) + count * sizeof(MIB_TCPROW) > *pdwSize || !pTheirTcpTable)
+                {
+                    *pdwSize = sizeof(MIB_TCPTABLE) + count * sizeof(MIB_TCPROW);
+                    ret = ERROR_INSUFFICIENT_BUFFER;
+                }
+                else
+                {
+                    pTheirTcpTable->dwNumEntries = count;
+
+                    for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                    {
+                        if (pOurTcpTable->table[i].State != MIB_TCP_STATE_LISTEN)
+                        {
+                            memcpy(&pTheirTcpTable->table[count], &pOurTcpTable->table[i], sizeof(MIB_TCPROW));
+                            ++count;
+                        }
+                    }
+                    ASSERT(count == pTheirTcpTable->dwNumEntries);
+
+                    if (bOrder)
+                        qsort(pTheirTcpTable->table, pTheirTcpTable->dwNumEntries,
+                              sizeof(MIB_TCPROW), TcpTableSorter);
+                }
+
+                free(pOurTcpTable);
+            }
+        }
+        break;
+
+        case TCP_TABLE_BASIC_LISTENER:
+        {
+            PMIB_TCPTABLE pOurTcpTable = getTcpTable();
+            PMIB_TCPTABLE pTheirTcpTable = pTcpTable;
+
+            if (pOurTcpTable)
+            {
+                for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                {
+                    if (pOurTcpTable->table[i].State == MIB_TCP_STATE_LISTEN)
+                    {
+                        ++count;
+                    }
+                }
+
+                if (sizeof(MIB_TCPTABLE) + count * sizeof(MIB_TCPROW) > *pdwSize || !pTheirTcpTable)
+                {
+                    *pdwSize = sizeof(MIB_TCPTABLE) + count * sizeof(MIB_TCPROW);
+                    ret = ERROR_INSUFFICIENT_BUFFER;
+                }
+                else
+                {
+                    pTheirTcpTable->dwNumEntries = count;
+
+                    for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                    {
+                        if (pOurTcpTable->table[i].State == MIB_TCP_STATE_LISTEN)
+                        {
+                            memcpy(&pTheirTcpTable->table[count], &pOurTcpTable->table[i], sizeof(MIB_TCPROW));
+                            ++count;
+                        }
+                    }
+                    ASSERT(count == pTheirTcpTable->dwNumEntries);
+
+                    if (bOrder)
+                        qsort(pTheirTcpTable->table, pTheirTcpTable->dwNumEntries,
+                              sizeof(MIB_TCPROW), TcpTableSorter);
+                }
+
+                free(pOurTcpTable);
+            }
+        }
+        break;
+
+        case TCP_TABLE_OWNER_PID_ALL:
+        {
+            PMIB_TCPTABLE_OWNER_PID pOurTcpTable = getOwnerTcpTable();
+            PMIB_TCPTABLE_OWNER_PID pTheirTcpTable = pTcpTable;
+
+            if (pOurTcpTable)
+            {
+                if (sizeof(DWORD) + pOurTcpTable->dwNumEntries * sizeof(MIB_TCPROW_OWNER_PID) > *pdwSize || !pTheirTcpTable)
+                {
+                    *pdwSize = sizeof(DWORD) + pOurTcpTable->dwNumEntries * sizeof(MIB_TCPROW_OWNER_PID);
+                    ret = ERROR_INSUFFICIENT_BUFFER;
+                }
+                else
+                {
+                    memcpy(pTheirTcpTable, pOurTcpTable, sizeof(DWORD) + pOurTcpTable->dwNumEntries * sizeof(MIB_TCPROW_OWNER_PID));
+
+                    /* Don't sort on PID, so use basic helper */
+                    if (bOrder)
+                        qsort(pTheirTcpTable->table, pTheirTcpTable->dwNumEntries,
+                              sizeof(MIB_TCPROW_OWNER_PID), TcpTableSorter);
+                }
+
+                free(pOurTcpTable);
+            }
+        }
+        break;
+
+        case TCP_TABLE_OWNER_PID_CONNECTIONS:
+        {
+            PMIB_TCPTABLE_OWNER_PID pOurTcpTable = getOwnerTcpTable();
+            PMIB_TCPTABLE_OWNER_PID pTheirTcpTable = pTcpTable;
+
+            if (pOurTcpTable)
+            {
+                for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                {
+                    if (pOurTcpTable->table[i].dwState != MIB_TCP_STATE_LISTEN)
+                    {
+                        ++count;
+                    }
+                }
+
+                if (sizeof(DWORD) + count * sizeof(MIB_TCPROW_OWNER_PID) > *pdwSize || !pTheirTcpTable)
+                {
+                    *pdwSize = sizeof(DWORD) + count * sizeof(MIB_TCPROW_OWNER_PID);
+                    ret = ERROR_INSUFFICIENT_BUFFER;
+                }
+                else
+                {
+                    pTheirTcpTable->dwNumEntries = count;
+
+                    for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                    {
+                        if (pOurTcpTable->table[i].dwState != MIB_TCP_STATE_LISTEN)
+                        {
+                            memcpy(&pTheirTcpTable->table[count], &pOurTcpTable->table[i], sizeof(MIB_TCPROW_OWNER_PID));
+                            ++count;
+                        }
+                    }
+                    ASSERT(count == pTheirTcpTable->dwNumEntries);
+
+                    /* Don't sort on PID, so use basic helper */
+                    if (bOrder)
+                        qsort(pTheirTcpTable->table, pTheirTcpTable->dwNumEntries,
+                              sizeof(MIB_TCPROW_OWNER_PID), TcpTableSorter);
+                }
+
+                free(pOurTcpTable);
+            }
+        }
+        break;
+
+        case TCP_TABLE_OWNER_PID_LISTENER:
+        {
+            PMIB_TCPTABLE_OWNER_PID pOurTcpTable = getOwnerTcpTable();
+            PMIB_TCPTABLE_OWNER_PID pTheirTcpTable = pTcpTable;
+
+            if (pOurTcpTable)
+            {
+                for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                {
+                    if (pOurTcpTable->table[i].dwState == MIB_TCP_STATE_LISTEN)
+                    {
+                        ++count;
+                    }
+                }
+
+                if (sizeof(DWORD) + count * sizeof(MIB_TCPROW_OWNER_PID) > *pdwSize || !pTheirTcpTable)
+                {
+                    *pdwSize = sizeof(DWORD) + count * sizeof(MIB_TCPROW_OWNER_PID);
+                    ret = ERROR_INSUFFICIENT_BUFFER;
+                }
+                else
+                {
+                    pTheirTcpTable->dwNumEntries = count;
+
+                    for (i = 0, count = 0; i < pOurTcpTable->dwNumEntries; ++i)
+                    {
+                        if (pOurTcpTable->table[i].dwState == MIB_TCP_STATE_LISTEN)
+                        {
+                            memcpy(&pTheirTcpTable->table[count], &pOurTcpTable->table[i], sizeof(MIB_TCPROW_OWNER_PID));
+                            ++count;
+                        }
+                    }
+                    ASSERT(count == pTheirTcpTable->dwNumEntries);
+
+                    /* Don't sort on PID, so use basic helper */
+                    if (bOrder)
+                        qsort(pTheirTcpTable->table, pTheirTcpTable->dwNumEntries,
+                              sizeof(MIB_TCPROW_OWNER_PID), TcpTableSorter);
+                }
+
+                free(pOurTcpTable);
+            }
+        }
+        break;
+
+        default:
+            UNIMPLEMENTED;
+            ret = ERROR_INVALID_PARAMETER;
+            break;
+    }
 
-    UNIMPLEMENTED;
     return ret;        
 }