[CMAKE]
[reactos.git] / dll / win32 / urlmon / bindctx.c
index 328254c..0b3efa6 100644 (file)
@@ -28,32 +28,38 @@ static WCHAR BSCBHolder[] = { '_','B','S','C','B','_','H','o','l','d','e','r','_
 extern IID IID_IBindStatusCallbackHolder;
 
 typedef struct {
-    const IBindStatusCallbackVtbl  *lpBindStatusCallbackVtbl;
-    const IServiceProviderVtbl     *lpServiceProviderVtbl;
-    const IHttpNegotiate2Vtbl      *lpHttpNegotiate2Vtbl;
-    const IAuthenticateVtbl        *lpAuthenticateVtbl;
+    const IBindStatusCallbackExVtbl  *lpBindStatusCallbackExVtbl;
+    const IServiceProviderVtbl       *lpServiceProviderVtbl;
+    const IHttpNegotiate2Vtbl        *lpHttpNegotiate2Vtbl;
+    const IAuthenticateVtbl          *lpAuthenticateVtbl;
 
     LONG ref;
 
     IBindStatusCallback *callback;
     IServiceProvider *serv_prov;
-
-    IHttpNegotiate *http_negotiate;
-    BOOL init_http_negotiate;
-    IHttpNegotiate2 *http_negotiate2;
-    BOOL init_http_negotiate2;
-    IAuthenticate *authenticate;
-    BOOL init_authenticate;
 } BindStatusCallback;
 
-#define STATUSCLB(x)     ((IBindStatusCallback*)  &(x)->lpBindStatusCallbackVtbl)
+#define STATUSCLB(x)     ((IBindStatusCallback*)  &(x)->lpBindStatusCallbackExVtbl)
+#define STATUSCLBEX(x)   ((IBindStatusCallbackEx*)&(x)->lpBindStatusCallbackExVtbl)
 #define SERVPROV(x)      ((IServiceProvider*)     &(x)->lpServiceProviderVtbl)
 #define HTTPNEG2(x)      ((IHttpNegotiate2*)      &(x)->lpHttpNegotiate2Vtbl)
 #define AUTHENTICATE(x)  ((IAuthenticate*)        &(x)->lpAuthenticateVtbl)
 
-#define STATUSCLB_THIS(iface) DEFINE_THIS(BindStatusCallback, BindStatusCallback, iface)
+static void *get_callback_iface(BindStatusCallback *This, REFIID riid)
+{
+    void *ret;
+    HRESULT hres;
+
+    hres = IBindStatusCallback_QueryInterface(This->callback, riid, (void**)&ret);
+    if(FAILED(hres) && This->serv_prov)
+        hres = IServiceProvider_QueryService(This->serv_prov, riid, riid, &ret);
 
-static HRESULT WINAPI BindStatusCallback_QueryInterface(IBindStatusCallback *iface,
+    return SUCCEEDED(hres) ? ret : NULL;
+}
+
+#define STATUSCLB_THIS(iface) DEFINE_THIS(BindStatusCallback, BindStatusCallbackEx, iface)
+
+static HRESULT WINAPI BindStatusCallback_QueryInterface(IBindStatusCallbackEx *iface,
         REFIID riid, void **ppv)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
@@ -66,6 +72,9 @@ static HRESULT WINAPI BindStatusCallback_QueryInterface(IBindStatusCallback *ifa
     }else if(IsEqualGUID(&IID_IBindStatusCallback, riid)) {
         TRACE("(%p)->(IID_IBindStatusCallback, %p)\n", This, ppv);
         *ppv = STATUSCLB(This);
+    }else if(IsEqualGUID(&IID_IBindStatusCallbackEx, riid)) {
+        TRACE("(%p)->(IID_IBindStatusCallback, %p)\n", This, ppv);
+        *ppv = STATUSCLBEX(This);
     }else if(IsEqualGUID(&IID_IBindStatusCallbackHolder, riid)) {
         TRACE("(%p)->(IID_IBindStatusCallbackHolder, %p)\n", This, ppv);
         *ppv = This;
@@ -92,7 +101,7 @@ static HRESULT WINAPI BindStatusCallback_QueryInterface(IBindStatusCallback *ifa
     return E_NOINTERFACE;
 }
 
-static ULONG WINAPI BindStatusCallback_AddRef(IBindStatusCallback *iface)
+static ULONG WINAPI BindStatusCallback_AddRef(IBindStatusCallbackEx *iface)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
     LONG ref = InterlockedIncrement(&This->ref);
@@ -102,7 +111,7 @@ static ULONG WINAPI BindStatusCallback_AddRef(IBindStatusCallback *iface)
     return ref;
 }
 
-static ULONG WINAPI BindStatusCallback_Release(IBindStatusCallback *iface)
+static ULONG WINAPI BindStatusCallback_Release(IBindStatusCallbackEx *iface)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
     LONG ref = InterlockedDecrement(&This->ref);
@@ -112,12 +121,6 @@ static ULONG WINAPI BindStatusCallback_Release(IBindStatusCallback *iface)
     if(!ref) {
         if(This->serv_prov)
             IServiceProvider_Release(This->serv_prov);
-        if(This->http_negotiate)
-            IHttpNegotiate_Release(This->http_negotiate);
-        if(This->http_negotiate2)
-            IHttpNegotiate2_Release(This->http_negotiate2);
-        if(This->authenticate)
-            IAuthenticate_Release(This->authenticate);
         IBindStatusCallback_Release(This->callback);
         heap_free(This);
     }
@@ -125,7 +128,7 @@ static ULONG WINAPI BindStatusCallback_Release(IBindStatusCallback *iface)
     return ref;
 }
 
-static HRESULT WINAPI BindStatusCallback_OnStartBinding(IBindStatusCallback *iface,
+static HRESULT WINAPI BindStatusCallback_OnStartBinding(IBindStatusCallbackEx *iface,
         DWORD dwReserved, IBinding *pbind)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
@@ -135,7 +138,7 @@ static HRESULT WINAPI BindStatusCallback_OnStartBinding(IBindStatusCallback *ifa
     return IBindStatusCallback_OnStartBinding(This->callback, 0xff, pbind);
 }
 
-static HRESULT WINAPI BindStatusCallback_GetPriority(IBindStatusCallback *iface, LONG *pnPriority)
+static HRESULT WINAPI BindStatusCallback_GetPriority(IBindStatusCallbackEx *iface, LONG *pnPriority)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
 
@@ -144,7 +147,7 @@ static HRESULT WINAPI BindStatusCallback_GetPriority(IBindStatusCallback *iface,
     return IBindStatusCallback_GetPriority(This->callback, pnPriority);
 }
 
-static HRESULT WINAPI BindStatusCallback_OnLowResource(IBindStatusCallback *iface, DWORD reserved)
+static HRESULT WINAPI BindStatusCallback_OnLowResource(IBindStatusCallbackEx *iface, DWORD reserved)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
 
@@ -153,7 +156,7 @@ static HRESULT WINAPI BindStatusCallback_OnLowResource(IBindStatusCallback *ifac
     return IBindStatusCallback_OnLowResource(This->callback, reserved);
 }
 
-static HRESULT WINAPI BindStatusCallback_OnProgress(IBindStatusCallback *iface, ULONG ulProgress,
+static HRESULT WINAPI BindStatusCallback_OnProgress(IBindStatusCallbackEx *iface, ULONG ulProgress,
         ULONG ulProgressMax, ULONG ulStatusCode, LPCWSTR szStatusText)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
@@ -165,7 +168,7 @@ static HRESULT WINAPI BindStatusCallback_OnProgress(IBindStatusCallback *iface,
             ulProgressMax, ulStatusCode, szStatusText);
 }
 
-static HRESULT WINAPI BindStatusCallback_OnStopBinding(IBindStatusCallback *iface,
+static HRESULT WINAPI BindStatusCallback_OnStopBinding(IBindStatusCallbackEx *iface,
         HRESULT hresult, LPCWSTR szError)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
@@ -175,17 +178,29 @@ static HRESULT WINAPI BindStatusCallback_OnStopBinding(IBindStatusCallback *ifac
     return IBindStatusCallback_OnStopBinding(This->callback, hresult, szError);
 }
 
-static HRESULT WINAPI BindStatusCallback_GetBindInfo(IBindStatusCallback *iface,
+static HRESULT WINAPI BindStatusCallback_GetBindInfo(IBindStatusCallbackEx *iface,
         DWORD *grfBINDF, BINDINFO *pbindinfo)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
+    IBindStatusCallbackEx *bscex;
+    HRESULT hres;
 
     TRACE("(%p)->(%p %p)\n", This, grfBINDF, pbindinfo);
 
-    return IBindStatusCallback_GetBindInfo(This->callback, grfBINDF, pbindinfo);
+    hres = IBindStatusCallback_QueryInterface(This->callback, &IID_IBindStatusCallbackEx, (void**)&bscex);
+    if(SUCCEEDED(hres)) {
+        DWORD bindf2 = 0, reserv = 0;
+
+        hres = IBindStatusCallbackEx_GetBindInfoEx(bscex, grfBINDF, pbindinfo, &bindf2, &reserv);
+        IBindStatusCallbackEx_Release(bscex);
+    }else {
+        hres = IBindStatusCallback_GetBindInfo(This->callback, grfBINDF, pbindinfo);
+    }
+
+    return hres;
 }
 
-static HRESULT WINAPI BindStatusCallback_OnDataAvailable(IBindStatusCallback *iface,
+static HRESULT WINAPI BindStatusCallback_OnDataAvailable(IBindStatusCallbackEx *iface,
         DWORD grfBSCF, DWORD dwSize, FORMATETC *pformatetc, STGMEDIUM *pstgmed)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
@@ -195,7 +210,7 @@ static HRESULT WINAPI BindStatusCallback_OnDataAvailable(IBindStatusCallback *if
     return IBindStatusCallback_OnDataAvailable(This->callback, grfBSCF, dwSize, pformatetc, pstgmed);
 }
 
-static HRESULT WINAPI BindStatusCallback_OnObjectAvailable(IBindStatusCallback *iface,
+static HRESULT WINAPI BindStatusCallback_OnObjectAvailable(IBindStatusCallbackEx *iface,
         REFIID riid, IUnknown *punk)
 {
     BindStatusCallback *This = STATUSCLB_THIS(iface);
@@ -205,9 +220,29 @@ static HRESULT WINAPI BindStatusCallback_OnObjectAvailable(IBindStatusCallback *
     return IBindStatusCallback_OnObjectAvailable(This->callback, riid, punk);
 }
 
+static HRESULT WINAPI BindStatusCallback_GetBindInfoEx(IBindStatusCallbackEx *iface, DWORD *grfBINDF,
+        BINDINFO *pbindinfo, DWORD *grfBINDF2, DWORD *pdwReserved)
+{
+    BindStatusCallback *This = STATUSCLB_THIS(iface);
+    IBindStatusCallbackEx *bscex;
+    HRESULT hres;
+
+    TRACE("(%p)->(%p %p %p %p)\n", This, grfBINDF, pbindinfo, grfBINDF2, pdwReserved);
+
+    hres = IBindStatusCallback_QueryInterface(This->callback, &IID_IBindStatusCallbackEx, (void**)&bscex);
+    if(SUCCEEDED(hres)) {
+        hres = IBindStatusCallbackEx_GetBindInfoEx(bscex, grfBINDF, pbindinfo, grfBINDF2, pdwReserved);
+        IBindStatusCallbackEx_Release(bscex);
+    }else {
+        hres = IBindStatusCallback_GetBindInfo(This->callback, grfBINDF, pbindinfo);
+    }
+
+    return hres;
+}
+
 #undef STATUSCLB_THIS
 
-static const IBindStatusCallbackVtbl BindStatusCallbackVtbl = {
+static const IBindStatusCallbackExVtbl BindStatusCallbackExVtbl = {
     BindStatusCallback_QueryInterface,
     BindStatusCallback_AddRef,
     BindStatusCallback_Release,
@@ -218,7 +253,8 @@ static const IBindStatusCallbackVtbl BindStatusCallbackVtbl = {
     BindStatusCallback_OnStopBinding,
     BindStatusCallback_GetBindInfo,
     BindStatusCallback_OnDataAvailable,
-    BindStatusCallback_OnObjectAvailable
+    BindStatusCallback_OnObjectAvailable,
+    BindStatusCallback_GetBindInfoEx
 };
 
 #define SERVPROV_THIS(iface) DEFINE_THIS(BindStatusCallback, ServiceProvider, iface)
@@ -250,46 +286,16 @@ static HRESULT WINAPI BSCServiceProvider_QueryService(IServiceProvider *iface,
 
     if(IsEqualGUID(&IID_IHttpNegotiate, guidService)) {
         TRACE("(%p)->(IID_IHttpNegotiate %s %p)\n", This, debugstr_guid(riid), ppv);
-
-        if(!This->init_http_negotiate) {
-            This->init_http_negotiate = TRUE;
-            hres = IBindStatusCallback_QueryInterface(This->callback, &IID_IHttpNegotiate,
-                    (void**)&This->http_negotiate);
-            if(FAILED(hres) && This->serv_prov)
-                IServiceProvider_QueryService(This->serv_prov, &IID_IHttpNegotiate,
-                        &IID_IHttpNegotiate, (void**)&This->http_negotiate);
-        }
-
         return IBindStatusCallback_QueryInterface(STATUSCLB(This), riid, ppv);
     }
 
     if(IsEqualGUID(&IID_IHttpNegotiate2, guidService)) {
         TRACE("(%p)->(IID_IHttpNegotiate2 %s %p)\n", This, debugstr_guid(riid), ppv);
-
-        if(!This->init_http_negotiate2) {
-            This->init_http_negotiate2 = TRUE;
-            hres = IBindStatusCallback_QueryInterface(This->callback, &IID_IHttpNegotiate2,
-                    (void**)&This->http_negotiate2);
-            if(FAILED(hres) && This->serv_prov)
-                IServiceProvider_QueryService(This->serv_prov, &IID_IHttpNegotiate2,
-                        &IID_IHttpNegotiate2, (void**)&This->http_negotiate2);
-        }
-
         return IBindStatusCallback_QueryInterface(STATUSCLB(This), riid, ppv);
     }
 
     if(IsEqualGUID(&IID_IAuthenticate, guidService)) {
         TRACE("(%p)->(IID_IAuthenticate %s %p)\n", This, debugstr_guid(riid), ppv);
-
-        if(!This->init_authenticate) {
-            This->init_authenticate = TRUE;
-            hres = IBindStatusCallback_QueryInterface(This->callback, &IID_IAuthenticate,
-                    (void**)&This->authenticate);
-            if(FAILED(hres) && This->serv_prov)
-                IServiceProvider_QueryService(This->serv_prov, &IID_IAuthenticate,
-                        &IID_IAuthenticate, (void**)&This->authenticate);
-        }
-
         return IBindStatusCallback_QueryInterface(STATUSCLB(This), riid, ppv);
     }
 
@@ -342,17 +348,22 @@ static HRESULT WINAPI BSCHttpNegotiate_BeginningTransaction(IHttpNegotiate2 *ifa
         LPCWSTR szURL, LPCWSTR szHeaders, DWORD dwReserved, LPWSTR *pszAdditionalHeaders)
 {
     BindStatusCallback *This = HTTPNEG2_THIS(iface);
+    IHttpNegotiate *http_negotiate;
+    HRESULT hres = S_OK;
 
     TRACE("(%p)->(%s %s %d %p)\n", This, debugstr_w(szURL), debugstr_w(szHeaders), dwReserved,
           pszAdditionalHeaders);
 
     *pszAdditionalHeaders = NULL;
 
-    if(!This->http_negotiate)
-        return S_OK;
+    http_negotiate = get_callback_iface(This, &IID_IHttpNegotiate);
+    if(http_negotiate) {
+        hres = IHttpNegotiate_BeginningTransaction(http_negotiate, szURL, szHeaders,
+                dwReserved, pszAdditionalHeaders);
+        IHttpNegotiate_Release(http_negotiate);
+    }
 
-    return IHttpNegotiate_BeginningTransaction(This->http_negotiate, szURL, szHeaders,
-                                               dwReserved, pszAdditionalHeaders);
+    return hres;
 }
 
 static HRESULT WINAPI BSCHttpNegotiate_OnResponse(IHttpNegotiate2 *iface, DWORD dwResponseCode,
@@ -361,14 +372,18 @@ static HRESULT WINAPI BSCHttpNegotiate_OnResponse(IHttpNegotiate2 *iface, DWORD
 {
     BindStatusCallback *This = HTTPNEG2_THIS(iface);
     LPWSTR additional_headers = NULL;
+    IHttpNegotiate *http_negotiate;
     HRESULT hres = S_OK;
 
     TRACE("(%p)->(%d %s %s %p)\n", This, dwResponseCode, debugstr_w(szResponseHeaders),
           debugstr_w(szRequestHeaders), pszAdditionalRequestHeaders);
 
-    if(This->http_negotiate)
-        hres = IHttpNegotiate_OnResponse(This->http_negotiate, dwResponseCode, szResponseHeaders,
-                                         szRequestHeaders, &additional_headers);
+    http_negotiate = get_callback_iface(This, &IID_IHttpNegotiate);
+    if(http_negotiate) {
+        hres = IHttpNegotiate_OnResponse(http_negotiate, dwResponseCode, szResponseHeaders,
+                szRequestHeaders, &additional_headers);
+        IHttpNegotiate_Release(http_negotiate);
+    }
 
     if(pszAdditionalRequestHeaders)
         *pszAdditionalRequestHeaders = additional_headers;
@@ -382,14 +397,19 @@ static HRESULT WINAPI BSCHttpNegotiate_GetRootSecurityId(IHttpNegotiate2 *iface,
         BYTE *pbSecurityId, DWORD *pcbSecurityId, DWORD_PTR dwReserved)
 {
     BindStatusCallback *This = HTTPNEG2_THIS(iface);
+    IHttpNegotiate2 *http_negotiate2;
+    HRESULT hres = E_FAIL;
 
     TRACE("(%p)->(%p %p %ld)\n", This, pbSecurityId, pcbSecurityId, dwReserved);
 
-    if(!This->http_negotiate2)
-        return E_FAIL;
+    http_negotiate2 = get_callback_iface(This, &IID_IHttpNegotiate2);
+    if(http_negotiate2) {
+        hres = IHttpNegotiate2_GetRootSecurityId(http_negotiate2, pbSecurityId,
+                pcbSecurityId, dwReserved);
+        IHttpNegotiate2_Release(http_negotiate2);
+    }
 
-    return IHttpNegotiate2_GetRootSecurityId(This->http_negotiate2, pbSecurityId,
-                                             pcbSecurityId, dwReserved);
+    return hres;
 }
 
 #undef HTTPNEG2_THIS
@@ -444,7 +464,7 @@ static IBindStatusCallback *create_bsc(IBindStatusCallback *bsc)
 {
     BindStatusCallback *ret = heap_alloc_zero(sizeof(BindStatusCallback));
 
-    ret->lpBindStatusCallbackVtbl = &BindStatusCallbackVtbl;
+    ret->lpBindStatusCallbackExVtbl = &BindStatusCallbackExVtbl;
     ret->lpServiceProviderVtbl    = &BSCServiceProviderVtbl;
     ret->lpHttpNegotiate2Vtbl     = &BSCHttpNegotiateVtbl;
     ret->lpAuthenticateVtbl       = &BSCAuthenticateVtbl;
@@ -511,7 +531,8 @@ HRESULT WINAPI RegisterBindStatusCallback(IBindCtx *pbc, IBindStatusCallback *pb
     hres = IBindCtx_RegisterObjectParam(pbc, BSCBHolder, (IUnknown*)bsc);
     IBindStatusCallback_Release(bsc);
     if(FAILED(hres)) {
-        IBindStatusCallback_Release(prev);
+        if(prev)
+            IBindStatusCallback_Release(prev);
         return hres;
     }