[IPHLPAPI] Add support for service tags in GetOwnerModuleFromTcp/UdpEntry
[reactos.git] / dll / win32 / iphlpapi / iphlpapi_main.c
index 162731c..83d53d2 100644 (file)
@@ -2347,6 +2347,84 @@ static DWORD GetOwnerModuleFromPidEntry(DWORD OwningPid, TCPIP_OWNER_MODULE_INFO
     return NO_ERROR;
 }
 
+static DWORD GetOwnerModuleFromTagEntry(DWORD OwningPid, DWORD OwningTag, TCPIP_OWNER_MODULE_INFO_CLASS Class, PVOID Buffer, PDWORD pdwSize)
+{
+    UINT Size;
+    HRESULT Res;
+    HANDLE hAdvapi32;
+    WCHAR SysDir[MAX_PATH];
+    PTCPIP_OWNER_MODULE_BASIC_INFO BasicInfo;
+    ULONG (NTAPI *_I_QueryTagInformation)(PVOID, DWORD, PVOID);
+    struct
+    {
+        DWORD ProcessId;
+        DWORD ServiceTag;
+        DWORD TagType;
+        PWSTR Buffer;
+    } ServiceQuery;
+
+    /* First, secure (avoid injections) load advapi32.dll */
+    Size = GetSystemDirectoryW(SysDir, MAX_PATH);
+    if (Size == 0)
+    {
+        return GetLastError();
+    }
+
+    Res = StringCchCatW(&SysDir[Size], MAX_PATH - Size, L"\\advapi32.dll");
+    if (FAILED(Res))
+    {
+        return Res;
+    }
+
+    hAdvapi32 = GetModuleHandleW(SysDir);
+    if (hAdvapi32 == NULL)
+    {
+        return GetLastError();
+    }
+
+    /* Now, we'll query the service associated with the tag */
+    _I_QueryTagInformation = (PVOID)GetProcAddress(hAdvapi32, "I_QueryTagInformation");
+    if (_I_QueryTagInformation == NULL)
+    {
+        return GetLastError();
+    }
+
+    /* Set tag and PID for the query */
+    ServiceQuery.ProcessId = OwningPid;
+    ServiceQuery.ServiceTag = OwningTag;
+    ServiceQuery.TagType = 0;
+    ServiceQuery.Buffer = NULL;
+
+    /* And query */
+    Res = _I_QueryTagInformation(NULL, 1, &ServiceQuery);
+    if (Res != ERROR_SUCCESS)
+    {
+        return Res;
+    }
+
+    /* Compute service name length */
+    Size = wcslen(ServiceQuery.Buffer) * sizeof(WCHAR) + sizeof(UNICODE_NULL);
+
+    /* We'll copy it twice, so make sure we have enough room */
+    if (*pdwSize < sizeof(TCPIP_OWNER_MODULE_BASIC_INFO) + 2 * Size)
+    {
+        *pdwSize = sizeof(TCPIP_OWNER_MODULE_BASIC_INFO) + 2 * Size;
+        LocalFree(ServiceQuery.Buffer);
+        return ERROR_INSUFFICIENT_BUFFER;
+    }
+
+    /* Copy back data */
+    BasicInfo = Buffer;
+    BasicInfo->pModuleName = (PVOID)((ULONG_PTR)BasicInfo + sizeof(TCPIP_OWNER_MODULE_BASIC_INFO));
+    BasicInfo->pModulePath = (PVOID)((ULONG_PTR)BasicInfo->pModuleName + Size);
+    wcscpy(BasicInfo->pModuleName, ServiceQuery.Buffer);
+    wcscpy(BasicInfo->pModulePath, ServiceQuery.Buffer);
+    *pdwSize = sizeof(TCPIP_OWNER_MODULE_BASIC_INFO) + 2 * Size;
+    LocalFree(ServiceQuery.Buffer);
+
+    return NO_ERROR;
+}
+
 /******************************************************************
  *    GetOwnerModuleFromTcpEntry (IPHLPAPI.@)
  *
@@ -2368,7 +2446,15 @@ static DWORD GetOwnerModuleFromPidEntry(DWORD OwningPid, TCPIP_OWNER_MODULE_INFO
  */
 DWORD WINAPI GetOwnerModuleFromTcpEntry( PMIB_TCPROW_OWNER_MODULE pTcpEntry, TCPIP_OWNER_MODULE_INFO_CLASS Class, PVOID Buffer, PDWORD pdwSize)
 {
-    return GetOwnerModuleFromPidEntry(pTcpEntry->dwOwningPid, Class, Buffer, pdwSize);
+    /* If we have a service tag, that's a service connection */
+    if (pTcpEntry->OwningModuleInfo[0] != 0)
+    {
+        return GetOwnerModuleFromTagEntry(pTcpEntry->dwOwningPid, (DWORD)(pTcpEntry->OwningModuleInfo[0]), Class, Buffer, pdwSize);
+    }
+    else
+    {
+        return GetOwnerModuleFromPidEntry(pTcpEntry->dwOwningPid, Class, Buffer, pdwSize);
+    }
 }
 
 /******************************************************************
@@ -2392,7 +2478,15 @@ DWORD WINAPI GetOwnerModuleFromTcpEntry( PMIB_TCPROW_OWNER_MODULE pTcpEntry, TCP
  */
 DWORD WINAPI GetOwnerModuleFromUdpEntry( PMIB_UDPROW_OWNER_MODULE pUdpEntry, TCPIP_OWNER_MODULE_INFO_CLASS Class, PVOID Buffer, PDWORD pdwSize)
 {
-    return GetOwnerModuleFromPidEntry(pUdpEntry->dwOwningPid, Class, Buffer, pdwSize);
+    /* If we have a service tag, that's a service connection */
+    if (pUdpEntry->OwningModuleInfo[0] != 0)
+    {
+        return GetOwnerModuleFromTagEntry(pUdpEntry->dwOwningPid, (DWORD)(pUdpEntry->OwningModuleInfo[0]), Class, Buffer, pdwSize);
+    }
+    else
+    {
+        return GetOwnerModuleFromPidEntry(pUdpEntry->dwOwningPid, Class, Buffer, pdwSize);
+    }
 }
 
 static void CreateNameServerListEnumNamesFunc( PWCHAR Interface, PWCHAR Server, PVOID Data)