[WINHTTP] Properly initialize winsock. Based on wine tests CORE-12104
[reactos.git] / reactos / dll / win32 / winhttp / net.c
index cbf642d..dcf6aee 100644 (file)
  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
  */
 
-#define WIN32_NO_STATUS
-#define _INC_WINDOWS
-#define COM_NO_WINDOWS_H
-
-#include <config.h>
-//#include <wine/port.h>
+#include "winhttp_private.h"
 
-//#include <stdarg.h>
-//#include <stdio.h>
-//#include <errno.h>
 #include <assert.h>
+#include <schannel.h>
 
-//#include <sys/types.h>
 #ifdef HAVE_SYS_SOCKET_H
 # include <sys/socket.h>
 #endif
 # include <poll.h>
 #endif
 
-#define NONAMELESSUNION
-
-#include <wine/debug.h>
-#include <wine/library.h>
-
-//#include "windef.h"
-//#include "winbase.h"
-#include <winhttp.h>
-//#include "wincrypt.h"
-#include <schannel.h>
-
-#include "winhttp_private.h"
-
-/* to avoid conflicts with the Unix socket headers */
-#define USE_WS_PREFIX
-//#include "winsock2.h"
-
-WINE_DEFAULT_DEBUG_CHANNEL(winhttp);
-
 #ifndef HAVE_GETADDRINFO
 
 /* critical section to protect non-reentrant gethostbyname() */
@@ -153,6 +126,28 @@ static inline int unix_ioctl(int filedes, long request, void *arg)
 #define ioctlsocket unix_ioctl
 #endif
 
+static int sock_send(int fd, const void *msg, size_t len, int flags)
+{
+    int ret;
+    do
+    {
+        ret = send(fd, msg, len, flags);
+    }
+    while(ret == -1 && errno == EINTR);
+    return ret;
+}
+
+static int sock_recv(int fd, void *msg, size_t len, int flags)
+{
+    int ret;
+    do
+    {
+        ret = recv(fd, msg, len, flags);
+    }
+    while(ret == -1 && errno == EINTR);
+    return ret;
+}
+
 static DWORD netconn_verify_cert( PCCERT_CONTEXT cert, WCHAR *server, DWORD security_flags )
 {
     HCERTSTORE store = cert->hCertStore;
@@ -281,6 +276,28 @@ static BOOL ensure_cred_handle(void)
     return ret;
 }
 
+#ifdef __REACTOS__
+static BOOL winsock_initialized = FALSE;
+BOOL netconn_init_winsock()
+{
+    WSADATA wsaData;
+    int error;
+    if (!winsock_initialized)
+    {
+        error = WSAStartup(MAKEWORD(1, 1), &wsaData);
+        if (error)
+        {
+            ERR("WSAStartup failed: %d\n", error);
+            return FALSE;
+        }
+        else
+            winsock_initialized = TRUE;
+    }
+    return winsock_initialized;
+}
+
+#endif
+
 BOOL netconn_init( netconn_t *conn )
 {
     memset(conn, 0, sizeof(*conn));
@@ -296,6 +313,10 @@ void netconn_unload( void )
 #ifndef HAVE_GETADDRINFO
     DeleteCriticalSection(&cs_gethostbyname);
 #endif
+#ifdef __REACTOS__
+    if(winsock_initialized)
+        WSACleanup();
+#endif
 }
 
 BOOL netconn_connected( netconn_t *conn )
@@ -358,7 +379,8 @@ BOOL netconn_connect( netconn_t *conn, const struct sockaddr *sockaddr, unsigned
         res = sock_get_error( errno );
         if (res == WSAEWOULDBLOCK || res == WSAEINPROGRESS)
         {
-            // ReactOS: use select instead of poll
+#ifdef __REACTOS__
+            /* ReactOS: use select instead of poll */
             fd_set outfd;
             struct timeval tv;
 
@@ -369,6 +391,13 @@ BOOL netconn_connect( netconn_t *conn, const struct sockaddr *sockaddr, unsigned
             tv.tv_usec = timeout * 1000;
 
             if (select( 0, NULL, &outfd, NULL, &tv ) > 0)
+#else
+            struct pollfd pfd;
+
+            pfd.fd = conn->socket;
+            pfd.events = POLLOUT;
+            if (poll( &pfd, 1, timeout ) > 0)
+#endif
                 ret = TRUE;
             else
                 res = sock_get_error( errno );
@@ -423,7 +452,7 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
 
             TRACE("sending %u bytes\n", out_buf.cbBuffer);
 
-            size = send(conn->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
+            size = sock_send(conn->socket, out_buf.pvBuffer, out_buf.cbBuffer, 0);
             if(size != out_buf.cbBuffer) {
                 ERR("send failed\n");
                 res = ERROR_WINHTTP_SECURE_CHANNEL_ERROR;
@@ -462,7 +491,7 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
             read_buf_size += 1024;
         }
 
-        size = recv(conn->socket, (char *)(read_buf+in_bufs[0].cbBuffer), read_buf_size-in_bufs[0].cbBuffer, 0);
+        size = sock_recv(conn->socket, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
         if(size < 1) {
             WARN("recv error\n");
             status = ERROR_WINHTTP_SECURE_CHANNEL_ERROR;
@@ -508,6 +537,7 @@ BOOL netconn_secure_connect( netconn_t *conn, WCHAR *hostname )
         }
     }
 
+    heap_free(read_buf);
 
     if(status != SEC_E_OK || res != ERROR_SUCCESS) {
         WARN("Failed to initialize security context failed: %08x\n", status);
@@ -543,7 +573,7 @@ static BOOL send_ssl_chunk(netconn_t *conn, const void *msg, size_t size)
         return FALSE;
     }
 
-    if(send(conn->socket, conn->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0) < 1) {
+    if(sock_send(conn->socket, conn->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0) < 1) {
         WARN("send failed\n");
         return FALSE;
     }
@@ -573,7 +603,7 @@ BOOL netconn_send( netconn_t *conn, const void *msg, size_t len, int *sent )
 
         return TRUE;
     }
-    if ((*sent = send( conn->socket, msg, len, 0 )) == -1)
+    if ((*sent = sock_send( conn->socket, msg, len, 0 )) == -1)
     {
         set_last_error( sock_get_error( errno ) );
         return FALSE;
@@ -599,7 +629,7 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
         heap_free(conn->extra_buf);
         conn->extra_buf = NULL;
     }else {
-        buf_len = recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0);
+        buf_len = sock_recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0);
         if(buf_len < 0) {
             WARN("recv failed\n");
             return FALSE;
@@ -631,7 +661,7 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *
         case SEC_E_INCOMPLETE_MESSAGE:
             assert(buf_len < ssl_buf_size);
 
-            size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
+            size = sock_recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
             if(size < 1)
                 return FALSE;
 
@@ -724,7 +754,7 @@ BOOL netconn_recv( netconn_t *conn, void *buf, size_t len, int flags, int *recvd
         *recvd = size;
         return TRUE;
     }
-    if ((*recvd = recv( conn->socket, buf, len, flags )) == -1)
+    if ((*recvd = sock_recv( conn->socket, buf, len, flags )) == -1)
     {
         set_last_error( sock_get_error( errno ) );
         return FALSE;
@@ -732,24 +762,15 @@ BOOL netconn_recv( netconn_t *conn, void *buf, size_t len, int flags, int *recvd
     return TRUE;
 }
 
-BOOL netconn_query_data_available( netconn_t *conn, DWORD *available )
+ULONG netconn_query_data_available( netconn_t *conn )
 {
-#ifdef FIONREAD
-    int ret;
-    ULONG unread;
-#endif
-    *available = 0;
-    if (!netconn_connected( conn )) return FALSE;
+    if(!netconn_connected(conn))
+        return 0;
 
-    if (conn->secure)
-    {
-        *available = conn->peek_len;
-        return TRUE;
-    }
-#ifdef FIONREAD
-    if (!(ret = ioctlsocket( conn->socket, FIONREAD, &unread ))) *available = unread;
-#endif
-    return TRUE;
+    if(conn->secure)
+        return conn->peek_len;
+
+    return 0;
 }
 
 DWORD netconn_set_timeout( netconn_t *netconn, BOOL send, int value )