[MSTSC] Add schannel based SSL implementation. CORE-9321
[reactos.git] / reactos / base / applications / mstsc / tcp.c
index ab6d7f9..ceb0170 100644 (file)
 */
 
 #include "precomp.h"
+#ifdef WITH_SSL
+#include <sspi.h>
+#include <schannel.h>
+#endif
 
 #ifdef _WIN32
 #define socklen_t int
 #endif
 
 #ifdef WITH_SSL
+typedef struct
+{
+       CtxtHandle ssl_ctx;
+       SecPkgContext_StreamSizes ssl_sizes;
+       char *ssl_buf;
+       char *extra_buf;
+       size_t extra_len;
+       char *peek_msg;
+       char *peek_msg_mem;
+       size_t peek_len;
+       DWORD security_flags;
+} netconn_t;
+static char * g_ssl_server = NULL;
 static RD_BOOL g_ssl_initialized = False;
-static SSL *g_ssl = NULL;
-static SSL_CTX *g_ssl_ctx = NULL;
+static RD_BOOL cred_handle_initialized = False;
+static RD_BOOL have_compat_cred_handle = False;
+static SecHandle cred_handle, compat_cred_handle;
+static netconn_t g_ssl1;
+static netconn_t *g_ssl = NULL;
 #endif /* WITH_SSL */
 static int g_sock;
 static struct stream g_in;
@@ -84,13 +104,171 @@ tcp_init(uint32 maxlen)
        return result;
 }
 
+#ifdef WITH_SSL
+
+RD_BOOL send_ssl_chunk(const void *msg, size_t size)
+{
+       SecBuffer bufs[4] = {
+               {g_ssl->ssl_sizes.cbHeader, SECBUFFER_STREAM_HEADER, g_ssl->ssl_buf},
+               {size,  SECBUFFER_DATA, g_ssl->ssl_buf+g_ssl->ssl_sizes.cbHeader},
+               {g_ssl->ssl_sizes.cbTrailer, SECBUFFER_STREAM_TRAILER, g_ssl->ssl_buf+g_ssl->ssl_sizes.cbHeader+size},
+               {0, SECBUFFER_EMPTY, NULL}
+       };
+       SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
+       SECURITY_STATUS res;
+       int tcp_res;
+
+       memcpy(bufs[1].pvBuffer, msg, size);
+       res = EncryptMessage(&g_ssl->ssl_ctx, 0, &buf_desc, 0);
+       if (res != SEC_E_OK)
+       {
+               error("EncryptMessage failed: %d\n", res);
+               return False;
+       }
+
+       tcp_res = send(g_sock, g_ssl->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0);
+       if (tcp_res < 1)
+       {
+               error("send failed: %d (%s)\n", tcp_res, TCP_STRERROR);
+               return False;
+       }
+
+       return True;
+}
+
+DWORD read_ssl_chunk(void *buf, SIZE_T buf_size, BOOL blocking, SIZE_T *ret_size, BOOL *eof)
+{
+       const SIZE_T ssl_buf_size = g_ssl->ssl_sizes.cbHeader+g_ssl->ssl_sizes.cbMaximumMessage+g_ssl->ssl_sizes.cbTrailer;
+       SecBuffer bufs[4];
+       SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
+       SSIZE_T size, buf_len = 0;
+       int i;
+       SECURITY_STATUS res;
+
+       //assert(conn->extra_len < ssl_buf_size);
+
+       if (g_ssl->extra_len)
+       {
+               memcpy(g_ssl->ssl_buf, g_ssl->extra_buf, g_ssl->extra_len);
+               buf_len = g_ssl->extra_len;
+               g_ssl->extra_len = 0;
+               xfree(g_ssl->extra_buf);
+               g_ssl->extra_buf = NULL;
+       }
+
+       size = recv(g_sock, g_ssl->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
+       if (size < 0)
+       {
+               if (!buf_len)
+               {
+                       if (size == -1 && TCP_BLOCKS)
+                       {
+                               return WSAEWOULDBLOCK;
+                       }
+                       error("recv failed: %d (%s)\n", size, TCP_STRERROR);
+                       return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
+               }
+       }
+       else
+       {
+               buf_len += size;
+       }
+
+       if (!buf_len)
+       {
+               *eof = TRUE;
+               *ret_size = 0;
+               return ERROR_SUCCESS;
+       }
+
+       *eof = FALSE;
+
+       do
+       {
+               memset(bufs, 0, sizeof(bufs));
+               bufs[0].BufferType = SECBUFFER_DATA;
+               bufs[0].cbBuffer = buf_len;
+               bufs[0].pvBuffer = g_ssl->ssl_buf;
+
+               res = DecryptMessage(&g_ssl->ssl_ctx, &buf_desc, 0, NULL);
+               switch (res)
+               {
+               case SEC_E_OK:
+                       break;
+               case SEC_I_CONTEXT_EXPIRED:
+                       *eof = TRUE;
+                       return ERROR_SUCCESS;
+               case SEC_E_INCOMPLETE_MESSAGE:
+                       //assert(buf_len < ssl_buf_size);
+
+                       size = recv(g_sock, g_ssl->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
+                       if (size < 1)
+                       {
+                               if (size == -1 && TCP_BLOCKS)
+                               {
+                                       /* FIXME: Optimize extra_buf usage. */
+                                       g_ssl->extra_buf = xmalloc(buf_len);
+                                       if (!g_ssl->extra_buf)
+                                               return ERROR_NOT_ENOUGH_MEMORY;
+
+                                       g_ssl->extra_len = buf_len;
+                                       memcpy(g_ssl->extra_buf, g_ssl->ssl_buf, g_ssl->extra_len);
+                                       return WSAEWOULDBLOCK;
+                               }
+
+                               error("recv failed: %d (%s)\n", size, TCP_STRERROR);
+                               return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
+                       }
+
+                       buf_len += size;
+                       continue;
+               default:
+                       error("DecryptMessage failed: %d\n", res);
+                       return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
+               }
+       }
+       while (res != SEC_E_OK);
+
+       for (i=0; i < sizeof(bufs)/sizeof(*bufs); i++)
+       {
+               if (bufs[i].BufferType == SECBUFFER_DATA)
+               {
+                       size = min(buf_size, bufs[i].cbBuffer);
+                       memcpy(buf, bufs[i].pvBuffer, size);
+                       if (size < bufs[i].cbBuffer)
+                       {
+                               //assert(!conn->peek_len);
+                               g_ssl->peek_msg_mem = g_ssl->peek_msg = xmalloc(bufs[i].cbBuffer - size);
+                               if (!g_ssl->peek_msg)
+                                       return ERROR_NOT_ENOUGH_MEMORY;
+                               g_ssl->peek_len = bufs[i].cbBuffer-size;
+                               memcpy(g_ssl->peek_msg, (char*)bufs[i].pvBuffer+size, g_ssl->peek_len);
+                       }
+
+                       *ret_size = size;
+               }
+       }
+
+       for (i=0; i < sizeof(bufs)/sizeof(*bufs); i++)
+       {
+               if (bufs[i].BufferType == SECBUFFER_EXTRA)
+               {
+                       g_ssl->extra_buf = xmalloc(bufs[i].cbBuffer);
+                       if (!g_ssl->extra_buf)
+                               return ERROR_NOT_ENOUGH_MEMORY;
+
+                       g_ssl->extra_len = bufs[i].cbBuffer;
+                       memcpy(g_ssl->extra_buf, bufs[i].pvBuffer, g_ssl->extra_len);
+               }
+       }
+
+       return ERROR_SUCCESS;
+}
+#endif /* WITH_SSL */
 /* Send TCP transport data packet */
 void
 tcp_send(STREAM s)
 {
-#ifdef WITH_SSL
-       int ssl_err;
-#endif /* WITH_SSL */
        int length = s->end - s->data;
        int sent, total = 0;
 
@@ -105,26 +283,28 @@ tcp_send(STREAM s)
 #ifdef WITH_SSL
                if (g_ssl)
                {
-                       sent = SSL_write(g_ssl, s->data + total, length - total);
-                       if (sent <= 0)
+                       const BYTE *ptr = s->data + total;
+                       size_t chunk_size;
+
+                       sent = 0;
+
+                       while (length - total)
                        {
-                               ssl_err = SSL_get_error(g_ssl, sent);
-                               if (sent < 0 && (ssl_err == SSL_ERROR_WANT_READ ||
-                                                ssl_err == SSL_ERROR_WANT_WRITE))
-                               {
-                                       TCP_SLEEP(0);
-                                       sent = 0;
-                               }
-                               else
+                               chunk_size = min(length - total, g_ssl->ssl_sizes.cbMaximumMessage);
+                               if (!send_ssl_chunk(ptr, chunk_size))
                                {
 #ifdef WITH_SCARD
                                        scard_unlock(SCARD_LOCK_TCP);
 #endif
 
-                                       error("SSL_write: %d (%s)\n", ssl_err, TCP_STRERROR);
+                                       //error("send_ssl_chunk: %d (%s)\n", sent, TCP_STRERROR);
                                        g_network_error = True;
                                        return;
                                }
+
+                               sent += chunk_size;
+                               ptr += chunk_size;
+                               length -= chunk_size;
                        }
                }
                else
@@ -144,7 +324,7 @@ tcp_send(STREAM s)
                                        scard_unlock(SCARD_LOCK_TCP);
 #endif
 
-                                       error("send: %s\n", TCP_STRERROR);
+                                       error("send: %d (%s)\n", sent, TCP_STRERROR);
                                        g_network_error = True;
                                        return;
                                }
@@ -165,9 +345,6 @@ tcp_recv(STREAM s, uint32 length)
 {
        uint32 new_length, end_offset, p_offset;
        int rcvd = 0;
-#ifdef WITH_SSL
-       int ssl_err;
-#endif /* WITH_SSL */
 
        if (g_network_error == True)
                return NULL;
@@ -201,7 +378,7 @@ tcp_recv(STREAM s, uint32 length)
        while (length > 0)
        {
 #ifdef WITH_SSL
-               if (!g_ssl || SSL_pending(g_ssl) <= 0)
+               if (!g_ssl)
 #endif /* WITH_SSL */
                {
                        if (!ui_select(g_sock))
@@ -215,33 +392,50 @@ tcp_recv(STREAM s, uint32 length)
 #ifdef WITH_SSL
                if (g_ssl)
                {
-                       rcvd = SSL_read(g_ssl, s->end, length);
-                       ssl_err = SSL_get_error(g_ssl, rcvd);
+                       SIZE_T size = 0;
+                       BOOL eof;
+                       DWORD res;
 
-                       if (ssl_err == SSL_ERROR_SSL)
+                       if (g_ssl->peek_msg)
                        {
-                               if (SSL_get_shutdown(g_ssl) & SSL_RECEIVED_SHUTDOWN)
+                               size = min(length, g_ssl->peek_len);
+                               memcpy(s->end, g_ssl->peek_msg, size);
+                               g_ssl->peek_len -= size;
+                               g_ssl->peek_msg += size;
+                               s->end += size;
+
+                               if (!g_ssl->peek_len)
                                {
-                                       error("Remote peer initiated ssl shutdown.\n");
-                                       return NULL;
+                                       xfree(g_ssl->peek_msg_mem);
+                                       g_ssl->peek_msg_mem = g_ssl->peek_msg = NULL;
                                }
 
-                               ERR_print_errors_fp(stdout);
-                               g_network_error = True;
-                               return NULL;
+                               return s;
                        }
 
-                       if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE)
-                       {
-                               rcvd = 0;
-                       }
-                       else if (ssl_err != SSL_ERROR_NONE)
+                       do
                        {
-                               error("SSL_read: %d (%s)\n", ssl_err, TCP_STRERROR);
-                               g_network_error = True;
-                               return NULL;
+                               res = read_ssl_chunk((BYTE*)s->end, length, TRUE, &size, &eof);
+                               if (res != ERROR_SUCCESS)
+                               {
+                                       if (res == WSAEWOULDBLOCK)
+                                       {
+                                               if (size)
+                                               {
+                                                       res = ERROR_SUCCESS;
+                                               }
+                                       }
+                                       else
+                                       {
+                                               error("read_ssl_chunk: %d (%s)\n", res, TCP_STRERROR);
+                                               g_network_error = True;
+                                               return NULL;
+                                       }
+                                       break;
+                               }
                        }
-
+                       while (!size && !eof);
+                       rcvd = size;
                }
                else
                {
@@ -255,7 +449,7 @@ tcp_recv(STREAM s, uint32 length)
                                }
                                else
                                {
-                                       error("recv: %s\n", TCP_STRERROR);
+                                       error("recv: %d (%s)\n", rcvd, TCP_STRERROR);
                                        g_network_error = True;
                                        return NULL;
                                }
@@ -277,78 +471,206 @@ tcp_recv(STREAM s, uint32 length)
 }
 
 #ifdef WITH_SSL
-/* Establish a SSL/TLS 1.0 connection */
+
 RD_BOOL
-tcp_tls_connect(void)
+ensure_cred_handle(void)
 {
-       int err;
-       long options;
+       SECURITY_STATUS res = SEC_E_OK;
 
-       if (!g_ssl_initialized)
+       if (!cred_handle_initialized)
        {
-               SSL_load_error_strings();
-               SSL_library_init();
-               g_ssl_initialized = True;
-       }
+               SCHANNEL_CRED cred = {SCHANNEL_CRED_VERSION};
+               SecPkgCred_SupportedProtocols prots;
 
-       /* create process context */
-       if (g_ssl_ctx == NULL)
-       {
-               g_ssl_ctx = SSL_CTX_new(TLSv1_client_method());
-               if (g_ssl_ctx == NULL)
+               res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
+                               NULL, NULL, &cred_handle, NULL);
+               if (res == SEC_E_OK)
                {
-                       error("tcp_tls_connect: SSL_CTX_new() failed to create TLS v1.0 context\n");
-                       goto fail;
+                       res = QueryCredentialsAttributesA(&cred_handle, SECPKG_ATTR_SUPPORTED_PROTOCOLS, &prots);
+                       if (res != SEC_E_OK || (prots.grbitProtocol & SP_PROT_TLS1_1PLUS_CLIENT))
+                       {
+                               cred.grbitEnabledProtocols = prots.grbitProtocol & ~SP_PROT_TLS1_1PLUS_CLIENT;
+                               res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
+                                               NULL, NULL, &compat_cred_handle, NULL);
+                               have_compat_cred_handle = res == SEC_E_OK;
+                       }
                }
 
-               options = 0;
-#ifdef SSL_OP_NO_COMPRESSION
-               options |= SSL_OP_NO_COMPRESSION;
-#endif // __SSL_OP_NO_COMPRESSION
-               options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
-               SSL_CTX_set_options(g_ssl_ctx, options);
+               cred_handle_initialized = res == SEC_E_OK;
        }
 
-       /* free old connection */
-       if (g_ssl)
-               SSL_free(g_ssl);
+       if (res != SEC_E_OK)
+       {
+               error("ensure_cred_handle failed: %ld\n", res);
+               return False;
+       }
 
-       /* create new ssl connection */
-       g_ssl = SSL_new(g_ssl_ctx);
-       if (g_ssl == NULL)
+       return True;
+}
+
+DWORD
+ssl_handshake(RD_BOOL compat_mode)
+{
+       SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
+       SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
+       SecHandle *cred = &cred_handle;
+       BYTE *read_buf;
+       SIZE_T read_buf_size = 2048;
+       ULONG attrs = 0;
+       CtxtHandle ctx;
+       SSIZE_T size;
+       SECURITY_STATUS status;
+       DWORD res = ERROR_SUCCESS;
+
+       const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
+               |ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;
+
+       if (!ensure_cred_handle())
+               return -1;
+
+       if (compat_mode) {
+               if (!have_compat_cred_handle)
+                       return -1;
+               cred = &compat_cred_handle;
+       }
+
+       read_buf = xmalloc(read_buf_size);
+       if (!read_buf)
+               return ERROR_OUTOFMEMORY;
+
+       if (!g_ssl_server)
+               return -1;
+
+       status = InitializeSecurityContextA(cred, NULL, g_ssl_server, isc_req_flags, 0, 0, NULL, 0,
+                       &ctx, &out_desc, &attrs, NULL);
+
+       //assert(status != SEC_E_OK);
+
+       while (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE)
        {
-               error("tcp_tls_connect: SSL_new() failed\n");
-               goto fail;
+               if (out_buf.cbBuffer)
+               {
+                       //assert(status == SEC_I_CONTINUE_NEEDED);
+
+                       size = send(g_sock, out_buf.pvBuffer, out_buf.cbBuffer, 0);
+                       if (size != out_buf.cbBuffer)
+                       {
+                               error("send failed: %d (%s)\n", size, TCP_STRERROR);
+                               status = -1;
+                               break;
+                       }
+
+                       FreeContextBuffer(out_buf.pvBuffer);
+                       out_buf.pvBuffer = NULL;
+                       out_buf.cbBuffer = 0;
+               }
+
+               if (status == SEC_I_CONTINUE_NEEDED)
+               {
+                       //assert(in_bufs[1].cbBuffer < read_buf_size);
+
+                       memmove(read_buf, (BYTE*)in_bufs[0].pvBuffer+in_bufs[0].cbBuffer-in_bufs[1].cbBuffer, in_bufs[1].cbBuffer);
+                       in_bufs[0].cbBuffer = in_bufs[1].cbBuffer;
+
+                       in_bufs[1].BufferType = SECBUFFER_EMPTY;
+                       in_bufs[1].cbBuffer = 0;
+                       in_bufs[1].pvBuffer = NULL;
+               }
+
+               //assert(in_bufs[0].BufferType == SECBUFFER_TOKEN);
+               //assert(in_bufs[1].BufferType == SECBUFFER_EMPTY);
+
+               if (in_bufs[0].cbBuffer + 1024 > read_buf_size)
+               {
+                       BYTE *new_read_buf = xrealloc(read_buf, read_buf_size + 1024);
+                       if (!new_read_buf)
+                       {
+                               status = E_OUTOFMEMORY;
+                               break;
+                       }
+
+                       in_bufs[0].pvBuffer = read_buf = new_read_buf;
+                       read_buf_size += 1024;
+               }
+
+               size = recv(g_sock, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
+               if (size < 1)
+               {
+                       error("recv failed: %d (%s)\n", size, TCP_STRERROR);
+                       res = -1;
+                       break;
+               }
+
+               in_bufs[0].cbBuffer += size;
+               in_bufs[0].pvBuffer = read_buf;
+               status = InitializeSecurityContextA(cred, &ctx, g_ssl_server,  isc_req_flags, 0, 0, &in_desc,
+                               0, NULL, &out_desc, &attrs, NULL);
+
+               if (status == SEC_E_OK) {
+                       if (SecIsValidHandle(&g_ssl->ssl_ctx))
+                               DeleteSecurityContext(&g_ssl->ssl_ctx);
+                       g_ssl->ssl_ctx = ctx;
+
+                       if (in_bufs[1].BufferType == SECBUFFER_EXTRA)
+                       {
+                               //FIXME("SECBUFFER_EXTRA not supported\n");
+                       }
+
+                       status = QueryContextAttributesW(&ctx, SECPKG_ATTR_STREAM_SIZES, &g_ssl->ssl_sizes);
+                       if (status != SEC_E_OK)
+                       {
+                               //error("Can't determine ssl buffer sizes: %ld\n", status);
+                               break;
+                       }
+
+                       g_ssl->ssl_buf = xmalloc(g_ssl->ssl_sizes.cbHeader + g_ssl->ssl_sizes.cbMaximumMessage
+                                       + g_ssl->ssl_sizes.cbTrailer);
+                       if (!g_ssl->ssl_buf)
+                       {
+                               res = GetLastError();
+                               break;
+                       }
+               }
        }
 
-       if (SSL_set_fd(g_ssl, g_sock) < 1)
+       xfree(read_buf);
+
+       if (status != SEC_E_OK || res != ERROR_SUCCESS)
        {
-               error("tcp_tls_connect: SSL_set_fd() failed\n");
-               goto fail;
+               error("Failed to establish SSL connection: %08x (%u)\n", status, res);
+               xfree(g_ssl->ssl_buf);
+               g_ssl->ssl_buf = NULL;
+               return res ? res : -1;
        }
 
-       do
+       return ERROR_SUCCESS;
+}
+
+/* Establish a SSL/TLS 1.0 connection */
+RD_BOOL
+tcp_tls_connect(void)
+{
+       int err;
+       char tcp_port_rdp_s[10];
+
+       if (!g_ssl_initialized)
        {
-               err = SSL_connect(g_ssl);
+               g_ssl = &g_ssl1;
+               SecInvalidateHandle(&g_ssl->ssl_ctx);
+
+               g_ssl_initialized = True;
        }
-       while (SSL_get_error(g_ssl, err) == SSL_ERROR_WANT_READ);
 
-       if (err < 0)
+       snprintf(tcp_port_rdp_s, 10, "%d", g_tcp_port_rdp);
+       if ((err = ssl_handshake(FALSE)) != 0)
        {
-               ERR_print_errors_fp(stdout);
                goto fail;
        }
 
        return True;
 
       fail:
-       if (g_ssl)
-               SSL_free(g_ssl);
-       if (g_ssl_ctx)
-               SSL_CTX_free(g_ssl_ctx);
-
        g_ssl = NULL;
-       g_ssl_ctx = NULL;
        return False;
 }
 
@@ -356,46 +678,36 @@ tcp_tls_connect(void)
 RD_BOOL
 tcp_tls_get_server_pubkey(STREAM s)
 {
-       X509 *cert = NULL;
-       EVP_PKEY *pkey = NULL;
+       const CERT_CONTEXT *cert = NULL;
+       SECURITY_STATUS status;
 
        s->data = s->p = NULL;
        s->size = 0;
 
        if (g_ssl == NULL)
                goto out;
-
-       cert = SSL_get_peer_certificate(g_ssl);
-       if (cert == NULL)
-       {
-               error("tcp_tls_get_server_pubkey: SSL_get_peer_certificate() failed\n");
-               goto out;
-       }
-
-       pkey = X509_get_pubkey(cert);
-       if (pkey == NULL)
+       status = QueryContextAttributesW(&g_ssl->ssl_ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
+       if (status != SEC_E_OK)
        {
-               error("tcp_tls_get_server_pubkey: X509_get_pubkey() failed\n");
+               error("tcp_tls_get_server_pubkey: QueryContextAttributesW() failed %ld\n", status);
                goto out;
        }
 
-       s->size = i2d_PublicKey(pkey, NULL);
+       s->size = cert->cbCertEncoded;
        if (s->size < 1)
        {
-               error("tcp_tls_get_server_pubkey: i2d_PublicKey() failed\n");
+               error("tcp_tls_get_server_pubkey: cert->cbCertEncoded = %ld\n", cert->cbCertEncoded);
                goto out;
        }
 
-       s->data = s->p = xmalloc(s->size);
-       i2d_PublicKey(pkey, &s->p);
+       s->data = s->p = (unsigned char *)xmalloc(s->size);
+       memcpy(cert->pbCertEncoded, &s->p, s->size);
        s->p = s->data;
        s->end = s->p + s->size;
 
       out:
        if (cert)
-               X509_free(cert);
-       if (pkey)
-               EVP_PKEY_free(pkey);
+               CertFreeCertificateContext(cert);
        return (s->size != 0);
 }
 #endif /* WITH_SSL */
@@ -508,6 +820,10 @@ tcp_connect(char *server)
                g_out[i].data = (uint8 *) xmalloc(g_out[i].size);
        }
 
+#ifdef WITH_SSL
+       g_ssl_server = xmalloc(strlen(server)+1);
+#endif /* WITH_SSL */
+
        return True;
 }
 
@@ -518,12 +834,28 @@ tcp_disconnect(void)
 #ifdef WITH_SSL
        if (g_ssl)
        {
-               if (!g_network_error)
-                       (void) SSL_shutdown(g_ssl);
-               SSL_free(g_ssl);
+               xfree(g_ssl->peek_msg_mem);
+               g_ssl->peek_msg_mem = NULL;
+               g_ssl->peek_msg = NULL;
+               g_ssl->peek_len = 0;
+               xfree(g_ssl->ssl_buf);
+               g_ssl->ssl_buf = NULL;
+               xfree(g_ssl->extra_buf);
+               g_ssl->extra_buf = NULL;
+               g_ssl->extra_len = 0;
+               if (SecIsValidHandle(&g_ssl->ssl_ctx))
+                       DeleteSecurityContext(&g_ssl->ssl_ctx);
+               if (cred_handle_initialized)
+                       FreeCredentialsHandle(&cred_handle);
+               if (have_compat_cred_handle)
+                       FreeCredentialsHandle(&compat_cred_handle);
+               if (g_ssl_server)
+               {
+                       xfree(g_ssl_server);
+                       g_ssl_server = NULL;
+               }
                g_ssl = NULL;
-               SSL_CTX_free(g_ssl_ctx);
-               g_ssl_ctx = NULL;
+               g_ssl_initialized = False;
        }
 #endif /* WITH_SSL */
        TCP_CLOSE(g_sock);