[WINHTTP] Sync with Wine Staging 3.9. CORE-14656
[reactos.git] / dll / win32 / winhttp / net.c
index 225442d..e9ff86a 100644 (file)
  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
  */
 
-#include "winhttp_private.h"
+#include "config.h"
+#include "wine/port.h"
 
+#include <stdarg.h>
+#include <stdio.h>
+#include <errno.h>
 #include <assert.h>
-#include <schannel.h>
 
+#include <sys/types.h>
 #ifdef HAVE_SYS_SOCKET_H
 # include <sys/socket.h>
 #endif
 # include <poll.h>
 #endif
 
+#define NONAMELESSUNION
+
+#include "wine/debug.h"
+#include "wine/library.h"
+
+#include "windef.h"
+#include "winbase.h"
+#include "winhttp.h"
+#include "wincrypt.h"
+#include "schannel.h"
+
+#include "winhttp_private.h"
+
+/* to avoid conflicts with the Unix socket headers */
+#define USE_WS_PREFIX
+#include "winsock2.h"
+
+WINE_DEFAULT_DEBUG_CHANNEL(winhttp);
+
 #ifndef HAVE_GETADDRINFO
 
 /* critical section to protect non-reentrant gethostbyname() */
@@ -126,6 +149,28 @@ static inline int unix_ioctl(int filedes, long request, void *arg)
 #define ioctlsocket unix_ioctl
 #endif
 
+static int sock_send(int fd, const void *msg, size_t len, int flags)
+{
+    int ret;
+    do
+    {
+        if ((ret = send(fd, msg, len, flags)) == -1) WARN("send error %s\n", strerror(errno));
+    }
+    while(ret == -1 && errno == EINTR);
+    return ret;
+}
+
+static int sock_recv(int fd, void *msg, size_t len, int flags)
+{
+    int ret;
+    do
+    {
+        if ((ret = recv(fd, msg, len, flags)) == -1) WARN("recv error %s\n", strerror(errno));
+    }
+    while(ret == -1 && errno == EINTR);
+    return ret;
+}
+
 static DWORD netconn_verify_cert( PCCERT_CONTEXT cert, WCHAR *server, DWORD security_flags )
 {
     HCERTSTORE store = cert->hCertStore;
@@ -219,136 +264,129 @@ static DWORD netconn_verify_cert( PCCERT_CONTEXT cert, WCHAR *server, DWORD secu
     return err;
 }
 
-static SecHandle cred_handle;
-static BOOL cred_handle_initialized;
-
-static CRITICAL_SECTION init_sechandle_cs;
-static CRITICAL_SECTION_DEBUG init_sechandle_cs_debug = {
-    0, 0, &init_sechandle_cs,
-    { &init_sechandle_cs_debug.ProcessLocksList,
-      &init_sechandle_cs_debug.ProcessLocksList },
-    0, 0, { (DWORD_PTR)(__FILE__ ": init_sechandle_cs") }
-};
-static CRITICAL_SECTION init_sechandle_cs = { &init_sechandle_cs_debug, -1, 0, 0, 0, 0 };
-
-static BOOL ensure_cred_handle(void)
+#ifdef __REACTOS__
+static BOOL winsock_initialized = FALSE;
+BOOL netconn_init_winsock()
 {
-    BOOL ret = TRUE;
-
-    EnterCriticalSection(&init_sechandle_cs);
-
-    if(!cred_handle_initialized) {
-        SECURITY_STATUS res;
-
-        res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, NULL,
-                NULL, NULL, &cred_handle, NULL);
-        if(res == SEC_E_OK) {
-            cred_handle_initialized = TRUE;
-        }else {
-            WARN("AcquireCredentialsHandleW failed: %u\n", res);
-            ret = FALSE;
+    WSADATA wsaData;
+    int error;
+    if (!winsock_initialized)
+    {
+        error = WSAStartup(MAKEWORD(1, 1), &wsaData);
+        if (error)
+        {
+            ERR("WSAStartup failed: %d\n", error);
+            return FALSE;
         }
+        else
+            winsock_initialized = TRUE;
     }
-
-    LeaveCriticalSection(&init_sechandle_cs);
-    return ret;
+    return winsock_initialized;
 }
 
-BOOL netconn_init( netconn_t *conn )
-{
-    memset(conn, 0, sizeof(*conn));
-    conn->socket = -1;
-    return TRUE;
-}
+#endif
 
 void netconn_unload( void )
 {
-    if(cred_handle_initialized)
-        FreeCredentialsHandle(&cred_handle);
-    DeleteCriticalSection(&init_sechandle_cs);
 #ifndef HAVE_GETADDRINFO
     DeleteCriticalSection(&cs_gethostbyname);
 #endif
+#ifdef __REACTOS__
+    if(winsock_initialized)
+        WSACleanup();
+#endif
 }
 
-BOOL netconn_connected( netconn_t *conn )
+netconn_t *netconn_create( hostdata_t *host, const struct sockaddr_storage *sockaddr, int timeout )
 {
-    return (conn->socket != -1);
-}
+    netconn_t *conn;
+    unsigned int addr_len;
+    BOOL ret = FALSE;
+    int res;
+    ULONG state;
 
-BOOL netconn_create( netconn_t *conn, int domain, int type, int protocol )
-{
-    if ((conn->socket = socket( domain, type, protocol )) == -1)
+    conn = heap_alloc_zero(sizeof(*conn));
+    if (!conn) return NULL;
+    conn->host = host;
+    conn->sockaddr = *sockaddr;
+    if ((conn->socket = socket( sockaddr->ss_family, SOCK_STREAM, 0 )) == -1)
     {
         WARN("unable to create socket (%s)\n", strerror(errno));
         set_last_error( sock_get_error( errno ) );
-        return FALSE;
+        heap_free(conn);
+        return NULL;
     }
-    return TRUE;
-}
-
-BOOL netconn_close( netconn_t *conn )
-{
-    int res;
 
-    if (conn->secure)
-    {
-        heap_free( conn->peek_msg_mem );
-        conn->peek_msg_mem = NULL;
-        conn->peek_msg = NULL;
-        conn->peek_len = 0;
-        heap_free(conn->ssl_buf);
-        conn->ssl_buf = NULL;
-        heap_free(conn->extra_buf);
-        conn->extra_buf = NULL;
-        conn->extra_len = 0;
-        DeleteSecurityContext(&conn->ssl_ctx);
-        conn->secure = FALSE;
-    }
-    res = closesocket( conn->socket );
-    conn->socket = -1;
-    if (res == -1)
+    switch (conn->sockaddr.ss_family)
     {
-        set_last_error( sock_get_error( errno ) );
-        return FALSE;
+    case AF_INET:
+        addr_len = sizeof(struct sockaddr_in);
+        break;
+    case AF_INET6:
+        addr_len = sizeof(struct sockaddr_in6);
+        break;
+    default:
+        assert(0);
     }
-    return TRUE;
-}
-
-BOOL netconn_connect( netconn_t *conn, const struct sockaddr *sockaddr, unsigned int addr_len, int timeout )
-{
-    BOOL ret = FALSE;
-    int res = 0;
-    ULONG state;
 
     if (timeout > 0)
     {
         state = 1;
         ioctlsocket( conn->socket, FIONBIO, &state );
     }
-    if (connect( conn->socket, sockaddr, addr_len ) < 0)
+
+    for (;;)
     {
-        res = sock_get_error( errno );
-        if (res == WSAEWOULDBLOCK || res == WSAEINPROGRESS)
+        res = 0;
+        if (connect( conn->socket, (const struct sockaddr *)&conn->sockaddr, addr_len ) < 0)
         {
-            // ReactOS: use select instead of poll
-            fd_set outfd;
-            struct timeval tv;
+            res = sock_get_error( errno );
+            if (res == WSAEWOULDBLOCK || res == WSAEINPROGRESS)
+            {
+#ifdef __REACTOS__
+                /* ReactOS: use select instead of poll */
+                fd_set outfd;
+                struct timeval tv;
 
-            FD_ZERO(&outfd);
-            FD_SET(conn->socket, &outfd);
+                FD_ZERO(&outfd);
+                FD_SET(conn->socket, &outfd);
 
-            tv.tv_sec = 0;
-            tv.tv_usec = timeout * 1000;
+                tv.tv_sec = 0;
+                tv.tv_usec = timeout * 1000;
+                for (;;)
+                {
+                    res = 0;
 
-            if (select( 0, NULL, &outfd, NULL, &tv ) > 0)
-                ret = TRUE;
-            else
-                res = sock_get_error( errno );
+                    if (select( 0, NULL, &outfd, NULL, &tv ) > 0)
+#else
+                struct pollfd pfd;
+
+                pfd.fd = conn->socket;
+                pfd.events = POLLOUT;
+                for (;;)
+                {
+                    res = 0;
+                    if (poll( &pfd, 1, timeout ) > 0)
+#endif
+                    {
+                        ret = TRUE;
+                        break;
+                    }
+                    else
+                    {
+                        res = sock_get_error( errno );
+                        if (res != WSAEINTR) break;
+                    }
+                }
+            }
+            if (res != WSAEINTR) break;
+        }
+        else
+        {
+            ret = TRUE;
+            break;
         }
     }
-    else
-        ret = TRUE;
     if (timeout > 0)
     {
         state = 0;
@@ -358,11 +396,35 @@ BOOL netconn_connect( netconn_t *conn, const struct sockaddr *sockaddr, unsigned
     {
         WARN("unable to connect to host (%d)\n", res);
         set_last_error( res );
+        netconn_close( conn );
+        return NULL;
     }
-    return ret;
+    return conn;
+}
+
+BOOL netconn_close( netconn_t *conn )
+{
+    int res;
+
+    if (conn->secure)
+    {
+        heap_free( conn->peek_msg_mem );
+        heap_free(conn->ssl_buf);
+        heap_free(conn->extra_buf);
+        DeleteSecurityContext(&conn->ssl_ctx);
+    }
+    res = closesocket( conn->socket );
+    release_host( conn->host );
+    heap_free(conn);
+    if (res == -1)
+    {
+        set_last_error( sock_get_error( errno ) );
+        return FALSE;
+    }
+    return TRUE;
 }
 
-BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
+BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname, DWORD security_flags, CredHandle *cred_handle )
 {
     SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
     SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
@@ -378,14 +440,11 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
     const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
         |ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;
 
-    if(!ensure_cred_handle())
-        return FALSE;
-
     read_buf = heap_alloc(read_buf_size);
     if(!read_buf)
         return FALSE;
 
-    status = InitializeSecurityContextW(&cred_handle, NULL, hostname, isc_req_flags, 0, 0, NULL, 0,
+    status = InitializeSecurityContextW(cred_handle, NULL, hostname, isc_req_flags, 0, 0, NULL, 0,
             &ctx, &out_desc, &attrs, NULL);
 
     assert(status != SEC_E_OK);
@@ -396,7 +455,7 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
 
             TRACE("sending %u bytes\n", out_buf.cbBuffer);
 
-            size = send(conn->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
+            size = sock_send(conn->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
             if(size != out_buf.cbBuffer) {
                 ERR("send failed\n");
                 res = ERROR_WINHTTP_SECURE_CHANNEL_ERROR;
@@ -435,9 +494,8 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
             read_buf_size += 1024;
         }
 
-        size = recv(conn->socket, (char *)(read_buf+in_bufs[0].cbBuffer), read_buf_size-in_bufs[0].cbBuffer, 0);
+        size = sock_recv(conn->socket, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
         if(size < 1) {
-            WARN("recv error\n");
             status = ERROR_WINHTTP_SECURE_CHANNEL_ERROR;
             break;
         }
@@ -446,7 +504,7 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
 
         in_bufs[0].cbBuffer += size;
         in_bufs[0].pvBuffer = read_buf;
-        status = InitializeSecurityContextW(&cred_handle, &ctx, hostname,  isc_req_flags, 0, 0, &in_desc,
+        status = InitializeSecurityContextW(cred_handle, &ctx, hostname,  isc_req_flags, 0, 0, &in_desc,
                 0, NULL, &out_desc, &attrs, NULL);
         TRACE("InitializeSecurityContext ret %08x\n", status);
 
@@ -462,7 +520,7 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
 
             status = QueryContextAttributesW(&ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
             if(status == SEC_E_OK) {
-                res = netconn_verify_cert(cert, hostname, conn->security_flags);
+                res = netconn_verify_cert(cert, hostname, security_flags);
                 CertFreeCertificateContext(cert);
                 if(res != ERROR_SUCCESS) {
                     WARN("cert verify failed: %u\n", res);
@@ -481,6 +539,7 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
         }
     }
 
+    heap_free(read_buf);
 
     if(status != SEC_E_OK || res != ERROR_SUCCESS) {
         WARN("Failed to initialize security context failed: %08x\n", status);
@@ -516,7 +575,7 @@ static BOOL send_ssl_chunk(netconn_t *conn, const void *msg, size_t size)
         return FALSE;
     }
 
-    if(send(conn->socket, conn->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0) < 1) {
+    if(sock_send(conn->socket, conn->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0) < 1) {
         WARN("send failed\n");
         return FALSE;
     }
@@ -526,7 +585,6 @@ static BOOL send_ssl_chunk(netconn_t *conn, const void *msg, size_t size)
 
 BOOL netconn_send( netconn_t *conn, const void *msg, size_t len, int *sent )
 {
-    if (!netconn_connected( conn )) return FALSE;
     if (conn->secure)
     {
         const BYTE *ptr = msg;
@@ -546,7 +604,7 @@ BOOL netconn_send( netconn_t *conn, const void *msg, size_t len, int *sent )
 
         return TRUE;
     }
-    if ((*sent = send( conn->socket, msg, len, 0 )) == -1)
+    if ((*sent = sock_send( conn->socket, msg, len, 0 )) == -1)
     {
         set_last_error( sock_get_error( errno ) );
         return FALSE;
@@ -572,11 +630,9 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
         heap_free(conn->extra_buf);
         conn->extra_buf = NULL;
     }else {
-        buf_len = recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0);
-        if(buf_len < 0) {
-            WARN("recv failed\n");
+        buf_len = sock_recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0);
+        if(buf_len < 0)
             return FALSE;
-        }
 
         if(!buf_len) {
             *eof = TRUE;
@@ -604,7 +660,7 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
         case SEC_E_INCOMPLETE_MESSAGE:
             assert(buf_len < ssl_buf_size);
 
-            size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
+            size = sock_recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
             if(size < 1)
                 return FALSE;
 
@@ -650,7 +706,6 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
 BOOL netconn_recv( netconn_t *conn, void *buf, size_t len, int flags, int *recvd )
 {
     *recvd = 0;
-    if (!netconn_connected( conn )) return FALSE;
     if (!len) return TRUE;
 
     if (conn->secure)
@@ -697,7 +752,7 @@ BOOL netconn_recv( netconn_t *conn, void *buf, size_t len, int flags, int *recvd
         *recvd = size;
         return TRUE;
     }
-    if ((*recvd = recv( conn->socket, buf, len, flags )) == -1)
+    if ((*recvd = sock_recv( conn->socket, buf, len, flags )) == -1)
     {
         set_last_error( sock_get_error( errno ) );
         return FALSE;
@@ -707,21 +762,7 @@ BOOL netconn_recv( netconn_t *conn, void *buf, size_t len, int flags, int *recvd
 
 ULONG netconn_query_data_available( netconn_t *conn )
 {
-    if(!netconn_connected(conn))
-        return 0;
-
-    if(conn->secure) {
-        return conn->peek_len;
-    }else {
-#ifdef FIONREAD
-        ULONG unread;
-
-        if(!ioctlsocket(conn->socket, FIONREAD, &unread))
-            return unread;
-#endif
-    }
-
-    return 0;
+    return conn->secure ? conn->peek_len : 0;
 }
 
 DWORD netconn_set_timeout( netconn_t *netconn, BOOL send, int value )
@@ -740,7 +781,37 @@ DWORD netconn_set_timeout( netconn_t *netconn, BOOL send, int value )
     return ERROR_SUCCESS;
 }
 
-static DWORD resolve_hostname( const WCHAR *hostnameW, INTERNET_PORT port, struct sockaddr *sa, socklen_t *sa_len )
+BOOL netconn_is_alive( netconn_t *netconn )
+{
+#ifdef MSG_DONTWAIT
+    ssize_t len;
+    BYTE b;
+
+    len = recv( netconn->socket, &b, 1, MSG_PEEK | MSG_DONTWAIT );
+    return len == 1 || (len == -1 && errno == EWOULDBLOCK);
+#elif defined(__MINGW32__) || defined(_MSC_VER)
+    ULONG mode;
+    int len;
+    char b;
+
+    mode = 1;
+    if(!ioctlsocket(netconn->socket, FIONBIO, &mode))
+        return FALSE;
+
+    len = recv(netconn->socket, &b, 1, MSG_PEEK);
+
+    mode = 0;
+    if(!ioctlsocket(netconn->socket, FIONBIO, &mode))
+        return FALSE;
+
+    return len == 1 || (len == -1 && WSAGetLastError() == WSAEWOULDBLOCK);
+#else
+    FIXME("not supported on this platform\n");
+    return TRUE;
+#endif
+}
+
+static DWORD resolve_hostname( const WCHAR *hostnameW, INTERNET_PORT port, struct sockaddr_storage *sa )
 {
     char *hostname;
 #ifdef HAVE_GETADDRINFO
@@ -774,13 +845,6 @@ static DWORD resolve_hostname( const WCHAR *hostnameW, INTERNET_PORT port, struc
         }
     }
     heap_free( hostname );
-    if (*sa_len < res->ai_addrlen)
-    {
-        WARN("address too small\n");
-        freeaddrinfo( res );
-        return ERROR_WINHTTP_NAME_NOT_RESOLVED;
-    }
-    *sa_len = res->ai_addrlen;
     memcpy( sa, res->ai_addr, res->ai_addrlen );
     /* Copy port */
     switch (res->ai_family)
@@ -806,13 +870,6 @@ static DWORD resolve_hostname( const WCHAR *hostnameW, INTERNET_PORT port, struc
         LeaveCriticalSection( &cs_gethostbyname );
         return ERROR_WINHTTP_NAME_NOT_RESOLVED;
     }
-    if (*sa_len < sizeof(struct sockaddr_in))
-    {
-        WARN("address too small\n");
-        LeaveCriticalSection( &cs_gethostbyname );
-        return ERROR_WINHTTP_NAME_NOT_RESOLVED;
-    }
-    *sa_len = sizeof(struct sockaddr_in);
     memset( sa, 0, sizeof(struct sockaddr_in) );
     memcpy( &sin->sin_addr, he->h_addr, he->h_length );
     sin->sin_family = he->h_addrtype;
@@ -825,19 +882,18 @@ static DWORD resolve_hostname( const WCHAR *hostnameW, INTERNET_PORT port, struc
 
 struct resolve_args
 {
-    const WCHAR     *hostname;
-    INTERNET_PORT    port;
-    struct sockaddr *sa;
-    socklen_t       *sa_len;
+    const WCHAR             *hostname;
+    INTERNET_PORT            port;
+    struct sockaddr_storage *sa;
 };
 
 static DWORD CALLBACK resolve_proc( LPVOID arg )
 {
     struct resolve_args *ra = arg;
-    return resolve_hostname( ra->hostname, ra->port, ra->sa, ra->sa_len );
+    return resolve_hostname( ra->hostname, ra->port, ra->sa );
 }
 
-BOOL netconn_resolve( WCHAR *hostname, INTERNET_PORT port, struct sockaddr *sa, socklen_t *sa_len, int timeout )
+BOOL netconn_resolve( WCHAR *hostname, INTERNET_PORT port, struct sockaddr_storage *sa, int timeout )
 {
     DWORD ret;
 
@@ -850,7 +906,6 @@ BOOL netconn_resolve( WCHAR *hostname, INTERNET_PORT port, struct sockaddr *sa,
         ra.hostname = hostname;
         ra.port     = port;
         ra.sa       = sa;
-        ra.sa_len   = sa_len;
 
         thread = CreateThread( NULL, 0, resolve_proc, &ra, 0, NULL );
         if (!thread) return FALSE;
@@ -860,7 +915,7 @@ BOOL netconn_resolve( WCHAR *hostname, INTERNET_PORT port, struct sockaddr *sa,
         else ret = ERROR_WINHTTP_TIMEOUT;
         CloseHandle( thread );
     }
-    else ret = resolve_hostname( hostname, port, sa, sa_len );
+    else ret = resolve_hostname( hostname, port, sa );
 
     if (ret)
     {