[CMAKE]
[reactos.git] / dll / win32 / urlmon / binding.c
index 210e8e2..e0f3624 100644 (file)
@@ -47,6 +47,7 @@ typedef struct {
     BYTE buf[1024*8];
     DWORD size;
     BOOL init;
+    HANDLE file;
     HRESULT hres;
 
     LPWSTR cache_file;
@@ -57,7 +58,7 @@ typedef struct _stgmed_obj_t stgmed_obj_t;
 typedef struct {
     void (*release)(stgmed_obj_t*);
     HRESULT (*fill_stgmed)(stgmed_obj_t*,STGMEDIUM*);
-    void *(*get_result)(stgmed_obj_t*);
+    HRESULT (*get_result)(stgmed_obj_t*,DWORD,void**);
 } stgmed_obj_vtbl;
 
 struct _stgmed_obj_t {
@@ -98,8 +99,10 @@ struct Binding {
     LPWSTR mime;
     UINT clipboard_format;
     LPWSTR url;
+    LPWSTR redirect_url;
     IID iid;
     BOOL report_mime;
+    BOOL use_cache_file;
     DWORD state;
     HRESULT hres;
     download_state_t download_state;
@@ -133,6 +136,20 @@ static void fill_stgmed_buffer(stgmed_buf_t *buf)
         buf->init = TRUE;
 }
 
+static void read_protocol_data(stgmed_buf_t *stgmed_buf)
+{
+    BYTE buf[8192];
+    DWORD read;
+    HRESULT hres;
+
+    fill_stgmed_buffer(stgmed_buf);
+    if(stgmed_buf->size < sizeof(stgmed_buf->buf))
+        return;
+
+    do hres = IInternetProtocol_Read(stgmed_buf->protocol, buf, sizeof(buf), &read);
+    while(hres == S_OK);
+}
+
 static void dump_BINDINFO(BINDINFO *bi)
 {
     static const char * const BINDINFOF_str[] = {
@@ -338,6 +355,19 @@ static void create_object(Binding *binding)
         IInternetProtocol_Terminate(binding->protocol, 0);
 }
 
+static void cache_file_available(Binding *This, const WCHAR *file_name)
+{
+    heap_free(This->stgmed_buf->cache_file);
+    This->stgmed_buf->cache_file = heap_strdupW(file_name);
+
+    if(This->use_cache_file) {
+        This->stgmed_buf->file = CreateFileW(file_name, GENERIC_READ, FILE_SHARE_READ|FILE_SHARE_WRITE, NULL,
+                OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
+        if(This->stgmed_buf->file == INVALID_HANDLE_VALUE)
+            WARN("CreateFile failed: %u\n", GetLastError());
+    }
+}
+
 #define STGMEDUNK_THIS(iface) DEFINE_THIS(stgmed_buf_t, Unknown, iface)
 
 static HRESULT WINAPI StgMedUnk_QueryInterface(IUnknown *iface, REFIID riid, void **ppv)
@@ -376,6 +406,8 @@ static ULONG WINAPI StgMedUnk_Release(IUnknown *iface)
     TRACE("(%p) ref=%d\n", This, ref);
 
     if(!ref) {
+        if(This->file != INVALID_HANDLE_VALUE)
+            CloseHandle(This->file);
         IInternetProtocol_Release(This->protocol);
         heap_free(This->cache_file);
         heap_free(This);
@@ -402,6 +434,7 @@ static stgmed_buf_t *create_stgmed_buf(IInternetProtocol *protocol)
     ret->ref = 1;
     ret->size = 0;
     ret->init = FALSE;
+    ret->file = INVALID_HANDLE_VALUE;
     ret->hres = S_OK;
     ret->cache_file = NULL;
 
@@ -487,6 +520,15 @@ static HRESULT WINAPI ProtocolStream_Read(IStream *iface, void *pv,
 
     TRACE("(%p)->(%p %d %p)\n", This, pv, cb, pcbRead);
 
+    if(This->buf->file != INVALID_HANDLE_VALUE) {
+        if (!ReadFile(This->buf->file, pv, cb, &read, NULL))
+            return INET_E_DOWNLOAD_FAILURE;
+
+        if(pcbRead)
+            *pcbRead = read;
+        return read ? S_OK : S_FALSE;
+    }
+
     if(This->buf->size) {
         read = cb;
 
@@ -636,12 +678,13 @@ static HRESULT stgmed_stream_fill_stgmed(stgmed_obj_t *obj, STGMEDIUM *stgmed)
     return S_OK;
 }
 
-static void *stgmed_stream_get_result(stgmed_obj_t *obj)
+static HRESULT stgmed_stream_get_result(stgmed_obj_t *obj, DWORD bindf, void **result)
 {
     ProtocolStream *stream = (ProtocolStream*)obj;
 
     IStream_AddRef(STREAM(stream));
-    return STREAM(stream);
+    *result = STREAM(stream);
+    return S_OK;
 }
 
 static const stgmed_obj_vtbl stgmed_stream_vtbl = {
@@ -688,16 +731,7 @@ static HRESULT stgmed_file_fill_stgmed(stgmed_obj_t *obj, STGMEDIUM *stgmed)
         return INET_E_DATA_NOT_AVAILABLE;
     }
 
-    fill_stgmed_buffer(file_obj->buf);
-    if(file_obj->buf->size == sizeof(file_obj->buf->buf)) {
-        BYTE buf[1024];
-        DWORD read;
-        HRESULT hres;
-
-        do {
-            hres = IInternetProtocol_Read(file_obj->buf->protocol, buf, sizeof(buf), &read);
-        }while(hres == S_OK);
-    }
+    read_protocol_data(file_obj->buf);
 
     stgmed->tymed = TYMED_FILE;
     stgmed->u.lpszFileName = file_obj->buf->cache_file;
@@ -706,9 +740,9 @@ static HRESULT stgmed_file_fill_stgmed(stgmed_obj_t *obj, STGMEDIUM *stgmed)
     return S_OK;
 }
 
-static void *stgmed_file_get_result(stgmed_obj_t *obj)
+static HRESULT stgmed_file_get_result(stgmed_obj_t *obj, DWORD bindf, void **result)
 {
-    return NULL;
+    return bindf & BINDF_ASYNCHRONOUS ? MK_S_ASYNCHRONOUS : S_OK;
 }
 
 static const stgmed_obj_vtbl stgmed_file_vtbl = {
@@ -829,6 +863,7 @@ static ULONG WINAPI Binding_Release(IBinding *iface)
         This->section.DebugInfo->Spare[0] = 0;
         DeleteCriticalSection(&This->section);
         heap_free(This->mime);
+        heap_free(This->redirect_url);
         heap_free(This->url);
 
         heap_free(This);
@@ -967,6 +1002,11 @@ static HRESULT WINAPI InternetProtocolSink_ReportProgress(IInternetProtocolSink
     case BINDSTATUS_CONNECTING:
         on_progress(This, 0, 0, BINDSTATUS_CONNECTING, szStatusText);
         break;
+    case BINDSTATUS_REDIRECTING:
+        heap_free(This->redirect_url);
+        This->redirect_url = heap_strdupW(szStatusText);
+        on_progress(This, 0, 0, BINDSTATUS_REDIRECTING, szStatusText);
+        break;
     case BINDSTATUS_BEGINDOWNLOADDATA:
         fill_stgmed_buffer(This->stgmed_buf);
         break;
@@ -980,8 +1020,7 @@ static HRESULT WINAPI InternetProtocolSink_ReportProgress(IInternetProtocolSink
         mime_available(This, szStatusText);
         break;
     case BINDSTATUS_CACHEFILENAMEAVAILABLE:
-        heap_free(This->stgmed_buf->cache_file);
-        This->stgmed_buf->cache_file = heap_strdupW(szStatusText);
+        cache_file_available(This, szStatusText);
         break;
     case BINDSTATUS_DECODING:
         IBindStatusCallback_OnProgress(This->callback, 0, 0, BINDSTATUS_DECODING, szStatusText);
@@ -1012,9 +1051,12 @@ static void report_data(Binding *This, DWORD bscf, ULONG progress, ULONG progres
     if(This->download_state == END_DOWNLOAD || (This->state & BINDING_STOPPED))
         return;
 
-    if(This->download_state == BEFORE_DOWNLOAD) {
+    if(This->stgmed_buf->file != INVALID_HANDLE_VALUE)
+        read_protocol_data(This->stgmed_buf);
+    else if(This->download_state == BEFORE_DOWNLOAD)
         fill_stgmed_buffer(This->stgmed_buf);
 
+    if(This->download_state == BEFORE_DOWNLOAD) {
         This->download_state = DOWNLOADING;
         sent_begindownloaddata = TRUE;
         IBindStatusCallback_OnProgress(This->callback, progress, progress_max,
@@ -1423,8 +1465,12 @@ static HRESULT Binding_Create(IMoniker *mon, Binding *binding_ctx, LPCWSTR url,
     if(to_obj)
         ret->bindinfo.dwOptions |= 0x100000;
 
-    if(!is_urlmon_protocol(url))
+    if(!(ret->bindf & BINDF_ASYNCHRONOUS)) {
         ret->bindf |= BINDF_NEEDFILE;
+        ret->use_cache_file = TRUE;
+    }else if(!is_urlmon_protocol(url)) {
+        ret->bindf |= BINDF_NEEDFILE;
+    }
 
     ret->url = heap_strdupW(url);
 
@@ -1474,6 +1520,8 @@ static HRESULT start_binding(IMoniker *mon, Binding *binding_ctx, LPCWSTR url, I
 
     if(binding_ctx) {
         set_binding_sink(binding->protocol, PROTSINK(binding));
+        if(binding_ctx->redirect_url)
+            IBindStatusCallback_OnProgress(binding->callback, 0, 0, BINDSTATUS_REDIRECTING, binding_ctx->redirect_url);
         report_data(binding, 0, 0, 0);
     }else {
         hres = IInternetProtocol_Start(binding->protocol, url, PROTSINK(binding),
@@ -1521,12 +1569,14 @@ HRESULT bind_to_storage(LPCWSTR url, IBindCtx *pbc, REFIID riid, void **ppv)
         if((binding->state & BINDING_STOPPED) && (binding->state & BINDING_LOCKED))
             IInternetProtocol_UnlockRequest(binding->protocol);
 
-        *ppv = binding->stgmed_obj->vtbl->get_result(binding->stgmed_obj);
+        hres = binding->stgmed_obj->vtbl->get_result(binding->stgmed_obj, binding->bindf, ppv);
+    }else {
+        hres = MK_S_ASYNCHRONOUS;
     }
 
     IBinding_Release(BINDING(binding));
 
-    return *ppv ? S_OK : MK_S_ASYNCHRONOUS;
+    return hres;
 }
 
 HRESULT bind_to_object(IMoniker *mon, LPCWSTR url, IBindCtx *pbc, REFIID riid, void **ppv)