[MPR]
[reactos.git] / reactos / dll / win32 / mpr / wnet.c
index 045e197..3e117e6 100644 (file)
@@ -48,6 +48,10 @@ typedef struct _WNetProvider
     PF_NPGetResourceInformation getResourceInformation;
     PF_NPAddConnection addConnection;
     PF_NPAddConnection3 addConnection3;
+    PF_NPCancelConnection cancelConnection;
+#ifdef __REACTOS__
+    PF_NPGetConnection getConnection;
+#endif
 } WNetProvider, *PWNetProvider;
 
 typedef struct _WNetProviderTable
@@ -196,8 +200,13 @@ static void _tryLoadProvider(PCWSTR provider)
                         }
                         provider->addConnection = MPR_GETPROC(NPAddConnection);
                         provider->addConnection3 = MPR_GETPROC(NPAddConnection3);
+                        provider->cancelConnection = MPR_GETPROC(NPCancelConnection);
+#ifdef __REACTOS__
+                        provider->getConnection = MPR_GETPROC(NPGetConnection);
+#endif
                         TRACE("NPAddConnection %p\n", provider->addConnection);
                         TRACE("NPAddConnection3 %p\n", provider->addConnection3);
+                        TRACE("NPCancelConnection %p\n", provider->cancelConnection);
                         providerTable->numProviders++;
                     }
                     else
@@ -1607,10 +1616,53 @@ static void use_connection_set_accessnameW(struct use_connection_context *ctxt)
         strcpyW(accessname, ctxt->resource->lpRemoteName);
 }
 
+static WCHAR * select_provider(struct use_connection_context *ctxt)
+{
+    DWORD ret, prov_size = 0x1000, len;
+    LPNETRESOURCEW provider;
+    WCHAR * system;
+
+    provider = HeapAlloc(GetProcessHeap(), 0, prov_size);
+    if (!provider)
+    {
+        return NULL;
+    }
+
+    ret = WNetGetResourceInformationW(ctxt->resource, provider, &prov_size, &system);
+    if (ret == ERROR_MORE_DATA)
+    {
+        HeapFree(GetProcessHeap(), 0, provider);
+        provider = HeapAlloc(GetProcessHeap(), 0, prov_size);
+        if (!provider)
+        {
+            return NULL;
+        }
+
+        ret = WNetGetResourceInformationW(ctxt->resource, provider, &prov_size, &system);
+    }
+
+    if (ret != NO_ERROR)
+    {
+        HeapFree(GetProcessHeap(), 0, provider);
+        return NULL;
+    }
+
+    len = WideCharToMultiByte(CP_ACP, 0, provider->lpProvider, -1, NULL, 0, NULL, NULL);
+    ctxt->resource->lpProvider = HeapAlloc(GetProcessHeap(), 0, len * sizeof(WCHAR));
+    if (ctxt->resource->lpProvider)
+        memcpy(ctxt->resource->lpProvider, provider->lpProvider, len * sizeof(WCHAR));
+
+    HeapFree(GetProcessHeap(), 0, provider);
+
+    return ctxt->resource->lpProvider;
+}
+
 static DWORD wnet_use_connection( struct use_connection_context *ctxt )
 {
     WNetProvider *provider;
     DWORD index, ret, caps;
+    BOOLEAN redirect = FALSE, prov = FALSE;
+    WCHAR letter[3] = {'z', ':', 0};
 
     if (!providerTable || providerTable->numProviders == 0)
         return WN_NO_NETWORK;
@@ -1618,42 +1670,86 @@ static DWORD wnet_use_connection( struct use_connection_context *ctxt )
     if (!ctxt->resource)
         return ERROR_INVALID_PARAMETER;
 
-    if (!ctxt->resource->lpProvider)
+    if (!ctxt->resource->lpLocalName && (ctxt->flags & CONNECT_REDIRECT))
     {
-        FIXME("Networking provider selection is not implemented.\n");
-        return WN_NO_NETWORK;
+        if (ctxt->resource->dwType != RESOURCETYPE_DISK && ctxt->resource->dwType != RESOURCETYPE_PRINT)
+        {
+            return ERROR_BAD_DEV_TYPE;
+        }
+
+        if (ctxt->resource->dwType == RESOURCETYPE_PRINT)
+        {
+            FIXME("Locale device selection is not implemented for printers.\n");
+            return WN_NO_NETWORK;
+        }
+
+        redirect = TRUE;
+        ctxt->resource->lpLocalName = letter;
     }
 
-    if (!ctxt->resource->lpLocalName && (ctxt->flags & CONNECT_REDIRECT))
+    if (ctxt->flags & CONNECT_INTERACTIVE)
     {
-        FIXME("Locale device selection is not implemented.\n");
-        return WN_NO_NETWORK;
+        ret = ERROR_BAD_NET_NAME;
+        goto done;
     }
 
-    if (ctxt->flags & CONNECT_INTERACTIVE)
-        return ERROR_BAD_NET_NAME;
+    if (!ctxt->resource->lpProvider)
+    {
+        ctxt->resource->lpProvider = select_provider(ctxt);
+        if (!ctxt->resource->lpProvider)
+        {
+            ret = ERROR_NO_NET_OR_BAD_PATH;
+            goto done;
+        }
+
+        prov = TRUE;
+    }
 
     index = _findProviderIndexW(ctxt->resource->lpProvider);
     if (index == BAD_PROVIDER_INDEX)
-        return ERROR_BAD_PROVIDER;
+    {
+        ret = ERROR_BAD_PROVIDER;
+        goto done;
+    }
 
     provider = &providerTable->table[index];
     caps = provider->getCaps(WNNC_CONNECTION);
     if (!(caps & (WNNC_CON_ADDCONNECTION | WNNC_CON_ADDCONNECTION3)))
-        return ERROR_BAD_PROVIDER;
+    {
+        ret = ERROR_BAD_PROVIDER;
+        goto done;
+    }
 
     if ((ret = ctxt->pre_set_accessname(ctxt)))
-        return ret;
+    {
+        goto done;
+    }
 
     ret = WN_ACCESS_DENIED;
-    if ((caps & WNNC_CON_ADDCONNECTION3) && provider->addConnection3)
-        ret = provider->addConnection3(ctxt->hwndOwner, ctxt->resource, ctxt->password, ctxt->userid, ctxt->flags);
-    else if ((caps & WNNC_CON_ADDCONNECTION) && provider->addConnection)
-        ret = provider->addConnection(ctxt->resource, ctxt->password, ctxt->userid);
+    do
+    {
+        if ((caps & WNNC_CON_ADDCONNECTION3) && provider->addConnection3)
+            ret = provider->addConnection3(ctxt->hwndOwner, ctxt->resource, ctxt->password, ctxt->userid, ctxt->flags);
+        else if ((caps & WNNC_CON_ADDCONNECTION) && provider->addConnection)
+            ret = provider->addConnection(ctxt->resource, ctxt->password, ctxt->userid);
+
+        if (redirect)
+            letter[0] -= 1;
+    } while (redirect && ret == WN_ALREADY_CONNECTED && letter[0] >= 'c');
 
     if (ret == WN_SUCCESS && ctxt->accessname)
         ctxt->set_accessname(ctxt);
 
+done:
+    if (prov)
+    {
+        HeapFree(GetProcessHeap(), 0, ctxt->resource->lpProvider);
+        ctxt->resource->lpProvider = NULL;
+    }
+
+    if (redirect)
+        ctxt->resource->lpLocalName = NULL;
+
     return ret;
 }
 
@@ -1820,9 +1916,26 @@ DWORD WINAPI WNetCancelConnection2A( LPCSTR lpName, DWORD dwFlags, BOOL fForce )
  */
 DWORD WINAPI WNetCancelConnection2W( LPCWSTR lpName, DWORD dwFlags, BOOL fForce )
 {
-    FIXME( "(%s, %08X, %d), stub\n", debugstr_w(lpName), dwFlags, fForce );
+    DWORD ret = WN_NO_NETWORK;
+    DWORD index;
 
-    return WN_SUCCESS;
+    if (providerTable != NULL)
+    {
+        for (index = 0; index < providerTable->numProviders; index++)
+        {
+            if(providerTable->table[index].getCaps(WNNC_CONNECTION) &
+                WNNC_CON_GETCONNECTIONS)
+            {
+                if (providerTable->table[index].cancelConnection)
+                    ret = providerTable->table[index].cancelConnection((LPWSTR)lpName, fForce);
+                else
+                    ret = WN_NO_NETWORK;
+                if (ret == WN_SUCCESS || ret == WN_OPEN_FILES)
+                    break;
+            }
+        }
+    }
+    return ret;
 }
 
 /*****************************************************************
@@ -1949,6 +2062,7 @@ DWORD WINAPI WNetGetConnectionA( LPCSTR lpLocalName,
 /* find the network connection for a given drive; helper for WNetGetConnection */
 static DWORD get_drive_connection( WCHAR letter, LPWSTR remote, LPDWORD size )
 {
+#ifndef __REACTOS__
     char buffer[1024];
     struct mountmgr_unix_drive *data = (struct mountmgr_unix_drive *)buffer;
     HANDLE mgr;
@@ -1991,6 +2105,32 @@ static DWORD get_drive_connection( WCHAR letter, LPWSTR remote, LPDWORD size )
     }
     CloseHandle( mgr );
     return ret;
+#else
+    DWORD ret = WN_NO_NETWORK;
+    DWORD index;
+    WCHAR local[3] = {letter, ':', 0};
+
+    if (providerTable != NULL)
+    {
+        for (index = 0; index < providerTable->numProviders; index++)
+        {
+            if(providerTable->table[index].getCaps(WNNC_CONNECTION) &
+                WNNC_CON_GETCONNECTIONS)
+            {
+                if (providerTable->table[index].getConnection)
+                    ret = providerTable->table[index].getConnection(
+                        local, remote, size);
+                else
+                    ret = WN_NO_NETWORK;
+                if (ret == WN_SUCCESS || ret == WN_MORE_DATA)
+                    break;
+            }
+        }
+    }
+    if (ret)
+        SetLastError(ret);
+    return ret;
+#endif
 }
 
 /**************************************************************************