From 79e0de2faf66a0c6d9fb998abae010975eaec49a Mon Sep 17 00:00:00 2001 From: Pierre Schweitzer Date: Mon, 31 Oct 2016 21:27:02 +0000 Subject: [PATCH] [MPR] Implement saved connections restoration. This is a bit hackish for now: it should only be attempted once, when the DLL is loaded on session login. Right now, it's attempt each time the DLL is loaded (ie, even when a program linking to it is started). This is to be fixed later on. So far, it brings the intended feature. Now, you can create a connection with net use, save it, and reboot: it will be restored on the session opening. CORE-11757 svn path=/trunk/; revision=73093 --- reactos/dll/win32/mpr/mpr_ros.diff | 159 ++++++++++++++++++++++++++++- reactos/dll/win32/mpr/wnet.c | 137 +++++++++++++++++++++++++ 2 files changed, 292 insertions(+), 4 deletions(-) diff --git a/reactos/dll/win32/mpr/mpr_ros.diff b/reactos/dll/win32/mpr/mpr_ros.diff index f7cdf9d1f5a..349f7acd803 100644 --- a/reactos/dll/win32/mpr/mpr_ros.diff +++ b/reactos/dll/win32/mpr/mpr_ros.diff @@ -69,7 +69,158 @@ Index: wnet.c TRACE("NPAddConnection %p\n", provider->addConnection); TRACE("NPAddConnection3 %p\n", provider->addConnection3); TRACE("NPCancelConnection %p\n", provider->cancelConnection); -@@ -1870,6 +1866,43 @@ +@@ -251,6 +247,85 @@ + debugstr_w(provider)); + } + ++#ifdef __REACTOS__ ++static void _restoreSavedConnection(HKEY connection, WCHAR * local) ++{ ++ NETRESOURCEW net; ++ DWORD type, prov, index, size; ++ ++ net.lpProvider = NULL; ++ net.lpRemoteName = NULL; ++ net.lpLocalName = NULL; ++ ++ TRACE("Restoring: %S\n", local); ++ ++ size = sizeof(DWORD); ++ if (RegQueryValueExW(connection, L"ConnectionType", NULL, &type, (BYTE *)&net.dwType, &size) != ERROR_SUCCESS) ++ return; ++ ++ if (type != REG_DWORD || size != sizeof(DWORD)) ++ return; ++ ++ if (RegQueryValueExW(connection, L"ProviderName", NULL, &type, NULL, &size) != ERROR_SUCCESS) ++ return; ++ ++ if (type != REG_SZ) ++ return; ++ ++ net.lpProvider = HeapAlloc(GetProcessHeap(), 0, size); ++ if (!net.lpProvider) ++ return; ++ ++ if (RegQueryValueExW(connection, L"ProviderName", NULL, NULL, (BYTE *)net.lpProvider, &size) != ERROR_SUCCESS) ++ goto cleanup; ++ ++ size = sizeof(DWORD); ++ if (RegQueryValueExW(connection, L"ProviderType", NULL, &type, (BYTE *)&prov, &size) != ERROR_SUCCESS) ++ goto cleanup; ++ ++ if (type != REG_DWORD || size != sizeof(DWORD)) ++ goto cleanup; ++ ++ index = _findProviderIndexW(net.lpProvider); ++ if (index == BAD_PROVIDER_INDEX) ++ goto cleanup; ++ ++ if (providerTable->table[index].dwNetType != prov) ++ goto cleanup; ++ ++ if (RegQueryValueExW(connection, L"RemotePath", NULL, &type, NULL, &size) != ERROR_SUCCESS) ++ goto cleanup; ++ ++ if (type != REG_SZ) ++ goto cleanup; ++ ++ net.lpRemoteName = HeapAlloc(GetProcessHeap(), 0, size); ++ if (!net.lpRemoteName) ++ goto cleanup; ++ ++ if (RegQueryValueExW(connection, L"RemotePath", NULL, NULL, (BYTE *)net.lpRemoteName, &size) != ERROR_SUCCESS) ++ goto cleanup; ++ ++ size = strlenW(local); ++ net.lpLocalName = HeapAlloc(GetProcessHeap(), 0, size * sizeof(WCHAR) + 2 * sizeof(WCHAR)); ++ if (!net.lpLocalName) ++ goto cleanup; ++ ++ strcpyW(net.lpLocalName, local); ++ net.lpLocalName[size] = ':'; ++ net.lpLocalName[size + 1] = 0; ++ ++ TRACE("Attempting connection\n"); ++ ++ WNetAddConnection2W(&net, NULL, NULL, 0); ++ ++cleanup: ++ HeapFree(GetProcessHeap(), 0, net.lpProvider); ++ HeapFree(GetProcessHeap(), 0, net.lpRemoteName); ++ HeapFree(GetProcessHeap(), 0, net.lpLocalName); ++} ++#endif ++ + void wnetInit(HINSTANCE hInstDll) + { + static const WCHAR providerOrderKey[] = { 'S','y','s','t','e','m','\\', +@@ -329,6 +404,64 @@ + } + RegCloseKey(hKey); + } ++ ++#ifdef __REACTOS__ ++ if (providerTable) ++ { ++ HKEY user_profile; ++ ++ if (RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS) ++ { ++ HKEY network; ++ WCHAR subkey[8] = {'N', 'e', 't', 'w', 'o', 'r', 'k', 0}; ++ ++ if (RegOpenKeyExW(user_profile, subkey, 0, KEY_READ, &network) == ERROR_SUCCESS) ++ { ++ DWORD size, max; ++ ++ TRACE("Enumerating remembered connections\n"); ++ ++ if (RegQueryInfoKey(network, NULL, NULL, NULL, &max, &size, NULL, NULL, NULL, NULL, NULL, NULL) == ERROR_SUCCESS) ++ { ++ WCHAR *local; ++ ++ TRACE("There are %lu connections\n", max); ++ ++ local = HeapAlloc(GetProcessHeap(), 0, (size + 1) * sizeof(WCHAR)); ++ if (local) ++ { ++ DWORD index; ++ ++ for (index = 0; index < max; ++index) ++ { ++ DWORD len = size + 1; ++ HKEY connection; ++ ++ TRACE("Trying connection %lu\n", index); ++ ++ if (RegEnumKeyExW(network, index, local, &len, NULL, NULL, NULL, NULL) != ERROR_SUCCESS) ++ continue; ++ ++ TRACE("It is %S\n", local); ++ ++ if (RegOpenKeyExW(network, local, 0, KEY_READ, &connection) != ERROR_SUCCESS) ++ continue; ++ ++ _restoreSavedConnection(connection, local); ++ RegCloseKey(connection); ++ } ++ ++ HeapFree(GetProcessHeap(), 0, local); ++ } ++ } ++ ++ RegCloseKey(network); ++ } ++ ++ RegCloseKey(user_profile); ++ } ++ } ++#endif + } + + void wnetFree(void) +@@ -1870,6 +2003,43 @@ } } @@ -113,7 +264,7 @@ Index: wnet.c return ret; } -@@ -2061,6 +2094,37 @@ +@@ -2061,6 +2231,37 @@ } } } @@ -151,7 +302,7 @@ Index: wnet.c return ret; } -@@ -2188,6 +2252,7 @@ +@@ -2188,6 +2389,7 @@ /* find the network connection for a given drive; helper for WNetGetConnection */ static DWORD get_drive_connection( WCHAR letter, LPWSTR remote, LPDWORD size ) { @@ -159,7 +310,7 @@ Index: wnet.c char buffer[1024]; struct mountmgr_unix_drive *data = (struct mountmgr_unix_drive *)buffer; HANDLE mgr; -@@ -2230,6 +2295,32 @@ +@@ -2230,6 +2432,32 @@ } CloseHandle( mgr ); return ret; diff --git a/reactos/dll/win32/mpr/wnet.c b/reactos/dll/win32/mpr/wnet.c index 07c0844886f..4c47ec10184 100644 --- a/reactos/dll/win32/mpr/wnet.c +++ b/reactos/dll/win32/mpr/wnet.c @@ -247,6 +247,85 @@ static void _tryLoadProvider(PCWSTR provider) debugstr_w(provider)); } +#ifdef __REACTOS__ +static void _restoreSavedConnection(HKEY connection, WCHAR * local) +{ + NETRESOURCEW net; + DWORD type, prov, index, size; + + net.lpProvider = NULL; + net.lpRemoteName = NULL; + net.lpLocalName = NULL; + + TRACE("Restoring: %S\n", local); + + size = sizeof(DWORD); + if (RegQueryValueExW(connection, L"ConnectionType", NULL, &type, (BYTE *)&net.dwType, &size) != ERROR_SUCCESS) + return; + + if (type != REG_DWORD || size != sizeof(DWORD)) + return; + + if (RegQueryValueExW(connection, L"ProviderName", NULL, &type, NULL, &size) != ERROR_SUCCESS) + return; + + if (type != REG_SZ) + return; + + net.lpProvider = HeapAlloc(GetProcessHeap(), 0, size); + if (!net.lpProvider) + return; + + if (RegQueryValueExW(connection, L"ProviderName", NULL, NULL, (BYTE *)net.lpProvider, &size) != ERROR_SUCCESS) + goto cleanup; + + size = sizeof(DWORD); + if (RegQueryValueExW(connection, L"ProviderType", NULL, &type, (BYTE *)&prov, &size) != ERROR_SUCCESS) + goto cleanup; + + if (type != REG_DWORD || size != sizeof(DWORD)) + goto cleanup; + + index = _findProviderIndexW(net.lpProvider); + if (index == BAD_PROVIDER_INDEX) + goto cleanup; + + if (providerTable->table[index].dwNetType != prov) + goto cleanup; + + if (RegQueryValueExW(connection, L"RemotePath", NULL, &type, NULL, &size) != ERROR_SUCCESS) + goto cleanup; + + if (type != REG_SZ) + goto cleanup; + + net.lpRemoteName = HeapAlloc(GetProcessHeap(), 0, size); + if (!net.lpRemoteName) + goto cleanup; + + if (RegQueryValueExW(connection, L"RemotePath", NULL, NULL, (BYTE *)net.lpRemoteName, &size) != ERROR_SUCCESS) + goto cleanup; + + size = strlenW(local); + net.lpLocalName = HeapAlloc(GetProcessHeap(), 0, size * sizeof(WCHAR) + 2 * sizeof(WCHAR)); + if (!net.lpLocalName) + goto cleanup; + + strcpyW(net.lpLocalName, local); + net.lpLocalName[size] = ':'; + net.lpLocalName[size + 1] = 0; + + TRACE("Attempting connection\n"); + + WNetAddConnection2W(&net, NULL, NULL, 0); + +cleanup: + HeapFree(GetProcessHeap(), 0, net.lpProvider); + HeapFree(GetProcessHeap(), 0, net.lpRemoteName); + HeapFree(GetProcessHeap(), 0, net.lpLocalName); +} +#endif + void wnetInit(HINSTANCE hInstDll) { static const WCHAR providerOrderKey[] = { 'S','y','s','t','e','m','\\', @@ -325,6 +404,64 @@ void wnetInit(HINSTANCE hInstDll) } RegCloseKey(hKey); } + +#ifdef __REACTOS__ + if (providerTable) + { + HKEY user_profile; + + if (RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS) + { + HKEY network; + WCHAR subkey[8] = {'N', 'e', 't', 'w', 'o', 'r', 'k', 0}; + + if (RegOpenKeyExW(user_profile, subkey, 0, KEY_READ, &network) == ERROR_SUCCESS) + { + DWORD size, max; + + TRACE("Enumerating remembered connections\n"); + + if (RegQueryInfoKey(network, NULL, NULL, NULL, &max, &size, NULL, NULL, NULL, NULL, NULL, NULL) == ERROR_SUCCESS) + { + WCHAR *local; + + TRACE("There are %lu connections\n", max); + + local = HeapAlloc(GetProcessHeap(), 0, (size + 1) * sizeof(WCHAR)); + if (local) + { + DWORD index; + + for (index = 0; index < max; ++index) + { + DWORD len = size + 1; + HKEY connection; + + TRACE("Trying connection %lu\n", index); + + if (RegEnumKeyExW(network, index, local, &len, NULL, NULL, NULL, NULL) != ERROR_SUCCESS) + continue; + + TRACE("It is %S\n", local); + + if (RegOpenKeyExW(network, local, 0, KEY_READ, &connection) != ERROR_SUCCESS) + continue; + + _restoreSavedConnection(connection, local); + RegCloseKey(connection); + } + + HeapFree(GetProcessHeap(), 0, local); + } + } + + RegCloseKey(network); + } + + RegCloseKey(user_profile); + } + } +#endif } void wnetFree(void) -- 2.17.1