Sync with trunk r63174.
[reactos.git] / dll / win32 / urlmon / protocol.c
index b423c67..84b8448 100644 (file)
 
 #include "urlmon_main.h"
 
-#include "wine/debug.h"
-
-WINE_DEFAULT_DEBUG_CHANNEL(urlmon);
-
-/* Flags are needed for, among other things, return HRESULTs from the Read function
- * to conform to native. For example, Read returns:
- *
- * 1. E_PENDING if called before the request has completed,
- *        (flags = 0)
- * 2. S_FALSE after all data has been read and S_OK has been reported,
- *        (flags = FLAG_REQUEST_COMPLETE | FLAG_ALL_DATA_READ | FLAG_RESULT_REPORTED)
- * 3. INET_E_DATA_NOT_AVAILABLE if InternetQueryDataAvailable fails. The first time
- *    this occurs, INET_E_DATA_NOT_AVAILABLE will also be reported to the sink,
- *        (flags = FLAG_REQUEST_COMPLETE)
- *    but upon subsequent calls to Read no reporting will take place, yet
- *    InternetQueryDataAvailable will still be called, and, on failure,
- *    INET_E_DATA_NOT_AVAILABLE will still be returned.
- *        (flags = FLAG_REQUEST_COMPLETE | FLAG_RESULT_REPORTED)
- *
- * FLAG_FIRST_DATA_REPORTED and FLAG_LAST_DATA_REPORTED are needed for proper
- * ReportData reporting. For example, if OnResponse returns S_OK, Continue will
- * report BSCF_FIRSTDATANOTIFICATION, and when all data has been read Read will
- * report BSCF_INTERMEDIATEDATANOTIFICATION|BSCF_LASTDATANOTIFICATION. However,
- * if OnResponse does not return S_OK, Continue will not report data, and Read
- * will report BSCF_FIRSTDATANOTIFICATION|BSCF_LASTDATANOTIFICATION when all
- * data has been read.
- */
-#define FLAG_REQUEST_COMPLETE         0x0001
-#define FLAG_FIRST_CONTINUE_COMPLETE  0x0002
-#define FLAG_FIRST_DATA_REPORTED      0x0004
-#define FLAG_ALL_DATA_READ            0x0008
-#define FLAG_LAST_DATA_REPORTED       0x0010
-#define FLAG_RESULT_REPORTED          0x0020
-
 static inline HRESULT report_progress(Protocol *protocol, ULONG status_code, LPCWSTR status_text)
 {
     return IInternetProtocolSink_ReportProgress(protocol->protocol_sink, status_code, status_text);
@@ -100,29 +66,82 @@ static void all_data_read(Protocol *protocol)
     report_result(protocol, S_OK);
 }
 
-static void request_complete(Protocol *protocol, INTERNET_ASYNC_RESULT *ar)
+static HRESULT start_downloading(Protocol *protocol)
 {
-    PROTOCOLDATA data;
+    HRESULT hres;
 
-    if(!ar->dwResult) {
-        WARN("request failed: %d\n", ar->dwError);
-        return;
+    hres = protocol->vtbl->start_downloading(protocol);
+    if(FAILED(hres)) {
+        protocol_close_connection(protocol);
+        report_result(protocol, hres);
+        return hres;
     }
 
-    protocol->flags |= FLAG_REQUEST_COMPLETE;
+    if(protocol->bindf & BINDF_NEEDFILE) {
+        WCHAR cache_file[MAX_PATH];
+        DWORD buflen = sizeof(cache_file);
 
-    if(!protocol->request) {
-        TRACE("setting request handle %p\n", (HINTERNET)ar->dwResult);
-        protocol->request = (HINTERNET)ar->dwResult;
+        if(InternetQueryOptionW(protocol->request, INTERNET_OPTION_DATAFILE_NAME, cache_file, &buflen)) {
+            report_progress(protocol, BINDSTATUS_CACHEFILENAMEAVAILABLE, cache_file);
+        }else {
+            FIXME("Could not get cache file\n");
+        }
     }
 
+    protocol->flags |= FLAG_FIRST_CONTINUE_COMPLETE;
+    return S_OK;
+}
+
+HRESULT protocol_syncbinding(Protocol *protocol)
+{
+    BOOL res;
+    HRESULT hres;
+
+    protocol->flags |= FLAG_SYNC_READ;
+
+    hres = start_downloading(protocol);
+    if(FAILED(hres))
+        return hres;
+
+    res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
+    if(res)
+        protocol->available_bytes = protocol->query_available;
+    else
+        WARN("InternetQueryDataAvailable failed: %u\n", GetLastError());
+
+    protocol->flags |= FLAG_FIRST_DATA_REPORTED|FLAG_LAST_DATA_REPORTED;
+    IInternetProtocolSink_ReportData(protocol->protocol_sink, BSCF_LASTDATANOTIFICATION|BSCF_DATAFULLYAVAILABLE,
+            protocol->available_bytes, protocol->content_length);
+    return S_OK;
+}
+
+static void request_complete(Protocol *protocol, INTERNET_ASYNC_RESULT *ar)
+{
+    PROTOCOLDATA data;
+
+    TRACE("(%p)->(%p)\n", protocol, ar);
+
     /* PROTOCOLDATA same as native */
     memset(&data, 0, sizeof(data));
     data.dwState = 0xf1000000;
-    if(protocol->flags & FLAG_FIRST_CONTINUE_COMPLETE)
-        data.pData = (LPVOID)BINDSTATUS_ENDDOWNLOADCOMPONENTS;
-    else
-        data.pData = (LPVOID)BINDSTATUS_DOWNLOADINGDATA;
+
+    if(ar->dwResult) {
+        protocol->flags |= FLAG_REQUEST_COMPLETE;
+
+        if(!protocol->request) {
+            TRACE("setting request handle %p\n", (HINTERNET)ar->dwResult);
+            protocol->request = (HINTERNET)ar->dwResult;
+        }
+
+        if(protocol->flags & FLAG_FIRST_CONTINUE_COMPLETE)
+            data.pData = UlongToPtr(BINDSTATUS_ENDDOWNLOADCOMPONENTS);
+        else
+            data.pData = UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
+
+    }else {
+        protocol->flags |= FLAG_ERROR;
+        data.pData = UlongToPtr(ar->dwError);
+    }
 
     if (protocol->bindf & BINDF_FROMURLMON)
         IInternetProtocolSink_Switch(protocol->protocol_sink, &data);
@@ -141,10 +160,19 @@ static void WINAPI internet_status_callback(HINTERNET internet, DWORD_PTR contex
         report_progress(protocol, BINDSTATUS_FINDINGRESOURCE, (LPWSTR)status_info);
         break;
 
-    case INTERNET_STATUS_CONNECTING_TO_SERVER:
-        TRACE("%p INTERNET_STATUS_CONNECTING_TO_SERVER\n", protocol);
-        report_progress(protocol, BINDSTATUS_CONNECTING, (LPWSTR)status_info);
+    case INTERNET_STATUS_CONNECTING_TO_SERVER: {
+        WCHAR *info;
+
+        TRACE("%p INTERNET_STATUS_CONNECTING_TO_SERVER %s\n", protocol, (const char*)status_info);
+
+        info = heap_strdupAtoW(status_info);
+        if(!info)
+            return;
+
+        report_progress(protocol, BINDSTATUS_CONNECTING, info);
+        heap_free(info);
         break;
+    }
 
     case INTERNET_STATUS_SENDING_REQUEST:
         TRACE("%p INTERNET_STATUS_SENDING_REQUEST\n", protocol);
@@ -191,6 +219,42 @@ static void WINAPI internet_status_callback(HINTERNET internet, DWORD_PTR contex
     }
 }
 
+static HRESULT write_post_stream(Protocol *protocol)
+{
+    BYTE buf[0x20000];
+    DWORD written;
+    ULONG size;
+    BOOL res;
+    HRESULT hres;
+
+    protocol->flags &= ~FLAG_REQUEST_COMPLETE;
+
+    while(1) {
+        size = 0;
+        hres = IStream_Read(protocol->post_stream, buf, sizeof(buf), &size);
+        if(FAILED(hres) || !size)
+            break;
+        res = InternetWriteFile(protocol->request, buf, size, &written);
+        if(!res) {
+            FIXME("InternetWriteFile failed: %u\n", GetLastError());
+            hres = E_FAIL;
+            break;
+        }
+    }
+
+    if(SUCCEEDED(hres)) {
+        IStream_Release(protocol->post_stream);
+        protocol->post_stream = NULL;
+
+        hres = protocol->vtbl->end_request(protocol);
+    }
+
+    if(FAILED(hres))
+        return report_result(protocol, hres);
+
+    return S_OK;
+}
+
 static HINTERNET create_internet_session(IInternetBindInfo *bind_info)
 {
     LPWSTR global_user_agent = NULL;
@@ -276,12 +340,10 @@ HRESULT protocol_start(Protocol *protocol, IInternetProtocol *prot, IUri *uri,
 
 HRESULT protocol_continue(Protocol *protocol, PROTOCOLDATA *data)
 {
+    BOOL is_start;
     HRESULT hres;
 
-    if (!data) {
-        WARN("Expected pProtocolData to be non-NULL\n");
-        return S_OK;
-    }
+    is_start = !data || data->pData == UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
 
     if(!protocol->request) {
         WARN("Expected request to be non-NULL\n");
@@ -293,45 +355,58 @@ HRESULT protocol_continue(Protocol *protocol, PROTOCOLDATA *data)
         return S_OK;
     }
 
-    if(data->pData == (LPVOID)BINDSTATUS_DOWNLOADINGDATA) {
-        hres = protocol->vtbl->start_downloading(protocol);
-        if(FAILED(hres)) {
-            protocol_close_connection(protocol);
-            report_result(protocol, hres);
-            return S_OK;
-        }
+    if(protocol->flags & FLAG_ERROR) {
+        protocol->flags &= ~FLAG_ERROR;
+        protocol->vtbl->on_error(protocol, PtrToUlong(data->pData));
+        return S_OK;
+    }
 
-        if(protocol->bindf & BINDF_NEEDFILE) {
-            WCHAR cache_file[MAX_PATH];
-            DWORD buflen = sizeof(cache_file);
+    if(protocol->post_stream)
+        return write_post_stream(protocol);
 
-            if(InternetQueryOptionW(protocol->request, INTERNET_OPTION_DATAFILE_NAME,
-                    cache_file, &buflen)) {
-                report_progress(protocol, BINDSTATUS_CACHEFILENAMEAVAILABLE, cache_file);
-            }else {
-                FIXME("Could not get cache file\n");
-            }
-        }
-
-        protocol->flags |= FLAG_FIRST_CONTINUE_COMPLETE;
+    if(is_start) {
+        hres = start_downloading(protocol);
+        if(FAILED(hres))
+            return S_OK;
     }
 
-    if(data->pData >= (LPVOID)BINDSTATUS_DOWNLOADINGDATA) {
-        BOOL res;
+    if(!data || data->pData >= UlongToPtr(BINDSTATUS_DOWNLOADINGDATA)) {
+        if(!protocol->available_bytes) {
+            if(protocol->query_available) {
+                protocol->available_bytes = protocol->query_available;
+            }else {
+                BOOL res;
+
+                /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
+                 * read, so clear the flag _before_ calling so it does not incorrectly get cleared
+                 * after the status callback is called */
+                protocol->flags &= ~FLAG_REQUEST_COMPLETE;
+                res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
+                if(res) {
+                    TRACE("available %u bytes\n", protocol->query_available);
+                    if(!protocol->query_available) {
+                        if(is_start) {
+                            TRACE("empty file\n");
+                            all_data_read(protocol);
+                        }else {
+                            WARN("unexpected end of file?\n");
+                            report_result(protocol, INET_E_DOWNLOAD_FAILURE);
+                        }
+                        return S_OK;
+                    }
+                    protocol->available_bytes = protocol->query_available;
+                }else if(GetLastError() != ERROR_IO_PENDING) {
+                    protocol->flags |= FLAG_REQUEST_COMPLETE;
+                    WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
+                    report_result(protocol, INET_E_DATA_NOT_AVAILABLE);
+                    return S_OK;
+                }
+            }
 
-        /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
-         * read, so clear the flag _before_ calling so it does not incorrectly get cleared
-         * after the status callback is called */
-        protocol->flags &= ~FLAG_REQUEST_COMPLETE;
-        res = InternetQueryDataAvailable(protocol->request, &protocol->available_bytes, 0, 0);
-        if(res) {
             protocol->flags |= FLAG_REQUEST_COMPLETE;
-            report_data(protocol);
-        }else if(GetLastError() != ERROR_IO_PENDING) {
-            protocol->flags |= FLAG_REQUEST_COMPLETE;
-            WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
-            report_result(protocol, INET_E_DATA_NOT_AVAILABLE);
         }
+
+        report_data(protocol);
     }
 
     return S_OK;
@@ -348,38 +423,40 @@ HRESULT protocol_read(Protocol *protocol, void *buf, ULONG size, ULONG *read_ret
         return S_FALSE;
     }
 
-    if(!(protocol->flags & FLAG_REQUEST_COMPLETE)) {
+    if(!(protocol->flags & FLAG_SYNC_READ) && (!(protocol->flags & FLAG_REQUEST_COMPLETE) || !protocol->available_bytes)) {
         *read_ret = 0;
         return E_PENDING;
     }
 
-    while(read < size) {
-        if(protocol->available_bytes) {
-            ULONG len;
+    while(read < size && protocol->available_bytes) {
+        ULONG len;
 
-            res = InternetReadFile(protocol->request, ((BYTE *)buf)+read,
-                    protocol->available_bytes > size-read ? size-read : protocol->available_bytes, &len);
-            if(!res) {
-                WARN("InternetReadFile failed: %d\n", GetLastError());
-                hres = INET_E_DOWNLOAD_FAILURE;
-                report_result(protocol, hres);
-                break;
-            }
+        res = InternetReadFile(protocol->request, ((BYTE *)buf)+read,
+                protocol->available_bytes > size-read ? size-read : protocol->available_bytes, &len);
+        if(!res) {
+            WARN("InternetReadFile failed: %d\n", GetLastError());
+            hres = INET_E_DOWNLOAD_FAILURE;
+            report_result(protocol, hres);
+            break;
+        }
 
-            if(!len) {
-                all_data_read(protocol);
-                break;
-            }
+        if(!len) {
+            all_data_read(protocol);
+            break;
+        }
 
-            read += len;
-            protocol->current_position += len;
-            protocol->available_bytes -= len;
-        }else {
+        read += len;
+        protocol->current_position += len;
+        protocol->available_bytes -= len;
+
+        TRACE("current_position %d, available_bytes %d\n", protocol->current_position, protocol->available_bytes);
+
+        if(!protocol->available_bytes) {
             /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
              * read, so clear the flag _before_ calling so it does not incorrectly get cleared
              * after the status callback is called */
             protocol->flags &= ~FLAG_REQUEST_COMPLETE;
-            res = InternetQueryDataAvailable(protocol->request, &protocol->available_bytes, 0, 0);
+            res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
             if(!res) {
                 if (GetLastError() == ERROR_IO_PENDING) {
                     hres = E_PENDING;
@@ -391,10 +468,12 @@ HRESULT protocol_read(Protocol *protocol, void *buf, ULONG size, ULONG *read_ret
                 break;
             }
 
-            if(!protocol->available_bytes) {
+            if(!protocol->query_available) {
                 all_data_read(protocol);
                 break;
             }
+
+            protocol->available_bytes = protocol->query_available;
         }
     }
 
@@ -428,6 +507,19 @@ HRESULT protocol_unlock_request(Protocol *protocol)
     return S_OK;
 }
 
+HRESULT protocol_abort(Protocol *protocol, HRESULT reason)
+{
+    if(!protocol->protocol_sink)
+        return S_OK;
+
+    /* NOTE: IE10 returns S_OK here */
+    if(protocol->flags & FLAG_RESULT_REPORTED)
+        return INET_E_RESULT_DISPATCHED;
+
+    report_result(protocol, reason);
+    return S_OK;
+}
+
 void protocol_close_connection(Protocol *protocol)
 {
     protocol->vtbl->close_connection(protocol);
@@ -438,5 +530,10 @@ void protocol_close_connection(Protocol *protocol)
     if(protocol->connection)
         InternetCloseHandle(protocol->connection);
 
+    if(protocol->post_stream) {
+        IStream_Release(protocol->post_stream);
+        protocol->post_stream = NULL;
+    }
+
     protocol->flags = 0;
 }