[MPR]
[reactos.git] / reactos / dll / win32 / mpr / wnet.c
index 794bf2e..37ad58f 100644 (file)
@@ -171,6 +171,7 @@ static void _tryLoadProvider(PCWSTR provider)
                     TRACE("loaded lib %p\n", hLib);
                     if (getCaps)
                     {
+                        DWORD connectCap;
                         PWNetProvider provider =
                          &providerTable->table[providerTable->numProviders];
 
@@ -204,11 +205,16 @@ static void _tryLoadProvider(PCWSTR provider)
                                 WARN("Couldn't load enumeration functions\n");
                             }
                         }
-                        provider->addConnection = MPR_GETPROC(NPAddConnection);
-                        provider->addConnection3 = MPR_GETPROC(NPAddConnection3);
-                        provider->cancelConnection = MPR_GETPROC(NPCancelConnection);
+                        connectCap = getCaps(WNNC_CONNECTION);
+                        if (connectCap & WNNC_CON_ADDCONNECTION)
+                            provider->addConnection = MPR_GETPROC(NPAddConnection);
+                        if (connectCap & WNNC_CON_ADDCONNECTION3)
+                            provider->addConnection3 = MPR_GETPROC(NPAddConnection3);
+                        if (connectCap & WNNC_CON_CANCELCONNECTION)
+                            provider->cancelConnection = MPR_GETPROC(NPCancelConnection);
 #ifdef __REACTOS__
-                        provider->getConnection = MPR_GETPROC(NPGetConnection);
+                        if (connectCap & WNNC_CON_GETCONNECTIONS)
+                            provider->getConnection = MPR_GETPROC(NPGetConnection);
 #endif
                         TRACE("NPAddConnection %p\n", provider->addConnection);
                         TRACE("NPAddConnection3 %p\n", provider->addConnection3);
@@ -247,6 +253,85 @@ static void _tryLoadProvider(PCWSTR provider)
          debugstr_w(provider));
 }
 
+#ifdef __REACTOS__
+static void _restoreSavedConnection(HKEY connection, WCHAR * local)
+{
+    NETRESOURCEW net;
+    DWORD type, prov, index, size;
+
+    net.lpProvider = NULL;
+    net.lpRemoteName = NULL;
+    net.lpLocalName = NULL;
+
+    TRACE("Restoring: %S\n", local);
+
+    size = sizeof(DWORD);
+    if (RegQueryValueExW(connection, L"ConnectionType", NULL, &type, (BYTE *)&net.dwType, &size) != ERROR_SUCCESS)
+       return;
+
+    if (type != REG_DWORD || size != sizeof(DWORD))
+        return;
+
+    if (RegQueryValueExW(connection, L"ProviderName", NULL, &type, NULL, &size) != ERROR_SUCCESS)
+        return;
+
+    if (type != REG_SZ)
+        return;
+
+    net.lpProvider = HeapAlloc(GetProcessHeap(), 0, size);
+    if (!net.lpProvider)
+        return;
+
+    if (RegQueryValueExW(connection, L"ProviderName", NULL, NULL, (BYTE *)net.lpProvider, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    size = sizeof(DWORD);
+    if (RegQueryValueExW(connection, L"ProviderType", NULL, &type, (BYTE *)&prov, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    if (type != REG_DWORD || size != sizeof(DWORD))
+        goto cleanup;
+
+    index = _findProviderIndexW(net.lpProvider);
+    if (index == BAD_PROVIDER_INDEX)
+        goto cleanup;
+
+    if (providerTable->table[index].dwNetType != prov)
+        goto cleanup;
+
+    if (RegQueryValueExW(connection, L"RemotePath", NULL, &type, NULL, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    if (type != REG_SZ)
+        goto cleanup;
+
+    net.lpRemoteName = HeapAlloc(GetProcessHeap(), 0, size);
+    if (!net.lpRemoteName)
+        goto cleanup;
+
+    if (RegQueryValueExW(connection, L"RemotePath", NULL, NULL, (BYTE *)net.lpRemoteName, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    size = strlenW(local);
+    net.lpLocalName = HeapAlloc(GetProcessHeap(), 0, size * sizeof(WCHAR) + 2 * sizeof(WCHAR));
+    if (!net.lpLocalName)
+        goto cleanup;
+
+    strcpyW(net.lpLocalName, local);
+    net.lpLocalName[size] = ':';
+    net.lpLocalName[size + 1] = 0;
+
+    TRACE("Attempting connection\n");
+
+    WNetAddConnection2W(&net, NULL, NULL, 0);
+
+cleanup:
+    HeapFree(GetProcessHeap(), 0, net.lpProvider);
+    HeapFree(GetProcessHeap(), 0, net.lpRemoteName);
+    HeapFree(GetProcessHeap(), 0, net.lpLocalName);
+}
+#endif
+
 void wnetInit(HINSTANCE hInstDll)
 {
     static const WCHAR providerOrderKey[] = { 'S','y','s','t','e','m','\\',
@@ -325,6 +410,64 @@ void wnetInit(HINSTANCE hInstDll)
         }
         RegCloseKey(hKey);
     }
+
+#ifdef __REACTOS__
+    if (providerTable)
+    {
+        HKEY user_profile;
+
+        if (RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS)
+        {
+            HKEY network;
+            WCHAR subkey[8] = {'N', 'e', 't', 'w', 'o', 'r', 'k', 0};
+
+            if (RegOpenKeyExW(user_profile, subkey, 0, KEY_READ, &network) == ERROR_SUCCESS)
+            {
+                DWORD size, max;
+
+                TRACE("Enumerating remembered connections\n");
+
+                if (RegQueryInfoKey(network, NULL, NULL, NULL, &max, &size, NULL, NULL, NULL, NULL, NULL, NULL) == ERROR_SUCCESS)
+                {
+                    WCHAR *local;
+
+                    TRACE("There are %lu connections\n", max);
+
+                    local = HeapAlloc(GetProcessHeap(), 0, (size + 1) * sizeof(WCHAR));
+                    if (local)
+                    {
+                        DWORD index;
+
+                        for (index = 0; index < max; ++index)
+                        {
+                            DWORD len = size + 1;
+                            HKEY connection;
+
+                            TRACE("Trying connection %lu\n", index);
+
+                            if (RegEnumKeyExW(network, index, local, &len, NULL, NULL, NULL, NULL) != ERROR_SUCCESS)
+                                continue;
+
+                            TRACE("It is %S\n", local);
+
+                            if (RegOpenKeyExW(network, local, 0, KEY_READ, &connection) != ERROR_SUCCESS)
+                                continue;
+
+                            _restoreSavedConnection(connection, local);
+                            RegCloseKey(connection);
+                        }
+
+                        HeapFree(GetProcessHeap(), 0, local);
+                    }
+                }
+
+                RegCloseKey(network);
+            }
+
+            RegCloseKey(user_profile);
+        }
+    }
+#endif
 }
 
 void wnetFree(void)
@@ -1257,7 +1400,7 @@ static DWORD _copyStringToEnumW(const WCHAR *source, DWORD* left, void** end)
 static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, DWORD* user_count,
                                   void* user_buffer, DWORD* user_size)
 {
-    DWORD ret, index, count, size, i, left;
+    DWORD ret, index, count, total_count, size, i, left;
     void* end;
     NETRESOURCEW* curr, * buffer;
     HANDLE* handles;
@@ -1281,6 +1424,7 @@ static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, DWORD* user_count,
     curr = user_buffer;
     end = (char *)user_buffer + size;
     count = *user_count;
+    total_count = 0;
 
     ret = WN_NO_MORE_ENTRIES;
     for (index = 0; index < providerTable->numProviders; index++)
@@ -1300,6 +1444,7 @@ static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, DWORD* user_count,
             ret = providerTable->table[index].enumResource(handles[index],
                                                            &count, buffer,
                                                            &size);
+            total_count += count;
             if (ret == WN_MORE_DATA)
                 break;
 
@@ -1334,19 +1479,22 @@ static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, DWORD* user_count,
                     ++curr;
                 }
 
-                count = *user_count - count;
+                if (*user_count != -1)
+                    count = *user_count - total_count;
+                else
+                    count = *user_count;
                 size = left;
             }
 
-            if (ret != WN_SUCCESS || count == 0)
+            if (ret != WN_SUCCESS || total_count == 0)
                 break;
         }
     }
 
-    if (count == 0)
+    if (total_count == 0)
         ret = WN_NO_MORE_ENTRIES;
 
-    *user_count = *user_count - count;
+    *user_count = total_count;
     if (ret != WN_MORE_DATA && ret != WN_NO_MORE_ENTRIES)
         ret = WN_SUCCESS;
 
@@ -1866,6 +2014,43 @@ static DWORD wnet_use_connection( struct use_connection_context *ctxt )
         }
     }
 
+#ifdef __REACTOS__
+    if (ret == WN_SUCCESS && ctxt->flags & CONNECT_UPDATE_PROFILE)
+    {
+        HKEY user_profile;
+
+        if (netres.dwType == RESOURCETYPE_PRINT)
+        {
+            FIXME("Persistent connection are not supported for printers\n");
+            return ret;
+        }
+
+        if (RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS)
+        {
+            HKEY network;
+            WCHAR subkey[10] = {'N', 'e', 't', 'w', 'o', 'r', 'k', '\\', netres.lpLocalName[0], 0};
+
+            if (RegCreateKeyExW(user_profile, subkey, 0, NULL, REG_OPTION_NON_VOLATILE, KEY_ALL_ACCESS, NULL, &network, NULL) == ERROR_SUCCESS)
+            {
+                DWORD dword_arg = RESOURCETYPE_DISK;
+                DWORD len = (strlenW(provider->name) + 1) * sizeof(WCHAR);
+
+                RegSetValueExW(network, L"ConnectionType", 0, REG_DWORD, (const BYTE *)&dword_arg, sizeof(DWORD));
+                RegSetValueExW(network, L"ProviderName", 0, REG_SZ, (const BYTE *)provider->name, len);
+                dword_arg = provider->dwNetType;
+                RegSetValueExW(network, L"ProviderType", 0, REG_DWORD, (const BYTE *)&dword_arg, sizeof(DWORD));
+                len = (strlenW(netres.lpRemoteName) + 1) * sizeof(WCHAR);
+                RegSetValueExW(network, L"RemotePath", 0, REG_SZ, (const BYTE *)netres.lpRemoteName, len);
+                len = 0;
+                RegSetValueExW(network, L"UserName", 0, REG_SZ, (const BYTE *)netres.lpRemoteName, len);
+                RegCloseKey(network);
+            }
+
+            RegCloseKey(user_profile);
+        }
+    }
+#endif
+
     return ret;
 }
 
@@ -2057,6 +2242,37 @@ DWORD WINAPI WNetCancelConnection2W( LPCWSTR lpName, DWORD dwFlags, BOOL fForce
             }
         }
     }
+#ifdef __REACTOS__
+
+    if (dwFlags & CONNECT_UPDATE_PROFILE)
+    {
+        HKEY user_profile;
+        WCHAR *coma = strchrW(lpName, ':');
+
+        if (coma && RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS)
+        {
+            WCHAR  *subkey;
+            DWORD len;
+
+            len = (ULONG_PTR)coma - (ULONG_PTR)lpName + sizeof(L"Network\\");
+            subkey = HeapAlloc(GetProcessHeap(), 0, len);
+            if (subkey)
+            {
+                strcpyW(subkey, L"Network\\");
+                memcpy(subkey + (sizeof(L"Network\\") / sizeof(WCHAR)) - 1, lpName, (ULONG_PTR)coma - (ULONG_PTR)lpName);
+                subkey[len / sizeof(WCHAR) - 1] = 0;
+
+                TRACE("Removing: %S\n", subkey);
+
+                RegDeleteKeyW(user_profile, subkey);
+                HeapFree(GetProcessHeap(), 0, subkey);
+            }
+
+            RegCloseKey(user_profile);
+        }
+    }
+
+#endif
     return ret;
 }