[MPR]
[reactos.git] / reactos / dll / win32 / mpr / wnet.c
index 7ffc427..37ad58f 100644 (file)
@@ -4,6 +4,7 @@
  * Copyright 1999 Ulrich Weigand
  * Copyright 2004 Juan Lang
  * Copyright 2007 Maarten Lankhorst
+ * Copyright 2016 Pierre Schweitzer
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
@@ -91,7 +92,11 @@ typedef struct _WNetEnumerator
     DWORD          dwScope;
     DWORD          dwType;
     DWORD          dwUsage;
-    LPVOID         lpBuffer;
+    union
+    {
+        NETRESOURCEW* net;
+        HANDLE* handles;
+    } specific;
 } WNetEnumerator, *PWNetEnumerator;
 
 #define BAD_PROVIDER_INDEX (DWORD)0xffffffff
@@ -166,6 +171,7 @@ static void _tryLoadProvider(PCWSTR provider)
                     TRACE("loaded lib %p\n", hLib);
                     if (getCaps)
                     {
+                        DWORD connectCap;
                         PWNetProvider provider =
                          &providerTable->table[providerTable->numProviders];
 
@@ -199,11 +205,16 @@ static void _tryLoadProvider(PCWSTR provider)
                                 WARN("Couldn't load enumeration functions\n");
                             }
                         }
-                        provider->addConnection = MPR_GETPROC(NPAddConnection);
-                        provider->addConnection3 = MPR_GETPROC(NPAddConnection3);
-                        provider->cancelConnection = MPR_GETPROC(NPCancelConnection);
+                        connectCap = getCaps(WNNC_CONNECTION);
+                        if (connectCap & WNNC_CON_ADDCONNECTION)
+                            provider->addConnection = MPR_GETPROC(NPAddConnection);
+                        if (connectCap & WNNC_CON_ADDCONNECTION3)
+                            provider->addConnection3 = MPR_GETPROC(NPAddConnection3);
+                        if (connectCap & WNNC_CON_CANCELCONNECTION)
+                            provider->cancelConnection = MPR_GETPROC(NPCancelConnection);
 #ifdef __REACTOS__
-                        provider->getConnection = MPR_GETPROC(NPGetConnection);
+                        if (connectCap & WNNC_CON_GETCONNECTIONS)
+                            provider->getConnection = MPR_GETPROC(NPGetConnection);
 #endif
                         TRACE("NPAddConnection %p\n", provider->addConnection);
                         TRACE("NPAddConnection3 %p\n", provider->addConnection3);
@@ -242,6 +253,85 @@ static void _tryLoadProvider(PCWSTR provider)
          debugstr_w(provider));
 }
 
+#ifdef __REACTOS__
+static void _restoreSavedConnection(HKEY connection, WCHAR * local)
+{
+    NETRESOURCEW net;
+    DWORD type, prov, index, size;
+
+    net.lpProvider = NULL;
+    net.lpRemoteName = NULL;
+    net.lpLocalName = NULL;
+
+    TRACE("Restoring: %S\n", local);
+
+    size = sizeof(DWORD);
+    if (RegQueryValueExW(connection, L"ConnectionType", NULL, &type, (BYTE *)&net.dwType, &size) != ERROR_SUCCESS)
+       return;
+
+    if (type != REG_DWORD || size != sizeof(DWORD))
+        return;
+
+    if (RegQueryValueExW(connection, L"ProviderName", NULL, &type, NULL, &size) != ERROR_SUCCESS)
+        return;
+
+    if (type != REG_SZ)
+        return;
+
+    net.lpProvider = HeapAlloc(GetProcessHeap(), 0, size);
+    if (!net.lpProvider)
+        return;
+
+    if (RegQueryValueExW(connection, L"ProviderName", NULL, NULL, (BYTE *)net.lpProvider, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    size = sizeof(DWORD);
+    if (RegQueryValueExW(connection, L"ProviderType", NULL, &type, (BYTE *)&prov, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    if (type != REG_DWORD || size != sizeof(DWORD))
+        goto cleanup;
+
+    index = _findProviderIndexW(net.lpProvider);
+    if (index == BAD_PROVIDER_INDEX)
+        goto cleanup;
+
+    if (providerTable->table[index].dwNetType != prov)
+        goto cleanup;
+
+    if (RegQueryValueExW(connection, L"RemotePath", NULL, &type, NULL, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    if (type != REG_SZ)
+        goto cleanup;
+
+    net.lpRemoteName = HeapAlloc(GetProcessHeap(), 0, size);
+    if (!net.lpRemoteName)
+        goto cleanup;
+
+    if (RegQueryValueExW(connection, L"RemotePath", NULL, NULL, (BYTE *)net.lpRemoteName, &size) != ERROR_SUCCESS)
+        goto cleanup;
+
+    size = strlenW(local);
+    net.lpLocalName = HeapAlloc(GetProcessHeap(), 0, size * sizeof(WCHAR) + 2 * sizeof(WCHAR));
+    if (!net.lpLocalName)
+        goto cleanup;
+
+    strcpyW(net.lpLocalName, local);
+    net.lpLocalName[size] = ':';
+    net.lpLocalName[size + 1] = 0;
+
+    TRACE("Attempting connection\n");
+
+    WNetAddConnection2W(&net, NULL, NULL, 0);
+
+cleanup:
+    HeapFree(GetProcessHeap(), 0, net.lpProvider);
+    HeapFree(GetProcessHeap(), 0, net.lpRemoteName);
+    HeapFree(GetProcessHeap(), 0, net.lpLocalName);
+}
+#endif
+
 void wnetInit(HINSTANCE hInstDll)
 {
     static const WCHAR providerOrderKey[] = { 'S','y','s','t','e','m','\\',
@@ -320,6 +410,64 @@ void wnetInit(HINSTANCE hInstDll)
         }
         RegCloseKey(hKey);
     }
+
+#ifdef __REACTOS__
+    if (providerTable)
+    {
+        HKEY user_profile;
+
+        if (RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS)
+        {
+            HKEY network;
+            WCHAR subkey[8] = {'N', 'e', 't', 'w', 'o', 'r', 'k', 0};
+
+            if (RegOpenKeyExW(user_profile, subkey, 0, KEY_READ, &network) == ERROR_SUCCESS)
+            {
+                DWORD size, max;
+
+                TRACE("Enumerating remembered connections\n");
+
+                if (RegQueryInfoKey(network, NULL, NULL, NULL, &max, &size, NULL, NULL, NULL, NULL, NULL, NULL) == ERROR_SUCCESS)
+                {
+                    WCHAR *local;
+
+                    TRACE("There are %lu connections\n", max);
+
+                    local = HeapAlloc(GetProcessHeap(), 0, (size + 1) * sizeof(WCHAR));
+                    if (local)
+                    {
+                        DWORD index;
+
+                        for (index = 0; index < max; ++index)
+                        {
+                            DWORD len = size + 1;
+                            HKEY connection;
+
+                            TRACE("Trying connection %lu\n", index);
+
+                            if (RegEnumKeyExW(network, index, local, &len, NULL, NULL, NULL, NULL) != ERROR_SUCCESS)
+                                continue;
+
+                            TRACE("It is %S\n", local);
+
+                            if (RegOpenKeyExW(network, local, 0, KEY_READ, &connection) != ERROR_SUCCESS)
+                                continue;
+
+                            _restoreSavedConnection(connection, local);
+                            RegCloseKey(connection);
+                        }
+
+                        HeapFree(GetProcessHeap(), 0, local);
+                    }
+                }
+
+                RegCloseKey(network);
+            }
+
+            RegCloseKey(user_profile);
+        }
+    }
+#endif
 }
 
 void wnetFree(void)
@@ -417,7 +565,7 @@ static PWNetEnumerator _createGlobalEnumeratorW(DWORD dwScope, DWORD dwType,
         ret->dwScope = dwScope;
         ret->dwType  = dwType;
         ret->dwUsage = dwUsage;
-        ret->lpBuffer = _copyNetResourceForEnumW(lpNet);
+        ret->specific.net = _copyNetResourceForEnumW(lpNet);
     }
     return ret;
 }
@@ -464,17 +612,15 @@ static PWNetEnumerator _createContextEnumerator(DWORD dwScope, DWORD dwType,
 static PWNetEnumerator _createConnectedEnumerator(DWORD dwScope, DWORD dwType,
  DWORD dwUsage)
 {
-    PWNetEnumerator ret = HeapAlloc(GetProcessHeap(),
-     HEAP_ZERO_MEMORY, sizeof(WNetEnumerator));
-
+    PWNetEnumerator ret = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(WNetEnumerator));
     if (ret)
     {
         ret->enumType = WNET_ENUMERATOR_TYPE_CONNECTED;
         ret->dwScope = dwScope;
         ret->dwType  = dwType;
         ret->dwUsage = dwUsage;
-        ret->lpBuffer = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(HANDLE) * providerTable->numProviders);
-        if (!ret->lpBuffer)
+        ret->specific.handles = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(HANDLE) * providerTable->numProviders);
+        if (!ret->specific.handles)
         {
             HeapFree(GetProcessHeap(), 0, ret);
             ret = NULL;
@@ -843,14 +989,14 @@ DWORD WINAPI WNetOpenEnumW( DWORD dwScope, DWORD dwType, DWORD dwUsage,
                 *lphEnum = _createContextEnumerator(dwScope, dwType, dwUsage);
                 ret = *lphEnum ? WN_SUCCESS : WN_OUT_OF_MEMORY;
                 break;
-            case RESOURCE_REMEMBERED:
-                *lphEnum = _createNullEnumerator();
-                ret = *lphEnum ? WN_SUCCESS : WN_OUT_OF_MEMORY;
-                break;
             case RESOURCE_CONNECTED:
                 *lphEnum = _createConnectedEnumerator(dwScope, dwType, dwUsage);
                 ret = *lphEnum ? WN_SUCCESS : WN_OUT_OF_MEMORY;
                 break;
+            case RESOURCE_REMEMBERED:
+                *lphEnum = _createNullEnumerator();
+                ret = *lphEnum ? WN_SUCCESS : WN_OUT_OF_MEMORY;
+                break;
             default:
                 WARN("unknown scope 0x%08x\n", dwScope);
                 ret = WN_BAD_VALUE;
@@ -1074,7 +1220,7 @@ static DWORD _enumerateGlobalPassthroughW(PWNetEnumerator enumerator,
     {
         ret = providerTable->table[enumerator->providerIndex].
          openEnum(enumerator->dwScope, enumerator->dwType,
-         enumerator->dwUsage, enumerator->lpBuffer,
+         enumerator->dwUsage, enumerator->specific.net,
          &enumerator->handle);
         if (ret == WN_SUCCESS)
         {
@@ -1112,7 +1258,7 @@ static DWORD _enumerateGlobalW(PWNetEnumerator enumerator, LPDWORD lpcCount,
     switch (enumerator->dwScope)
     {
         case RESOURCE_GLOBALNET:
-            if (enumerator->lpBuffer)
+            if (enumerator->specific.net)
                 ret = _enumerateGlobalPassthroughW(enumerator, lpcCount,
                  lpBuffer, lpBufferSize);
             else
@@ -1233,39 +1379,54 @@ static DWORD _enumerateContextW(PWNetEnumerator enumerator, LPDWORD lpcCount,
     return ret;
 }
 
-static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, LPDWORD lpcCount,
- LPVOID lpBuffer, LPDWORD lpBufferSize)
+static DWORD _copyStringToEnumW(const WCHAR *source, DWORD* left, void** end)
 {
-    DWORD ret, index, count, size, i, len, left;
-    PVOID end;
-    LPNETRESOURCEW curr, buffer;
-    PHANDLE handles;
+    DWORD len;
+    WCHAR* local = *end;
+
+    len = strlenW(source) + 1;
+    len *= sizeof(WCHAR);
+    if (*left < len)
+        return WN_MORE_DATA;
+
+    local -= (len / sizeof(WCHAR));
+    memcpy(local, source, len);
+    *left -= len;
+    *end = local;
+
+    return WN_SUCCESS;
+}
+
+static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, DWORD* user_count,
+                                  void* user_buffer, DWORD* user_size)
+{
+    DWORD ret, index, count, total_count, size, i, left;
+    void* end;
+    NETRESOURCEW* curr, * buffer;
+    HANDLE* handles;
 
     if (!enumerator)
         return WN_BAD_POINTER;
     if (enumerator->enumType != WNET_ENUMERATOR_TYPE_CONNECTED)
         return WN_BAD_VALUE;
-    if (!lpcCount)
-        return WN_BAD_POINTER;
-    if (!lpBuffer)
-        return WN_BAD_POINTER;
-    if (!lpBufferSize)
+    if (!user_count || !user_buffer || !user_size)
         return WN_BAD_POINTER;
     if (!providerTable)
         return WN_NO_NETWORK;
 
-    handles = enumerator->lpBuffer;
-    left = *lpBufferSize;
-    size = *lpBufferSize;
-    buffer = HeapAlloc(GetProcessHeap(), 0, *lpBufferSize);
+    handles = enumerator->specific.handles;
+    left = *user_size;
+    size = *user_size;
+    buffer = HeapAlloc(GetProcessHeap(), 0, *user_size);
     if (!buffer)
         return WN_NO_NETWORK;
 
-    curr = lpBuffer;
-    end = (PVOID)((ULONG_PTR)lpBuffer + size);
-    count = *lpcCount;
-
+    curr = user_buffer;
+    end = (char *)user_buffer + size;
+    count = *user_count;
+    total_count = 0;
 
+    ret = WN_NO_MORE_ENTRIES;
     for (index = 0; index < providerTable->numProviders; index++)
     {
         if (providerTable->table[index].dwEnumScopes)
@@ -1281,15 +1442,11 @@ static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, LPDWORD lpcCount,
             }
 
             ret = providerTable->table[index].enumResource(handles[index],
-                                                           &count,
-                                                           buffer,
+                                                           &count, buffer,
                                                            &size);
-
+            total_count += count;
             if (ret == WN_MORE_DATA)
-            {
-
                 break;
-            }
 
             if (ret == WN_SUCCESS)
             {
@@ -1304,63 +1461,40 @@ static DWORD _enumerateConnectedW(PWNetEnumerator enumerator, LPDWORD lpcCount,
                     memcpy(curr, &buffer[i], sizeof(NETRESOURCEW));
                     left -= sizeof(NETRESOURCEW);
 
-                    len = WideCharToMultiByte(CP_ACP, 0, buffer[i].lpLocalName, -1, NULL, 0, NULL, NULL);
-                    len *= sizeof(WCHAR);
-                    if (left < len)
-                    {
-                        ret = WN_MORE_DATA;
+                    ret = _copyStringToEnumW(buffer[i].lpLocalName, &left, &end);
+                    if (ret == WN_MORE_DATA)
                         break;
-                    }
-
-                    end = (PVOID)((ULONG_PTR)end - len);
                     curr->lpLocalName = end;
-                    memcpy(end, buffer[i].lpLocalName, len);
-                    left -= len;
 
-                    len = WideCharToMultiByte(CP_ACP, 0, buffer[i].lpRemoteName, -1, NULL, 0, NULL, NULL);
-                    len *= sizeof(WCHAR);
-                    if (left < len)
-                    {
-                        ret = WN_MORE_DATA;
+                    ret = _copyStringToEnumW(buffer[i].lpRemoteName, &left, &end);
+                    if (ret == WN_MORE_DATA)
                         break;
-                    }
-
-                    end = (PVOID)((ULONG_PTR)end - len);
                     curr->lpRemoteName = end;
-                    memcpy(end, buffer[i].lpRemoteName, len);
-                    left -= len;
 
-                    len = WideCharToMultiByte(CP_ACP, 0, buffer[i].lpProvider, -1, NULL, 0, NULL, NULL);
-                    len *= sizeof(WCHAR);
-                    if (left < len)
-                    {
-                        ret = WN_MORE_DATA;
+                    ret = _copyStringToEnumW(buffer[i].lpProvider, &left, &end);
+                    if (ret == WN_MORE_DATA)
                         break;
-                    }
-
-                    end = (PVOID)((ULONG_PTR)end - len);
                     curr->lpProvider = end;
-                    memcpy(end, buffer[i].lpProvider, len);
-                    left -= len;
 
                     ++curr;
                 }
 
-                count = *lpcCount - count;
+                if (*user_count != -1)
+                    count = *user_count - total_count;
+                else
+                    count = *user_count;
                 size = left;
             }
 
-            if (ret != WN_SUCCESS || count == 0)
-            {
+            if (ret != WN_SUCCESS || total_count == 0)
                 break;
-            }
         }
     }
 
-    if (count == 0)
+    if (total_count == 0)
         ret = WN_NO_MORE_ENTRIES;
 
-    *lpcCount = *lpcCount - count;
+    *user_count = total_count;
     if (ret != WN_MORE_DATA && ret != WN_NO_MORE_ENTRIES)
         ret = WN_SUCCESS;
 
@@ -1435,7 +1569,7 @@ DWORD WINAPI WNetEnumResourceW( HANDLE hEnum, LPDWORD lpcCount,
 DWORD WINAPI WNetCloseEnum( HANDLE hEnum )
 {
     DWORD ret, index;
-    PHANDLE handles;
+    HANDLE *handles;
 
     TRACE( "(%p)\n", hEnum );
 
@@ -1445,21 +1579,12 @@ DWORD WINAPI WNetCloseEnum( HANDLE hEnum )
 
         switch (enumerator->enumType)
         {
-            case WNET_ENUMERATOR_TYPE_CONNECTED:
-                handles = enumerator->lpBuffer;
-                for (index = 0; index < providerTable->numProviders; index++)
-                {
-                    if (providerTable->table[index].dwEnumScopes && handles[index] != 0)
-                        providerTable->table[index].closeEnum(handles[index]);
-                }
-                HeapFree(GetProcessHeap(), 0, handles);
-                ret = WN_SUCCESS;
             case WNET_ENUMERATOR_TYPE_NULL:
                 ret = WN_SUCCESS;
                 break;
             case WNET_ENUMERATOR_TYPE_GLOBAL:
-                if (enumerator->lpBuffer)
-                    _freeEnumNetResource(enumerator->lpBuffer);
+                if (enumerator->specific.net)
+                    _freeEnumNetResource(enumerator->specific.net);
                 if (enumerator->handle)
                     providerTable->table[enumerator->providerIndex].
                      closeEnum(enumerator->handle);
@@ -1471,6 +1596,16 @@ DWORD WINAPI WNetCloseEnum( HANDLE hEnum )
                      closeEnum(enumerator->handle);
                 ret = WN_SUCCESS;
                 break;
+            case WNET_ENUMERATOR_TYPE_CONNECTED:
+                handles = enumerator->specific.handles;
+                for (index = 0; index < providerTable->numProviders; index++)
+                {
+                    if (providerTable->table[index].dwEnumScopes && handles[index])
+                        providerTable->table[index].closeEnum(handles[index]);
+                }
+                HeapFree(GetProcessHeap(), 0, handles);
+                ret = WN_SUCCESS;
+                break;
             default:
                 WARN("bogus enumerator type!\n");
                 ret = WN_BAD_HANDLE;
@@ -1797,54 +1932,38 @@ static void use_connection_set_accessnameW(struct use_connection_context *ctxt,
         strcpyW(accessname, ctxt->resource->lpRemoteName);
 }
 
-static WCHAR * select_provider(struct use_connection_context *ctxt)
+static DWORD wnet_use_provider( struct use_connection_context *ctxt, NETRESOURCEW * netres, WNetProvider *provider, BOOLEAN redirect )
 {
-    DWORD ret, prov_size = 0x1000, len;
-    LPNETRESOURCEW provider;
-    WCHAR * system;
-    WCHAR * provider_name;
+    DWORD caps, ret;
 
-    provider = HeapAlloc(GetProcessHeap(), 0, prov_size);
-    if (!provider)
-    {
-        return NULL;
-    }
-
-    ret = WNetGetResourceInformationW(ctxt->resource, provider, &prov_size, &system);
-    if (ret == ERROR_MORE_DATA)
-    {
-        HeapFree(GetProcessHeap(), 0, provider);
-        provider = HeapAlloc(GetProcessHeap(), 0, prov_size);
-        if (!provider)
-        {
-            return NULL;
-        }
-
-        ret = WNetGetResourceInformationW(ctxt->resource, provider, &prov_size, &system);
-    }
+    caps = provider->getCaps(WNNC_CONNECTION);
+    if (!(caps & (WNNC_CON_ADDCONNECTION | WNNC_CON_ADDCONNECTION3)))
+        return ERROR_BAD_PROVIDER;
 
-    if (ret != NO_ERROR)
+    ret = WN_ACCESS_DENIED;
+    do
     {
-        HeapFree(GetProcessHeap(), 0, provider);
-        return NULL;
-    }
+        if ((caps & WNNC_CON_ADDCONNECTION3) && provider->addConnection3)
+            ret = provider->addConnection3(ctxt->hwndOwner, netres, ctxt->password, ctxt->userid, ctxt->flags);
+        else if ((caps & WNNC_CON_ADDCONNECTION) && provider->addConnection)
+            ret = provider->addConnection(netres, ctxt->password, ctxt->userid);
 
-    len = WideCharToMultiByte(CP_ACP, 0, provider->lpProvider, -1, NULL, 0, NULL, NULL);
-    provider_name = HeapAlloc(GetProcessHeap(), 0, len * sizeof(WCHAR));
-    if (provider_name)
-        memcpy(provider_name, provider->lpProvider, len * sizeof(WCHAR));
+        if (ret == WN_ALREADY_CONNECTED && redirect)
+            netres->lpLocalName[0] -= 1;
+    } while (redirect && ret == WN_ALREADY_CONNECTED && netres->lpLocalName[0] >= 'C');
 
-    HeapFree(GetProcessHeap(), 0, provider);
+    if (ret == WN_SUCCESS && ctxt->accessname)
+        ctxt->set_accessname(ctxt, netres->lpLocalName);
 
-    return provider_name;
+    return ret;
 }
 
 static DWORD wnet_use_connection( struct use_connection_context *ctxt )
 {
     WNetProvider *provider;
-    DWORD index, ret, caps;
-    BOOLEAN redirect = FALSE, prov = FALSE;
-    WCHAR letter[3] = {'z', ':', 0};
+    DWORD index, ret = WN_NO_NETWORK;
+    BOOL redirect = FALSE;
+    WCHAR letter[3] = {'Z', ':', 0};
     NETRESOURCEW netres;
 
     if (!providerTable || providerTable->numProviders == 0)
@@ -1857,13 +1976,11 @@ static DWORD wnet_use_connection( struct use_connection_context *ctxt )
     if (!netres.lpLocalName && (ctxt->flags & CONNECT_REDIRECT))
     {
         if (netres.dwType != RESOURCETYPE_DISK && netres.dwType != RESOURCETYPE_PRINT)
-        {
             return ERROR_BAD_DEV_TYPE;
-        }
 
         if (netres.dwType == RESOURCETYPE_PRINT)
         {
-            FIXME("Locale device selection is not implemented for printers.\n");
+            FIXME("Local device selection is not implemented for printers.\n");
             return WN_NO_NETWORK;
         }
 
@@ -1872,62 +1989,67 @@ static DWORD wnet_use_connection( struct use_connection_context *ctxt )
     }
 
     if (ctxt->flags & CONNECT_INTERACTIVE)
-    {
         return ERROR_BAD_NET_NAME;
-    }
 
-    if (ctxt->flags & CONNECT_UPDATE_PROFILE)
-        FIXME("Connection saving is not implemented\n");
+    if ((ret = ctxt->pre_set_accessname(ctxt, netres.lpLocalName)))
+        return ret;
 
-    if (!netres.lpProvider)
+    if (netres.lpProvider)
     {
-        netres.lpProvider = select_provider(ctxt);
-        if (!netres.lpProvider)
-        {
-            return ERROR_NO_NET_OR_BAD_PATH;
-        }
-
-        prov = TRUE;
-    }
+        index = _findProviderIndexW(netres.lpProvider);
+        if (index == BAD_PROVIDER_INDEX)
+            return ERROR_BAD_PROVIDER;
 
-    index = _findProviderIndexW(netres.lpProvider);
-    if (index == BAD_PROVIDER_INDEX)
-    {
-        ret = ERROR_BAD_PROVIDER;
-        goto done;
+        provider = &providerTable->table[index];
+        ret = wnet_use_provider(ctxt, &netres, provider, redirect);
     }
-
-    provider = &providerTable->table[index];
-    caps = provider->getCaps(WNNC_CONNECTION);
-    if (!(caps & (WNNC_CON_ADDCONNECTION | WNNC_CON_ADDCONNECTION3)))
+    else
     {
-        ret = ERROR_BAD_PROVIDER;
-        goto done;
+        for (index = 0; index < providerTable->numProviders; index++)
+        {
+            provider = &providerTable->table[index];
+            ret = wnet_use_provider(ctxt, &netres, provider, redirect);
+            if (ret == WN_SUCCESS || ret == WN_ALREADY_CONNECTED)
+                break;
+        }
     }
 
-    if ((ret = ctxt->pre_set_accessname(ctxt, netres.lpLocalName)))
+#ifdef __REACTOS__
+    if (ret == WN_SUCCESS && ctxt->flags & CONNECT_UPDATE_PROFILE)
     {
-        goto done;
-    }
+        HKEY user_profile;
 
-    ret = WN_ACCESS_DENIED;
-    do
-    {
-        if ((caps & WNNC_CON_ADDCONNECTION3) && provider->addConnection3)
-            ret = provider->addConnection3(ctxt->hwndOwner, &netres, ctxt->password, ctxt->userid, ctxt->flags);
-        else if ((caps & WNNC_CON_ADDCONNECTION) && provider->addConnection)
-            ret = provider->addConnection(&netres, ctxt->password, ctxt->userid);
+        if (netres.dwType == RESOURCETYPE_PRINT)
+        {
+            FIXME("Persistent connection are not supported for printers\n");
+            return ret;
+        }
 
-        if (ret != NO_ERROR && redirect)
-            letter[0] -= 1;
-    } while (redirect && ret == WN_ALREADY_CONNECTED && letter[0] >= 'c');
+        if (RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS)
+        {
+            HKEY network;
+            WCHAR subkey[10] = {'N', 'e', 't', 'w', 'o', 'r', 'k', '\\', netres.lpLocalName[0], 0};
 
-    if (ret == WN_SUCCESS && ctxt->accessname)
-        ctxt->set_accessname(ctxt, netres.lpLocalName);
+            if (RegCreateKeyExW(user_profile, subkey, 0, NULL, REG_OPTION_NON_VOLATILE, KEY_ALL_ACCESS, NULL, &network, NULL) == ERROR_SUCCESS)
+            {
+                DWORD dword_arg = RESOURCETYPE_DISK;
+                DWORD len = (strlenW(provider->name) + 1) * sizeof(WCHAR);
+
+                RegSetValueExW(network, L"ConnectionType", 0, REG_DWORD, (const BYTE *)&dword_arg, sizeof(DWORD));
+                RegSetValueExW(network, L"ProviderName", 0, REG_SZ, (const BYTE *)provider->name, len);
+                dword_arg = provider->dwNetType;
+                RegSetValueExW(network, L"ProviderType", 0, REG_DWORD, (const BYTE *)&dword_arg, sizeof(DWORD));
+                len = (strlenW(netres.lpRemoteName) + 1) * sizeof(WCHAR);
+                RegSetValueExW(network, L"RemotePath", 0, REG_SZ, (const BYTE *)netres.lpRemoteName, len);
+                len = 0;
+                RegSetValueExW(network, L"UserName", 0, REG_SZ, (const BYTE *)netres.lpRemoteName, len);
+                RegCloseKey(network);
+            }
 
-done:
-    if (prov)
-        HeapFree(GetProcessHeap(), 0, netres.lpProvider);
+            RegCloseKey(user_profile);
+        }
+    }
+#endif
 
     return ret;
 }
@@ -2109,7 +2231,7 @@ DWORD WINAPI WNetCancelConnection2W( LPCWSTR lpName, DWORD dwFlags, BOOL fForce
         for (index = 0; index < providerTable->numProviders; index++)
         {
             if(providerTable->table[index].getCaps(WNNC_CONNECTION) &
-                WNNC_CON_GETCONNECTIONS)
+                WNNC_CON_CANCELCONNECTION)
             {
                 if (providerTable->table[index].cancelConnection)
                     ret = providerTable->table[index].cancelConnection((LPWSTR)lpName, fForce);
@@ -2120,6 +2242,37 @@ DWORD WINAPI WNetCancelConnection2W( LPCWSTR lpName, DWORD dwFlags, BOOL fForce
             }
         }
     }
+#ifdef __REACTOS__
+
+    if (dwFlags & CONNECT_UPDATE_PROFILE)
+    {
+        HKEY user_profile;
+        WCHAR *coma = strchrW(lpName, ':');
+
+        if (coma && RegOpenCurrentUser(KEY_ALL_ACCESS, &user_profile) == ERROR_SUCCESS)
+        {
+            WCHAR  *subkey;
+            DWORD len;
+
+            len = (ULONG_PTR)coma - (ULONG_PTR)lpName + sizeof(L"Network\\");
+            subkey = HeapAlloc(GetProcessHeap(), 0, len);
+            if (subkey)
+            {
+                strcpyW(subkey, L"Network\\");
+                memcpy(subkey + (sizeof(L"Network\\") / sizeof(WCHAR)) - 1, lpName, (ULONG_PTR)coma - (ULONG_PTR)lpName);
+                subkey[len / sizeof(WCHAR) - 1] = 0;
+
+                TRACE("Removing: %S\n", subkey);
+
+                RegDeleteKeyW(user_profile, subkey);
+                HeapFree(GetProcessHeap(), 0, subkey);
+            }
+
+            RegCloseKey(user_profile);
+        }
+    }
+
+#endif
     return ret;
 }