Sync with trunk revision 63128.
[reactos.git] / dll / win32 / wininet / netconnection.c
1 /*
2 * Wininet - networking layer. Uses unix sockets.
3 *
4 * Copyright 2002 TransGaming Technologies Inc.
5 * Copyright 2013 Jacek Caban for CodeWeavers
6 *
7 * David Hammerton
8 *
9 * This library is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public
11 * License as published by the Free Software Foundation; either
12 * version 2.1 of the License, or (at your option) any later version.
13 *
14 * This library is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
18 *
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with this library; if not, write to the Free Software
21 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
22 */
23
24 #include "internet.h"
25
26 #include <sys/types.h>
27 #ifdef HAVE_POLL_H
28 #include <poll.h>
29 #endif
30 #ifdef HAVE_SYS_FILIO_H
31 # include <sys/filio.h>
32 #endif
33 #ifdef HAVE_NETINET_TCP_H
34 # include <netinet/tcp.h>
35 #endif
36
37 #include <errno.h>
38
39 #define RESPONSE_TIMEOUT 30 /* FROM internet.c */
40
41 #ifdef MSG_DONTWAIT
42 #define WINE_MSG_DONTWAIT MSG_DONTWAIT
43 #else
44 #define WINE_MSG_DONTWAIT 0
45 #endif
46
47 /* FIXME!!!!!!
48 * This should use winsock - To use winsock the functions will have to change a bit
49 * as they are designed for unix sockets.
50 */
51
52 static DWORD netconn_verify_cert(netconn_t *conn, PCCERT_CONTEXT cert, HCERTSTORE store)
53 {
54 BOOL ret;
55 CERT_CHAIN_PARA chainPara = { sizeof(chainPara), { 0 } };
56 PCCERT_CHAIN_CONTEXT chain;
57 char oid_server_auth[] = szOID_PKIX_KP_SERVER_AUTH;
58 char *server_auth[] = { oid_server_auth };
59 DWORD err = ERROR_SUCCESS, errors;
60
61 static const DWORD supportedErrors =
62 CERT_TRUST_IS_NOT_TIME_VALID |
63 CERT_TRUST_IS_UNTRUSTED_ROOT |
64 CERT_TRUST_IS_PARTIAL_CHAIN |
65 CERT_TRUST_IS_NOT_VALID_FOR_USAGE;
66
67 TRACE("verifying %s\n", debugstr_w(conn->server->name));
68
69 chainPara.RequestedUsage.Usage.cUsageIdentifier = 1;
70 chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = server_auth;
71 if (!(ret = CertGetCertificateChain(NULL, cert, NULL, store, &chainPara, 0, NULL, &chain))) {
72 TRACE("failed\n");
73 return GetLastError();
74 }
75
76 errors = chain->TrustStatus.dwErrorStatus;
77
78 do {
79 /* This seems strange, but that's what tests show */
80 if(errors & CERT_TRUST_IS_PARTIAL_CHAIN) {
81 WARN("ERROR_INTERNET_SEC_CERT_REV_FAILED\n");
82 err = ERROR_INTERNET_SEC_CERT_REV_FAILED;
83 if(conn->mask_errors)
84 conn->security_flags |= _SECURITY_FLAG_CERT_REV_FAILED;
85 if(!(conn->security_flags & SECURITY_FLAG_IGNORE_REVOCATION))
86 break;
87 }
88
89 if (chain->TrustStatus.dwErrorStatus & ~supportedErrors) {
90 WARN("error status %x\n", chain->TrustStatus.dwErrorStatus & ~supportedErrors);
91 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
92 errors &= supportedErrors;
93 if(!conn->mask_errors)
94 break;
95 WARN("unknown error flags\n");
96 }
97
98 if(errors & CERT_TRUST_IS_NOT_TIME_VALID) {
99 WARN("CERT_TRUST_IS_NOT_TIME_VALID\n");
100 if(!(conn->security_flags & SECURITY_FLAG_IGNORE_CERT_DATE_INVALID)) {
101 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_CERT_DATE_INVALID;
102 if(!conn->mask_errors)
103 break;
104 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_DATE;
105 }
106 errors &= ~CERT_TRUST_IS_NOT_TIME_VALID;
107 }
108
109 if(errors & CERT_TRUST_IS_UNTRUSTED_ROOT) {
110 WARN("CERT_TRUST_IS_UNTRUSTED_ROOT\n");
111 if(!(conn->security_flags & SECURITY_FLAG_IGNORE_UNKNOWN_CA)) {
112 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_INVALID_CA;
113 if(!conn->mask_errors)
114 break;
115 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CA;
116 }
117 errors &= ~CERT_TRUST_IS_UNTRUSTED_ROOT;
118 }
119
120 if(errors & CERT_TRUST_IS_PARTIAL_CHAIN) {
121 WARN("CERT_TRUST_IS_PARTIAL_CHAIN\n");
122 if(!(conn->security_flags & SECURITY_FLAG_IGNORE_UNKNOWN_CA)) {
123 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_INVALID_CA;
124 if(!conn->mask_errors)
125 break;
126 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CA;
127 }
128 errors &= ~CERT_TRUST_IS_PARTIAL_CHAIN;
129 }
130
131 if(errors & CERT_TRUST_IS_NOT_VALID_FOR_USAGE) {
132 WARN("CERT_TRUST_IS_NOT_VALID_FOR_USAGE\n");
133 if(!(conn->security_flags & SECURITY_FLAG_IGNORE_WRONG_USAGE)) {
134 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
135 if(!conn->mask_errors)
136 break;
137 WARN("CERT_TRUST_IS_NOT_VALID_FOR_USAGE, unknown error flags\n");
138 }
139 errors &= ~CERT_TRUST_IS_NOT_VALID_FOR_USAGE;
140 }
141
142 if(err == ERROR_INTERNET_SEC_CERT_REV_FAILED) {
143 assert(conn->security_flags & SECURITY_FLAG_IGNORE_REVOCATION);
144 err = ERROR_SUCCESS;
145 }
146 }while(0);
147
148 if(!err || conn->mask_errors) {
149 CERT_CHAIN_POLICY_PARA policyPara;
150 SSL_EXTRA_CERT_CHAIN_POLICY_PARA sslExtraPolicyPara;
151 CERT_CHAIN_POLICY_STATUS policyStatus;
152 CERT_CHAIN_CONTEXT chainCopy;
153
154 /* Clear chain->TrustStatus.dwErrorStatus so
155 * CertVerifyCertificateChainPolicy will verify additional checks
156 * rather than stopping with an existing, ignored error.
157 */
158 memcpy(&chainCopy, chain, sizeof(chainCopy));
159 chainCopy.TrustStatus.dwErrorStatus = 0;
160 sslExtraPolicyPara.u.cbSize = sizeof(sslExtraPolicyPara);
161 sslExtraPolicyPara.dwAuthType = AUTHTYPE_SERVER;
162 sslExtraPolicyPara.pwszServerName = conn->server->name;
163 sslExtraPolicyPara.fdwChecks = conn->security_flags;
164 policyPara.cbSize = sizeof(policyPara);
165 policyPara.dwFlags = 0;
166 policyPara.pvExtraPolicyPara = &sslExtraPolicyPara;
167 ret = CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL,
168 &chainCopy, &policyPara, &policyStatus);
169 /* Any error in the policy status indicates that the
170 * policy couldn't be verified.
171 */
172 if(ret) {
173 if(policyStatus.dwError == CERT_E_CN_NO_MATCH) {
174 WARN("CERT_E_CN_NO_MATCH\n");
175 if(conn->mask_errors)
176 conn->security_flags |= _SECURITY_FLAG_CERT_INVALID_CN;
177 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_CERT_CN_INVALID;
178 }else if(policyStatus.dwError) {
179 WARN("policyStatus.dwError %x\n", policyStatus.dwError);
180 if(conn->mask_errors)
181 WARN("unknown error flags for policy status %x\n", policyStatus.dwError);
182 err = conn->mask_errors && err ? ERROR_INTERNET_SEC_CERT_ERRORS : ERROR_INTERNET_SEC_INVALID_CERT;
183 }
184 }else {
185 err = GetLastError();
186 }
187 }
188
189 if(err) {
190 WARN("failed %u\n", err);
191 CertFreeCertificateChain(chain);
192 if(conn->server->cert_chain) {
193 CertFreeCertificateChain(conn->server->cert_chain);
194 conn->server->cert_chain = NULL;
195 }
196 if(conn->mask_errors)
197 conn->server->security_flags |= conn->security_flags & _SECURITY_ERROR_FLAGS_MASK;
198 return err;
199 }
200
201 /* FIXME: Reuse cached chain */
202 if(conn->server->cert_chain)
203 CertFreeCertificateChain(chain);
204 else
205 conn->server->cert_chain = chain;
206 return ERROR_SUCCESS;
207 }
208
209 static SecHandle cred_handle, compat_cred_handle;
210 static BOOL cred_handle_initialized, have_compat_cred_handle;
211
212 static CRITICAL_SECTION init_sechandle_cs;
213 static CRITICAL_SECTION_DEBUG init_sechandle_cs_debug = {
214 0, 0, &init_sechandle_cs,
215 { &init_sechandle_cs_debug.ProcessLocksList,
216 &init_sechandle_cs_debug.ProcessLocksList },
217 0, 0, { (DWORD_PTR)(__FILE__ ": init_sechandle_cs") }
218 };
219 static CRITICAL_SECTION init_sechandle_cs = { &init_sechandle_cs_debug, -1, 0, 0, 0, 0 };
220
221 static BOOL ensure_cred_handle(void)
222 {
223 SECURITY_STATUS res = SEC_E_OK;
224
225 EnterCriticalSection(&init_sechandle_cs);
226
227 if(!cred_handle_initialized) {
228 SCHANNEL_CRED cred = {SCHANNEL_CRED_VERSION};
229 SecPkgCred_SupportedProtocols prots;
230
231 res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
232 NULL, NULL, &cred_handle, NULL);
233 if(res == SEC_E_OK) {
234 res = QueryCredentialsAttributesA(&cred_handle, SECPKG_ATTR_SUPPORTED_PROTOCOLS, &prots);
235 if(res != SEC_E_OK || (prots.grbitProtocol & SP_PROT_TLS1_1PLUS_CLIENT)) {
236 cred.grbitEnabledProtocols = prots.grbitProtocol & ~SP_PROT_TLS1_1PLUS_CLIENT;
237 res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
238 NULL, NULL, &compat_cred_handle, NULL);
239 have_compat_cred_handle = res == SEC_E_OK;
240 }
241 }
242
243 cred_handle_initialized = res == SEC_E_OK;
244 }
245
246 LeaveCriticalSection(&init_sechandle_cs);
247
248 if(res != SEC_E_OK) {
249 WARN("Failed: %08x\n", res);
250 return FALSE;
251 }
252
253 return TRUE;
254 }
255
256 static DWORD create_netconn_socket(server_t *server, netconn_t *netconn, DWORD timeout)
257 {
258 int result;
259 ULONG flag;
260
261 assert(server->addr_len);
262 result = netconn->socket = socket(server->addr.ss_family, SOCK_STREAM, 0);
263 if(result != -1) {
264 flag = 1;
265 ioctlsocket(netconn->socket, FIONBIO, &flag);
266 result = connect(netconn->socket, (struct sockaddr*)&server->addr, server->addr_len);
267 if(result == -1)
268 {
269 if (sock_get_error(errno) == WSAEINPROGRESS) {
270 // ReactOS: use select instead of poll
271 fd_set outfd;
272 struct timeval tv;
273 int res;
274
275 FD_ZERO(&outfd);
276 FD_SET(netconn->socket, &outfd);
277 tv.tv_sec = timeout / 1000;
278 tv.tv_usec = (timeout % 1000) * 1000;
279 res = select(0, NULL, &outfd, NULL, &tv);
280 if (!res)
281 {
282 closesocket(netconn->socket);
283 return ERROR_INTERNET_CANNOT_CONNECT;
284 }
285 else if (res > 0)
286 {
287 int err;
288 socklen_t len = sizeof(err);
289 if (!getsockopt(netconn->socket, SOL_SOCKET, SO_ERROR, (void *)&err, &len) && !err)
290 result = 0;
291 }
292 }
293 }
294 if(result == -1)
295 {
296 closesocket(netconn->socket);
297 netconn->socket = -1;
298 }
299 else {
300 flag = 0;
301 ioctlsocket(netconn->socket, FIONBIO, &flag);
302 }
303 }
304 if(result == -1)
305 return ERROR_INTERNET_CANNOT_CONNECT;
306
307 #ifdef TCP_NODELAY
308 flag = 1;
309 result = setsockopt(netconn->socket, IPPROTO_TCP, TCP_NODELAY, (void*)&flag, sizeof(flag));
310 if(result < 0)
311 WARN("setsockopt(TCP_NODELAY) failed\n");
312 #endif
313
314 return ERROR_SUCCESS;
315 }
316
317 DWORD create_netconn(BOOL useSSL, server_t *server, DWORD security_flags, BOOL mask_errors, DWORD timeout, netconn_t **ret)
318 {
319 netconn_t *netconn;
320 int result;
321
322 netconn = heap_alloc_zero(sizeof(*netconn));
323 if(!netconn)
324 return ERROR_OUTOFMEMORY;
325
326 netconn->socket = -1;
327 netconn->security_flags = security_flags | server->security_flags;
328 netconn->mask_errors = mask_errors;
329 list_init(&netconn->pool_entry);
330 SecInvalidateHandle(&netconn->ssl_ctx);
331
332 result = create_netconn_socket(server, netconn, timeout);
333 if (result != ERROR_SUCCESS) {
334 heap_free(netconn);
335 return result;
336 }
337
338 server_addref(server);
339 netconn->server = server;
340 *ret = netconn;
341 return result;
342 }
343
344 BOOL is_valid_netconn(netconn_t *netconn)
345 {
346 return netconn && netconn->socket != -1;
347 }
348
349 void close_netconn(netconn_t *netconn)
350 {
351 closesocket(netconn->socket);
352 netconn->socket = -1;
353 }
354
355 void free_netconn(netconn_t *netconn)
356 {
357 server_release(netconn->server);
358
359 if (netconn->secure) {
360 heap_free(netconn->peek_msg_mem);
361 netconn->peek_msg_mem = NULL;
362 netconn->peek_msg = NULL;
363 netconn->peek_len = 0;
364 heap_free(netconn->ssl_buf);
365 netconn->ssl_buf = NULL;
366 heap_free(netconn->extra_buf);
367 netconn->extra_buf = NULL;
368 netconn->extra_len = 0;
369 if (SecIsValidHandle(&netconn->ssl_ctx))
370 DeleteSecurityContext(&netconn->ssl_ctx);
371 }
372
373 heap_free(netconn);
374 }
375
376 void NETCON_unload(void)
377 {
378 if(cred_handle_initialized)
379 FreeCredentialsHandle(&cred_handle);
380 if(have_compat_cred_handle)
381 FreeCredentialsHandle(&compat_cred_handle);
382 DeleteCriticalSection(&init_sechandle_cs);
383 }
384
385 #ifndef __REACTOS__
386 /* translate a unix error code into a winsock one */
387 int sock_get_error( int err )
388 {
389 #if !defined(__MINGW32__) && !defined (_MSC_VER)
390 switch (err)
391 {
392 case EINTR: return WSAEINTR;
393 case EBADF: return WSAEBADF;
394 case EPERM:
395 case EACCES: return WSAEACCES;
396 case EFAULT: return WSAEFAULT;
397 case EINVAL: return WSAEINVAL;
398 case EMFILE: return WSAEMFILE;
399 case EWOULDBLOCK: return WSAEWOULDBLOCK;
400 case EINPROGRESS: return WSAEINPROGRESS;
401 case EALREADY: return WSAEALREADY;
402 case ENOTSOCK: return WSAENOTSOCK;
403 case EDESTADDRREQ: return WSAEDESTADDRREQ;
404 case EMSGSIZE: return WSAEMSGSIZE;
405 case EPROTOTYPE: return WSAEPROTOTYPE;
406 case ENOPROTOOPT: return WSAENOPROTOOPT;
407 case EPROTONOSUPPORT: return WSAEPROTONOSUPPORT;
408 case ESOCKTNOSUPPORT: return WSAESOCKTNOSUPPORT;
409 case EOPNOTSUPP: return WSAEOPNOTSUPP;
410 case EPFNOSUPPORT: return WSAEPFNOSUPPORT;
411 case EAFNOSUPPORT: return WSAEAFNOSUPPORT;
412 case EADDRINUSE: return WSAEADDRINUSE;
413 case EADDRNOTAVAIL: return WSAEADDRNOTAVAIL;
414 case ENETDOWN: return WSAENETDOWN;
415 case ENETUNREACH: return WSAENETUNREACH;
416 case ENETRESET: return WSAENETRESET;
417 case ECONNABORTED: return WSAECONNABORTED;
418 case EPIPE:
419 case ECONNRESET: return WSAECONNRESET;
420 case ENOBUFS: return WSAENOBUFS;
421 case EISCONN: return WSAEISCONN;
422 case ENOTCONN: return WSAENOTCONN;
423 case ESHUTDOWN: return WSAESHUTDOWN;
424 case ETOOMANYREFS: return WSAETOOMANYREFS;
425 case ETIMEDOUT: return WSAETIMEDOUT;
426 case ECONNREFUSED: return WSAECONNREFUSED;
427 case ELOOP: return WSAELOOP;
428 case ENAMETOOLONG: return WSAENAMETOOLONG;
429 case EHOSTDOWN: return WSAEHOSTDOWN;
430 case EHOSTUNREACH: return WSAEHOSTUNREACH;
431 case ENOTEMPTY: return WSAENOTEMPTY;
432 #ifdef EPROCLIM
433 case EPROCLIM: return WSAEPROCLIM;
434 #endif
435 #ifdef EUSERS
436 case EUSERS: return WSAEUSERS;
437 #endif
438 #ifdef EDQUOT
439 case EDQUOT: return WSAEDQUOT;
440 #endif
441 #ifdef ESTALE
442 case ESTALE: return WSAESTALE;
443 #endif
444 #ifdef EREMOTE
445 case EREMOTE: return WSAEREMOTE;
446 #endif
447 default: errno=err; perror("sock_set_error"); return WSAEFAULT;
448 }
449 #endif
450 return err;
451 }
452 #endif
453
454 static void set_socket_blocking(int socket, blocking_mode_t mode)
455 {
456 #if defined(__MINGW32__) || defined (_MSC_VER)
457 ULONG arg = mode == BLOCKING_DISALLOW;
458 ioctlsocket(socket, FIONBIO, &arg);
459 #endif
460 }
461
462 static DWORD netcon_secure_connect_setup(netconn_t *connection, BOOL compat_mode)
463 {
464 SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
465 SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
466 SecHandle *cred = &cred_handle;
467 BYTE *read_buf;
468 SIZE_T read_buf_size = 2048;
469 ULONG attrs = 0;
470 CtxtHandle ctx;
471 SSIZE_T size;
472 int bits;
473 const CERT_CONTEXT *cert;
474 SECURITY_STATUS status;
475 DWORD res = ERROR_SUCCESS;
476
477 const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
478 |ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;
479
480 if(!ensure_cred_handle())
481 return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
482
483 if(compat_mode) {
484 if(!have_compat_cred_handle)
485 return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
486 cred = &compat_cred_handle;
487 }
488
489 read_buf = heap_alloc(read_buf_size);
490 if(!read_buf)
491 return ERROR_OUTOFMEMORY;
492
493 status = InitializeSecurityContextW(cred, NULL, connection->server->name, isc_req_flags, 0, 0, NULL, 0,
494 &ctx, &out_desc, &attrs, NULL);
495
496 assert(status != SEC_E_OK);
497
498 while(status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE) {
499 if(out_buf.cbBuffer) {
500 assert(status == SEC_I_CONTINUE_NEEDED);
501
502 TRACE("sending %u bytes\n", out_buf.cbBuffer);
503
504 size = send(connection->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
505 if(size != out_buf.cbBuffer) {
506 ERR("send failed\n");
507 status = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
508 break;
509 }
510
511 FreeContextBuffer(out_buf.pvBuffer);
512 out_buf.pvBuffer = NULL;
513 out_buf.cbBuffer = 0;
514 }
515
516 if(status == SEC_I_CONTINUE_NEEDED) {
517 assert(in_bufs[1].cbBuffer < read_buf_size);
518
519 memmove(read_buf, (BYTE*)in_bufs[0].pvBuffer+in_bufs[0].cbBuffer-in_bufs[1].cbBuffer, in_bufs[1].cbBuffer);
520 in_bufs[0].cbBuffer = in_bufs[1].cbBuffer;
521
522 in_bufs[1].BufferType = SECBUFFER_EMPTY;
523 in_bufs[1].cbBuffer = 0;
524 in_bufs[1].pvBuffer = NULL;
525 }
526
527 assert(in_bufs[0].BufferType == SECBUFFER_TOKEN);
528 assert(in_bufs[1].BufferType == SECBUFFER_EMPTY);
529
530 if(in_bufs[0].cbBuffer + 1024 > read_buf_size) {
531 BYTE *new_read_buf;
532
533 new_read_buf = heap_realloc(read_buf, read_buf_size + 1024);
534 if(!new_read_buf) {
535 status = E_OUTOFMEMORY;
536 break;
537 }
538
539 in_bufs[0].pvBuffer = read_buf = new_read_buf;
540 read_buf_size += 1024;
541 }
542
543 size = recv(connection->socket, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
544 if(size < 1) {
545 WARN("recv error\n");
546 res = ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
547 break;
548 }
549
550 TRACE("recv %lu bytes\n", size);
551
552 in_bufs[0].cbBuffer += size;
553 in_bufs[0].pvBuffer = read_buf;
554 status = InitializeSecurityContextW(cred, &ctx, connection->server->name, isc_req_flags, 0, 0, &in_desc,
555 0, NULL, &out_desc, &attrs, NULL);
556 TRACE("InitializeSecurityContext ret %08x\n", status);
557
558 if(status == SEC_E_OK) {
559 if(SecIsValidHandle(&connection->ssl_ctx))
560 DeleteSecurityContext(&connection->ssl_ctx);
561 connection->ssl_ctx = ctx;
562
563 if(in_bufs[1].BufferType == SECBUFFER_EXTRA)
564 FIXME("SECBUFFER_EXTRA not supported\n");
565
566 status = QueryContextAttributesW(&ctx, SECPKG_ATTR_STREAM_SIZES, &connection->ssl_sizes);
567 if(status != SEC_E_OK) {
568 WARN("Could not get sizes\n");
569 break;
570 }
571
572 status = QueryContextAttributesW(&ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
573 if(status == SEC_E_OK) {
574 res = netconn_verify_cert(connection, cert, cert->hCertStore);
575 CertFreeCertificateContext(cert);
576 if(res != ERROR_SUCCESS) {
577 WARN("cert verify failed: %u\n", res);
578 break;
579 }
580 }else {
581 WARN("Could not get cert\n");
582 break;
583 }
584
585 connection->ssl_buf = heap_alloc(connection->ssl_sizes.cbHeader + connection->ssl_sizes.cbMaximumMessage
586 + connection->ssl_sizes.cbTrailer);
587 if(!connection->ssl_buf) {
588 res = GetLastError();
589 break;
590 }
591 }
592 }
593
594 if(status != SEC_E_OK || res != ERROR_SUCCESS) {
595 WARN("Failed to establish SSL connection: %08x (%u)\n", status, res);
596 heap_free(connection->ssl_buf);
597 connection->ssl_buf = NULL;
598 return res ? res : ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
599 }
600
601 TRACE("established SSL connection\n");
602 connection->secure = TRUE;
603 connection->security_flags |= SECURITY_FLAG_SECURE;
604
605 bits = NETCON_GetCipherStrength(connection);
606 if (bits >= 128)
607 connection->security_flags |= SECURITY_FLAG_STRENGTH_STRONG;
608 else if (bits >= 56)
609 connection->security_flags |= SECURITY_FLAG_STRENGTH_MEDIUM;
610 else
611 connection->security_flags |= SECURITY_FLAG_STRENGTH_WEAK;
612
613 if(connection->mask_errors)
614 connection->server->security_flags = connection->security_flags;
615 return ERROR_SUCCESS;
616 }
617
618 /******************************************************************************
619 * NETCON_secure_connect
620 * Initiates a secure connection over an existing plaintext connection.
621 */
622 DWORD NETCON_secure_connect(netconn_t *connection, server_t *server)
623 {
624 DWORD res;
625
626 /* can't connect if we are already connected */
627 if(connection->secure) {
628 ERR("already connected\n");
629 return ERROR_INTERNET_CANNOT_CONNECT;
630 }
631
632 if(server != connection->server) {
633 server_release(connection->server);
634 server_addref(server);
635 connection->server = server;
636 }
637
638 /* connect with given TLS options */
639 res = netcon_secure_connect_setup(connection, FALSE);
640 if (res == ERROR_SUCCESS)
641 return res;
642
643 /* FIXME: when got version alert and FIN from server */
644 /* fallback to connect without TLSv1.1/TLSv1.2 */
645 if (res == ERROR_INTERNET_SECURITY_CHANNEL_ERROR && have_compat_cred_handle)
646 {
647 closesocket(connection->socket);
648 res = create_netconn_socket(connection->server, connection, 500);
649 if (res != ERROR_SUCCESS)
650 return res;
651 res = netcon_secure_connect_setup(connection, TRUE);
652 }
653 return res;
654 }
655
656 static BOOL send_ssl_chunk(netconn_t *conn, const void *msg, size_t size)
657 {
658 SecBuffer bufs[4] = {
659 {conn->ssl_sizes.cbHeader, SECBUFFER_STREAM_HEADER, conn->ssl_buf},
660 {size, SECBUFFER_DATA, conn->ssl_buf+conn->ssl_sizes.cbHeader},
661 {conn->ssl_sizes.cbTrailer, SECBUFFER_STREAM_TRAILER, conn->ssl_buf+conn->ssl_sizes.cbHeader+size},
662 {0, SECBUFFER_EMPTY, NULL}
663 };
664 SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
665 SECURITY_STATUS res;
666
667 memcpy(bufs[1].pvBuffer, msg, size);
668 res = EncryptMessage(&conn->ssl_ctx, 0, &buf_desc, 0);
669 if(res != SEC_E_OK) {
670 WARN("EncryptMessage failed\n");
671 return FALSE;
672 }
673
674 if(send(conn->socket, conn->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0) < 1) {
675 WARN("send failed\n");
676 return FALSE;
677 }
678
679 return TRUE;
680 }
681
682 /******************************************************************************
683 * NETCON_send
684 * Basically calls 'send()' unless we should use SSL
685 * number of chars send is put in *sent
686 */
687 DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags,
688 int *sent /* out */)
689 {
690 if(!connection->secure)
691 {
692 *sent = send(connection->socket, msg, len, flags);
693 if (*sent == -1)
694 return sock_get_error(errno);
695 return ERROR_SUCCESS;
696 }
697 else
698 {
699 const BYTE *ptr = msg;
700 size_t chunk_size;
701
702 *sent = 0;
703
704 while(len) {
705 chunk_size = min(len, connection->ssl_sizes.cbMaximumMessage);
706 if(!send_ssl_chunk(connection, ptr, chunk_size))
707 return ERROR_INTERNET_SECURITY_CHANNEL_ERROR;
708
709 *sent += chunk_size;
710 ptr += chunk_size;
711 len -= chunk_size;
712 }
713
714 return ERROR_SUCCESS;
715 }
716 }
717
718 static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, blocking_mode_t mode, SIZE_T *ret_size, BOOL *eof)
719 {
720 const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer;
721 SecBuffer bufs[4];
722 SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
723 SSIZE_T size, buf_len = 0;
724 blocking_mode_t tmp_mode;
725 int i;
726 SECURITY_STATUS res;
727
728 assert(conn->extra_len < ssl_buf_size);
729
730 /* BLOCKING_WAITALL is handled by caller */
731 if(mode == BLOCKING_WAITALL)
732 mode = BLOCKING_ALLOW;
733
734 if(conn->extra_len) {
735 memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len);
736 buf_len = conn->extra_len;
737 conn->extra_len = 0;
738 heap_free(conn->extra_buf);
739 conn->extra_buf = NULL;
740 }
741
742 tmp_mode = buf_len ? BLOCKING_DISALLOW : mode;
743 set_socket_blocking(conn->socket, tmp_mode);
744 size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, tmp_mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT);
745 if(size < 0) {
746 if(!buf_len) {
747 if(errno == EAGAIN || errno == EWOULDBLOCK) {
748 TRACE("would block\n");
749 return WSAEWOULDBLOCK;
750 }
751 WARN("recv failed\n");
752 return ERROR_INTERNET_CONNECTION_ABORTED;
753 }
754 }else {
755 buf_len += size;
756 }
757
758 *ret_size = buf_len;
759
760 if(!buf_len) {
761 *eof = TRUE;
762 return ERROR_SUCCESS;
763 }
764
765 *eof = FALSE;
766
767 do {
768 memset(bufs, 0, sizeof(bufs));
769 bufs[0].BufferType = SECBUFFER_DATA;
770 bufs[0].cbBuffer = buf_len;
771 bufs[0].pvBuffer = conn->ssl_buf;
772
773 res = DecryptMessage(&conn->ssl_ctx, &buf_desc, 0, NULL);
774 switch(res) {
775 case SEC_E_OK:
776 break;
777 case SEC_I_CONTEXT_EXPIRED:
778 TRACE("context expired\n");
779 *eof = TRUE;
780 return ERROR_SUCCESS;
781 case SEC_E_INCOMPLETE_MESSAGE:
782 assert(buf_len < ssl_buf_size);
783
784 set_socket_blocking(conn->socket, mode);
785 size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT);
786 if(size < 1) {
787 if(size < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
788 TRACE("would block\n");
789
790 /* FIXME: Optimize extra_buf usage. */
791 conn->extra_buf = heap_alloc(buf_len);
792 if(!conn->extra_buf)
793 return ERROR_NOT_ENOUGH_MEMORY;
794
795 conn->extra_len = buf_len;
796 memcpy(conn->extra_buf, conn->ssl_buf, conn->extra_len);
797 return WSAEWOULDBLOCK;
798 }
799
800 return ERROR_INTERNET_CONNECTION_ABORTED;
801 }
802
803 buf_len += size;
804 continue;
805 default:
806 WARN("failed: %08x\n", res);
807 return ERROR_INTERNET_CONNECTION_ABORTED;
808 }
809 } while(res != SEC_E_OK);
810
811 for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
812 if(bufs[i].BufferType == SECBUFFER_DATA) {
813 size = min(buf_size, bufs[i].cbBuffer);
814 memcpy(buf, bufs[i].pvBuffer, size);
815 if(size < bufs[i].cbBuffer) {
816 assert(!conn->peek_len);
817 conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size);
818 if(!conn->peek_msg)
819 return ERROR_NOT_ENOUGH_MEMORY;
820 conn->peek_len = bufs[i].cbBuffer-size;
821 memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len);
822 }
823
824 *ret_size = size;
825 }
826 }
827
828 for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) {
829 if(bufs[i].BufferType == SECBUFFER_EXTRA) {
830 conn->extra_buf = heap_alloc(bufs[i].cbBuffer);
831 if(!conn->extra_buf)
832 return ERROR_NOT_ENOUGH_MEMORY;
833
834 conn->extra_len = bufs[i].cbBuffer;
835 memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len);
836 }
837 }
838
839 return ERROR_SUCCESS;
840 }
841
842 /******************************************************************************
843 * NETCON_recv
844 * Basically calls 'recv()' unless we should use SSL
845 * number of chars received is put in *recvd
846 */
847 DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t mode, int *recvd)
848 {
849 *recvd = 0;
850 if (!len)
851 return ERROR_SUCCESS;
852
853 if (!connection->secure)
854 {
855 int flags = 0;
856
857 switch(mode) {
858 case BLOCKING_ALLOW:
859 break;
860 case BLOCKING_DISALLOW:
861 flags = WINE_MSG_DONTWAIT;
862 break;
863 case BLOCKING_WAITALL:
864 flags = MSG_WAITALL;
865 break;
866 }
867
868 set_socket_blocking(connection->socket, mode);
869 *recvd = recv(connection->socket, buf, len, flags);
870 return *recvd == -1 ? sock_get_error(errno) : ERROR_SUCCESS;
871 }
872 else
873 {
874 SIZE_T size = 0, cread;
875 BOOL eof;
876 DWORD res;
877
878 if(connection->peek_msg) {
879 size = min(len, connection->peek_len);
880 memcpy(buf, connection->peek_msg, size);
881 connection->peek_len -= size;
882 connection->peek_msg += size;
883
884 if(!connection->peek_len) {
885 heap_free(connection->peek_msg_mem);
886 connection->peek_msg_mem = connection->peek_msg = NULL;
887 }
888 /* check if we have enough data from the peek buffer */
889 if(mode != BLOCKING_WAITALL || size == len) {
890 *recvd = size;
891 return ERROR_SUCCESS;
892 }
893
894 mode = BLOCKING_DISALLOW;
895 }
896
897 do {
898 res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, mode, &cread, &eof);
899 if(res != ERROR_SUCCESS) {
900 if(res == WSAEWOULDBLOCK) {
901 if(size)
902 res = ERROR_SUCCESS;
903 }else {
904 WARN("read_ssl_chunk failed\n");
905 }
906 break;
907 }
908
909 if(eof) {
910 TRACE("EOF\n");
911 break;
912 }
913
914 size += cread;
915 }while(!size || (mode == BLOCKING_WAITALL && size < len));
916
917 TRACE("received %ld bytes\n", size);
918 *recvd = size;
919 return res;
920 }
921 }
922
923 /******************************************************************************
924 * NETCON_query_data_available
925 * Returns the number of bytes of peeked data plus the number of bytes of
926 * queued, but unread data.
927 */
928 BOOL NETCON_query_data_available(netconn_t *connection, DWORD *available)
929 {
930 *available = 0;
931
932 if(!connection->secure)
933 {
934 #ifdef FIONREAD
935 ULONG unread;
936 int retval = ioctlsocket(connection->socket, FIONREAD, &unread);
937 if (!retval)
938 {
939 TRACE("%d bytes of queued, but unread data\n", unread);
940 *available += unread;
941 }
942 #endif
943 }
944 else
945 {
946 *available = connection->peek_len;
947 }
948 return TRUE;
949 }
950
951 BOOL NETCON_is_alive(netconn_t *netconn)
952 {
953 #ifdef MSG_DONTWAIT
954 ssize_t len;
955 BYTE b;
956
957 len = recv(netconn->socket, &b, 1, MSG_PEEK|MSG_DONTWAIT);
958 return len == 1 || (len == -1 && errno == EWOULDBLOCK);
959 #elif defined(__MINGW32__) || defined(_MSC_VER)
960 ULONG mode;
961 int len;
962 char b;
963
964 mode = 1;
965 if(!ioctlsocket(netconn->socket, FIONBIO, &mode))
966 return FALSE;
967
968 len = recv(netconn->socket, &b, 1, MSG_PEEK);
969
970 mode = 0;
971 if(!ioctlsocket(netconn->socket, FIONBIO, &mode))
972 return FALSE;
973
974 return len == 1 || (len == -1 && errno == WSAEWOULDBLOCK);
975 #else
976 FIXME("not supported on this platform\n");
977 return TRUE;
978 #endif
979 }
980
981 LPCVOID NETCON_GetCert(netconn_t *connection)
982 {
983 const CERT_CONTEXT *ret;
984 SECURITY_STATUS res;
985
986 res = QueryContextAttributesW(&connection->ssl_ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&ret);
987 return res == SEC_E_OK ? ret : NULL;
988 }
989
990 int NETCON_GetCipherStrength(netconn_t *connection)
991 {
992 SecPkgContext_ConnectionInfo conn_info;
993 SECURITY_STATUS res;
994
995 if (!connection->secure)
996 return 0;
997
998 res = QueryContextAttributesW(&connection->ssl_ctx, SECPKG_ATTR_CONNECTION_INFO, (void*)&conn_info);
999 if(res != SEC_E_OK)
1000 WARN("QueryContextAttributesW failed: %08x\n", res);
1001 return res == SEC_E_OK ? conn_info.dwCipherStrength : 0;
1002 }
1003
1004 DWORD NETCON_set_timeout(netconn_t *connection, BOOL send, DWORD value)
1005 {
1006 int result;
1007 struct timeval tv;
1008
1009 /* value is in milliseconds, convert to struct timeval */
1010 if (value == INFINITE)
1011 {
1012 tv.tv_sec = 0;
1013 tv.tv_usec = 0;
1014 }
1015 else
1016 {
1017 tv.tv_sec = value / 1000;
1018 tv.tv_usec = (value % 1000) * 1000;
1019 }
1020 result = setsockopt(connection->socket, SOL_SOCKET,
1021 send ? SO_SNDTIMEO : SO_RCVTIMEO, (void*)&tv,
1022 sizeof(tv));
1023 if (result == -1)
1024 {
1025 WARN("setsockopt failed (%s)\n", strerror(errno));
1026 return sock_get_error(errno);
1027 }
1028 return ERROR_SUCCESS;
1029 }