[rshell]
[reactos.git] / dll / win32 / urlmon / protocol.c
1 /*
2 * Copyright 2007 Misha Koshelev
3 * Copyright 2009 Jacek Caban for CodeWeavers
4 *
5 * This library is free software; you can redistribute it and/or
6 * modify it under the terms of the GNU Lesser General Public
7 * License as published by the Free Software Foundation; either
8 * version 2.1 of the License, or (at your option) any later version.
9 *
10 * This library is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 * Lesser General Public License for more details.
14 *
15 * You should have received a copy of the GNU Lesser General Public
16 * License along with this library; if not, write to the Free Software
17 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18 */
19
20 #include "urlmon_main.h"
21
22 static inline HRESULT report_progress(Protocol *protocol, ULONG status_code, LPCWSTR status_text)
23 {
24 return IInternetProtocolSink_ReportProgress(protocol->protocol_sink, status_code, status_text);
25 }
26
27 static inline HRESULT report_result(Protocol *protocol, HRESULT hres)
28 {
29 if (!(protocol->flags & FLAG_RESULT_REPORTED) && protocol->protocol_sink) {
30 protocol->flags |= FLAG_RESULT_REPORTED;
31 IInternetProtocolSink_ReportResult(protocol->protocol_sink, hres, 0, NULL);
32 }
33
34 return hres;
35 }
36
37 static void report_data(Protocol *protocol)
38 {
39 DWORD bscf;
40
41 if((protocol->flags & FLAG_LAST_DATA_REPORTED) || !protocol->protocol_sink)
42 return;
43
44 if(protocol->flags & FLAG_FIRST_DATA_REPORTED) {
45 bscf = BSCF_INTERMEDIATEDATANOTIFICATION;
46 }else {
47 protocol->flags |= FLAG_FIRST_DATA_REPORTED;
48 bscf = BSCF_FIRSTDATANOTIFICATION;
49 }
50
51 if(protocol->flags & FLAG_ALL_DATA_READ && !(protocol->flags & FLAG_LAST_DATA_REPORTED)) {
52 protocol->flags |= FLAG_LAST_DATA_REPORTED;
53 bscf |= BSCF_LASTDATANOTIFICATION;
54 }
55
56 IInternetProtocolSink_ReportData(protocol->protocol_sink, bscf,
57 protocol->current_position+protocol->available_bytes,
58 protocol->content_length);
59 }
60
61 static void all_data_read(Protocol *protocol)
62 {
63 protocol->flags |= FLAG_ALL_DATA_READ;
64
65 report_data(protocol);
66 report_result(protocol, S_OK);
67 }
68
69 static HRESULT start_downloading(Protocol *protocol)
70 {
71 HRESULT hres;
72
73 hres = protocol->vtbl->start_downloading(protocol);
74 if(FAILED(hres)) {
75 protocol_close_connection(protocol);
76 report_result(protocol, hres);
77 return hres;
78 }
79
80 if(protocol->bindf & BINDF_NEEDFILE) {
81 WCHAR cache_file[MAX_PATH];
82 DWORD buflen = sizeof(cache_file);
83
84 if(InternetQueryOptionW(protocol->request, INTERNET_OPTION_DATAFILE_NAME, cache_file, &buflen)) {
85 report_progress(protocol, BINDSTATUS_CACHEFILENAMEAVAILABLE, cache_file);
86 }else {
87 FIXME("Could not get cache file\n");
88 }
89 }
90
91 protocol->flags |= FLAG_FIRST_CONTINUE_COMPLETE;
92 return S_OK;
93 }
94
95 HRESULT protocol_syncbinding(Protocol *protocol)
96 {
97 BOOL res;
98 HRESULT hres;
99
100 protocol->flags |= FLAG_SYNC_READ;
101
102 hres = start_downloading(protocol);
103 if(FAILED(hres))
104 return hres;
105
106 res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
107 if(res)
108 protocol->available_bytes = protocol->query_available;
109 else
110 WARN("InternetQueryDataAvailable failed: %u\n", GetLastError());
111
112 protocol->flags |= FLAG_FIRST_DATA_REPORTED|FLAG_LAST_DATA_REPORTED;
113 IInternetProtocolSink_ReportData(protocol->protocol_sink, BSCF_LASTDATANOTIFICATION|BSCF_DATAFULLYAVAILABLE,
114 protocol->available_bytes, protocol->content_length);
115 return S_OK;
116 }
117
118 static void request_complete(Protocol *protocol, INTERNET_ASYNC_RESULT *ar)
119 {
120 PROTOCOLDATA data;
121
122 TRACE("(%p)->(%p)\n", protocol, ar);
123
124 /* PROTOCOLDATA same as native */
125 memset(&data, 0, sizeof(data));
126 data.dwState = 0xf1000000;
127
128 if(ar->dwResult) {
129 protocol->flags |= FLAG_REQUEST_COMPLETE;
130
131 if(!protocol->request) {
132 TRACE("setting request handle %p\n", (HINTERNET)ar->dwResult);
133 protocol->request = (HINTERNET)ar->dwResult;
134 }
135
136 if(protocol->flags & FLAG_FIRST_CONTINUE_COMPLETE)
137 data.pData = UlongToPtr(BINDSTATUS_ENDDOWNLOADCOMPONENTS);
138 else
139 data.pData = UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
140
141 }else {
142 protocol->flags |= FLAG_ERROR;
143 data.pData = UlongToPtr(ar->dwError);
144 }
145
146 if (protocol->bindf & BINDF_FROMURLMON)
147 IInternetProtocolSink_Switch(protocol->protocol_sink, &data);
148 else
149 protocol_continue(protocol, &data);
150 }
151
152 static void WINAPI internet_status_callback(HINTERNET internet, DWORD_PTR context,
153 DWORD internet_status, LPVOID status_info, DWORD status_info_len)
154 {
155 Protocol *protocol = (Protocol*)context;
156
157 switch(internet_status) {
158 case INTERNET_STATUS_RESOLVING_NAME:
159 TRACE("%p INTERNET_STATUS_RESOLVING_NAME\n", protocol);
160 report_progress(protocol, BINDSTATUS_FINDINGRESOURCE, (LPWSTR)status_info);
161 break;
162
163 case INTERNET_STATUS_CONNECTING_TO_SERVER: {
164 WCHAR *info;
165
166 TRACE("%p INTERNET_STATUS_CONNECTING_TO_SERVER %s\n", protocol, (const char*)status_info);
167
168 info = heap_strdupAtoW(status_info);
169 if(!info)
170 return;
171
172 report_progress(protocol, BINDSTATUS_CONNECTING, info);
173 heap_free(info);
174 break;
175 }
176
177 case INTERNET_STATUS_SENDING_REQUEST:
178 TRACE("%p INTERNET_STATUS_SENDING_REQUEST\n", protocol);
179 report_progress(protocol, BINDSTATUS_SENDINGREQUEST, (LPWSTR)status_info);
180 break;
181
182 case INTERNET_STATUS_REDIRECT:
183 TRACE("%p INTERNET_STATUS_REDIRECT\n", protocol);
184 report_progress(protocol, BINDSTATUS_REDIRECTING, (LPWSTR)status_info);
185 break;
186
187 case INTERNET_STATUS_REQUEST_COMPLETE:
188 request_complete(protocol, status_info);
189 break;
190
191 case INTERNET_STATUS_HANDLE_CREATED:
192 TRACE("%p INTERNET_STATUS_HANDLE_CREATED\n", protocol);
193 IInternetProtocol_AddRef(protocol->protocol);
194 break;
195
196 case INTERNET_STATUS_HANDLE_CLOSING:
197 TRACE("%p INTERNET_STATUS_HANDLE_CLOSING\n", protocol);
198
199 if(*(HINTERNET *)status_info == protocol->request) {
200 protocol->request = NULL;
201 if(protocol->protocol_sink) {
202 IInternetProtocolSink_Release(protocol->protocol_sink);
203 protocol->protocol_sink = NULL;
204 }
205
206 if(protocol->bind_info.cbSize) {
207 ReleaseBindInfo(&protocol->bind_info);
208 memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
209 }
210 }else if(*(HINTERNET *)status_info == protocol->connection) {
211 protocol->connection = NULL;
212 }
213
214 IInternetProtocol_Release(protocol->protocol);
215 break;
216
217 default:
218 WARN("Unhandled Internet status callback %d\n", internet_status);
219 }
220 }
221
222 static HRESULT write_post_stream(Protocol *protocol)
223 {
224 BYTE buf[0x20000];
225 DWORD written;
226 ULONG size;
227 BOOL res;
228 HRESULT hres;
229
230 protocol->flags &= ~FLAG_REQUEST_COMPLETE;
231
232 while(1) {
233 size = 0;
234 hres = IStream_Read(protocol->post_stream, buf, sizeof(buf), &size);
235 if(FAILED(hres) || !size)
236 break;
237 res = InternetWriteFile(protocol->request, buf, size, &written);
238 if(!res) {
239 FIXME("InternetWriteFile failed: %u\n", GetLastError());
240 hres = E_FAIL;
241 break;
242 }
243 }
244
245 if(SUCCEEDED(hres)) {
246 IStream_Release(protocol->post_stream);
247 protocol->post_stream = NULL;
248
249 hres = protocol->vtbl->end_request(protocol);
250 }
251
252 if(FAILED(hres))
253 return report_result(protocol, hres);
254
255 return S_OK;
256 }
257
258 static HINTERNET create_internet_session(IInternetBindInfo *bind_info)
259 {
260 LPWSTR global_user_agent = NULL;
261 LPOLESTR user_agent = NULL;
262 ULONG size = 0;
263 HINTERNET ret;
264 HRESULT hres;
265
266 hres = IInternetBindInfo_GetBindString(bind_info, BINDSTRING_USER_AGENT, &user_agent, 1, &size);
267 if(hres != S_OK || !size)
268 global_user_agent = get_useragent();
269
270 ret = InternetOpenW(user_agent ? user_agent : global_user_agent, 0, NULL, NULL, INTERNET_FLAG_ASYNC);
271 heap_free(global_user_agent);
272 CoTaskMemFree(user_agent);
273 if(!ret) {
274 WARN("InternetOpen failed: %d\n", GetLastError());
275 return NULL;
276 }
277
278 InternetSetStatusCallbackW(ret, internet_status_callback);
279 return ret;
280 }
281
282 static HINTERNET internet_session;
283
284 HINTERNET get_internet_session(IInternetBindInfo *bind_info)
285 {
286 HINTERNET new_session;
287
288 if(internet_session)
289 return internet_session;
290
291 if(!bind_info)
292 return NULL;
293
294 new_session = create_internet_session(bind_info);
295 if(new_session && InterlockedCompareExchangePointer((void**)&internet_session, new_session, NULL))
296 InternetCloseHandle(new_session);
297
298 return internet_session;
299 }
300
301 HRESULT protocol_start(Protocol *protocol, IInternetProtocol *prot, IUri *uri,
302 IInternetProtocolSink *protocol_sink, IInternetBindInfo *bind_info)
303 {
304 DWORD request_flags;
305 HRESULT hres;
306
307 protocol->protocol = prot;
308
309 IInternetProtocolSink_AddRef(protocol_sink);
310 protocol->protocol_sink = protocol_sink;
311
312 memset(&protocol->bind_info, 0, sizeof(protocol->bind_info));
313 protocol->bind_info.cbSize = sizeof(BINDINFO);
314 hres = IInternetBindInfo_GetBindInfo(bind_info, &protocol->bindf, &protocol->bind_info);
315 if(hres != S_OK) {
316 WARN("GetBindInfo failed: %08x\n", hres);
317 return report_result(protocol, hres);
318 }
319
320 if(!(protocol->bindf & BINDF_FROMURLMON))
321 report_progress(protocol, BINDSTATUS_DIRECTBIND, NULL);
322
323 if(!get_internet_session(bind_info))
324 return report_result(protocol, INET_E_NO_SESSION);
325
326 request_flags = INTERNET_FLAG_KEEP_CONNECTION;
327 if(protocol->bindf & BINDF_NOWRITECACHE)
328 request_flags |= INTERNET_FLAG_NO_CACHE_WRITE;
329 if(protocol->bindf & BINDF_NEEDFILE)
330 request_flags |= INTERNET_FLAG_NEED_FILE;
331
332 hres = protocol->vtbl->open_request(protocol, uri, request_flags, internet_session, bind_info);
333 if(FAILED(hres)) {
334 protocol_close_connection(protocol);
335 return report_result(protocol, hres);
336 }
337
338 return S_OK;
339 }
340
341 HRESULT protocol_continue(Protocol *protocol, PROTOCOLDATA *data)
342 {
343 BOOL is_start;
344 HRESULT hres;
345
346 is_start = !data || data->pData == UlongToPtr(BINDSTATUS_DOWNLOADINGDATA);
347
348 if(!protocol->request) {
349 WARN("Expected request to be non-NULL\n");
350 return S_OK;
351 }
352
353 if(!protocol->protocol_sink) {
354 WARN("Expected IInternetProtocolSink pointer to be non-NULL\n");
355 return S_OK;
356 }
357
358 if(protocol->flags & FLAG_ERROR) {
359 protocol->flags &= ~FLAG_ERROR;
360 protocol->vtbl->on_error(protocol, PtrToUlong(data->pData));
361 return S_OK;
362 }
363
364 if(protocol->post_stream)
365 return write_post_stream(protocol);
366
367 if(is_start) {
368 hres = start_downloading(protocol);
369 if(FAILED(hres))
370 return S_OK;
371 }
372
373 if(!data || data->pData >= UlongToPtr(BINDSTATUS_DOWNLOADINGDATA)) {
374 if(!protocol->available_bytes) {
375 if(protocol->query_available) {
376 protocol->available_bytes = protocol->query_available;
377 }else {
378 BOOL res;
379
380 /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
381 * read, so clear the flag _before_ calling so it does not incorrectly get cleared
382 * after the status callback is called */
383 protocol->flags &= ~FLAG_REQUEST_COMPLETE;
384 res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
385 if(res) {
386 TRACE("available %u bytes\n", protocol->query_available);
387 if(!protocol->query_available) {
388 if(is_start) {
389 TRACE("empty file\n");
390 all_data_read(protocol);
391 }else {
392 WARN("unexpected end of file?\n");
393 report_result(protocol, INET_E_DOWNLOAD_FAILURE);
394 }
395 return S_OK;
396 }
397 protocol->available_bytes = protocol->query_available;
398 }else if(GetLastError() != ERROR_IO_PENDING) {
399 protocol->flags |= FLAG_REQUEST_COMPLETE;
400 WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
401 report_result(protocol, INET_E_DATA_NOT_AVAILABLE);
402 return S_OK;
403 }
404 }
405
406 protocol->flags |= FLAG_REQUEST_COMPLETE;
407 }
408
409 report_data(protocol);
410 }
411
412 return S_OK;
413 }
414
415 HRESULT protocol_read(Protocol *protocol, void *buf, ULONG size, ULONG *read_ret)
416 {
417 ULONG read = 0;
418 BOOL res;
419 HRESULT hres = S_FALSE;
420
421 if(protocol->flags & FLAG_ALL_DATA_READ) {
422 *read_ret = 0;
423 return S_FALSE;
424 }
425
426 if(!(protocol->flags & FLAG_SYNC_READ) && (!(protocol->flags & FLAG_REQUEST_COMPLETE) || !protocol->available_bytes)) {
427 *read_ret = 0;
428 return E_PENDING;
429 }
430
431 while(read < size && protocol->available_bytes) {
432 ULONG len;
433
434 res = InternetReadFile(protocol->request, ((BYTE *)buf)+read,
435 protocol->available_bytes > size-read ? size-read : protocol->available_bytes, &len);
436 if(!res) {
437 WARN("InternetReadFile failed: %d\n", GetLastError());
438 hres = INET_E_DOWNLOAD_FAILURE;
439 report_result(protocol, hres);
440 break;
441 }
442
443 if(!len) {
444 all_data_read(protocol);
445 break;
446 }
447
448 read += len;
449 protocol->current_position += len;
450 protocol->available_bytes -= len;
451
452 TRACE("current_position %d, available_bytes %d\n", protocol->current_position, protocol->available_bytes);
453
454 if(!protocol->available_bytes) {
455 /* InternetQueryDataAvailable may immediately fork and perform its asynchronous
456 * read, so clear the flag _before_ calling so it does not incorrectly get cleared
457 * after the status callback is called */
458 protocol->flags &= ~FLAG_REQUEST_COMPLETE;
459 res = InternetQueryDataAvailable(protocol->request, &protocol->query_available, 0, 0);
460 if(!res) {
461 if (GetLastError() == ERROR_IO_PENDING) {
462 hres = E_PENDING;
463 }else {
464 WARN("InternetQueryDataAvailable failed: %d\n", GetLastError());
465 hres = INET_E_DATA_NOT_AVAILABLE;
466 report_result(protocol, hres);
467 }
468 break;
469 }
470
471 if(!protocol->query_available) {
472 all_data_read(protocol);
473 break;
474 }
475
476 protocol->available_bytes = protocol->query_available;
477 }
478 }
479
480 *read_ret = read;
481
482 if (hres != E_PENDING)
483 protocol->flags |= FLAG_REQUEST_COMPLETE;
484 if(FAILED(hres))
485 return hres;
486
487 return read ? S_OK : S_FALSE;
488 }
489
490 HRESULT protocol_lock_request(Protocol *protocol)
491 {
492 if (!InternetLockRequestFile(protocol->request, &protocol->lock))
493 WARN("InternetLockRequest failed: %d\n", GetLastError());
494
495 return S_OK;
496 }
497
498 HRESULT protocol_unlock_request(Protocol *protocol)
499 {
500 if(!protocol->lock)
501 return S_OK;
502
503 if(!InternetUnlockRequestFile(protocol->lock))
504 WARN("InternetUnlockRequest failed: %d\n", GetLastError());
505 protocol->lock = 0;
506
507 return S_OK;
508 }
509
510 HRESULT protocol_abort(Protocol *protocol, HRESULT reason)
511 {
512 if(!protocol->protocol_sink)
513 return S_OK;
514
515 if(protocol->flags & FLAG_RESULT_REPORTED)
516 return INET_E_RESULT_DISPATCHED;
517
518 report_result(protocol, reason);
519 return S_OK;
520 }
521
522 void protocol_close_connection(Protocol *protocol)
523 {
524 protocol->vtbl->close_connection(protocol);
525
526 if(protocol->request)
527 InternetCloseHandle(protocol->request);
528
529 if(protocol->connection)
530 InternetCloseHandle(protocol->connection);
531
532 if(protocol->post_stream) {
533 IStream_Release(protocol->post_stream);
534 protocol->post_stream = NULL;
535 }
536
537 protocol->flags = 0;
538 }