[MBEDTLS] Update to version 2.16.10. CORE-17326
[reactos.git] / dll / 3rdparty / mbedtls / rsa.c
index 94ace5e..c8c23db 100644 (file)
@@ -73,6 +73,7 @@
 #include "mbedtls/rsa.h"
 #include "mbedtls/rsa_internal.h"
 #include "mbedtls/oid.h"
+#include "mbedtls/platform_util.h"
 
 #include <string.h>
 
 
 #if !defined(MBEDTLS_RSA_ALT)
 
-/* Implementation that should never be optimized out by the compiler */
-static void mbedtls_zeroize( void *v, size_t n ) {
-    volatile unsigned char *p = (unsigned char*)v; while( n-- ) *p++ = 0;
-}
+/* Parameter validation macros */
+#define RSA_VALIDATE_RET( cond )                                       \
+    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_RSA_BAD_INPUT_DATA )
+#define RSA_VALIDATE( cond )                                           \
+    MBEDTLS_INTERNAL_VALIDATE( cond )
 
 #if defined(MBEDTLS_PKCS1_V15)
 /* constant-time buffer comparison */
@@ -122,6 +124,7 @@ int mbedtls_rsa_import( mbedtls_rsa_context *ctx,
                         const mbedtls_mpi *D, const mbedtls_mpi *E )
 {
     int ret;
+    RSA_VALIDATE_RET( ctx != NULL );
 
     if( ( N != NULL && ( ret = mbedtls_mpi_copy( &ctx->N, N ) ) != 0 ) ||
         ( P != NULL && ( ret = mbedtls_mpi_copy( &ctx->P, P ) ) != 0 ) ||
@@ -146,6 +149,7 @@ int mbedtls_rsa_import_raw( mbedtls_rsa_context *ctx,
                             unsigned char const *E, size_t E_len )
 {
     int ret = 0;
+    RSA_VALIDATE_RET( ctx != NULL );
 
     if( N != NULL )
     {
@@ -269,17 +273,24 @@ static int rsa_check_context( mbedtls_rsa_context const *ctx, int is_priv,
 int mbedtls_rsa_complete( mbedtls_rsa_context *ctx )
 {
     int ret = 0;
+    int have_N, have_P, have_Q, have_D, have_E;
+#if !defined(MBEDTLS_RSA_NO_CRT)
+    int have_DP, have_DQ, have_QP;
+#endif
+    int n_missing, pq_missing, d_missing, is_pub, is_priv;
+
+    RSA_VALIDATE_RET( ctx != NULL );
 
-    const int have_N = ( mbedtls_mpi_cmp_int( &ctx->N, 0 ) != 0 );
-    const int have_P = ( mbedtls_mpi_cmp_int( &ctx->P, 0 ) != 0 );
-    const int have_Q = ( mbedtls_mpi_cmp_int( &ctx->Q, 0 ) != 0 );
-    const int have_D = ( mbedtls_mpi_cmp_int( &ctx->D, 0 ) != 0 );
-    const int have_E = ( mbedtls_mpi_cmp_int( &ctx->E, 0 ) != 0 );
+    have_N = ( mbedtls_mpi_cmp_int( &ctx->N, 0 ) != 0 );
+    have_P = ( mbedtls_mpi_cmp_int( &ctx->P, 0 ) != 0 );
+    have_Q = ( mbedtls_mpi_cmp_int( &ctx->Q, 0 ) != 0 );
+    have_D = ( mbedtls_mpi_cmp_int( &ctx->D, 0 ) != 0 );
+    have_E = ( mbedtls_mpi_cmp_int( &ctx->E, 0 ) != 0 );
 
 #if !defined(MBEDTLS_RSA_NO_CRT)
-    const int have_DP = ( mbedtls_mpi_cmp_int( &ctx->DP, 0 ) != 0 );
-    const int have_DQ = ( mbedtls_mpi_cmp_int( &ctx->DQ, 0 ) != 0 );
-    const int have_QP = ( mbedtls_mpi_cmp_int( &ctx->QP, 0 ) != 0 );
+    have_DP = ( mbedtls_mpi_cmp_int( &ctx->DP, 0 ) != 0 );
+    have_DQ = ( mbedtls_mpi_cmp_int( &ctx->DQ, 0 ) != 0 );
+    have_QP = ( mbedtls_mpi_cmp_int( &ctx->QP, 0 ) != 0 );
 #endif
 
     /*
@@ -292,13 +303,13 @@ int mbedtls_rsa_complete( mbedtls_rsa_context *ctx )
      *
      */
 
-    const int n_missing  =              have_P &&  have_Q &&  have_D && have_E;
-    const int pq_missing =   have_N && !have_P && !have_Q &&  have_D && have_E;
-    const int d_missing  =              have_P &&  have_Q && !have_D && have_E;
-    const int is_pub     =   have_N && !have_P && !have_Q && !have_D && have_E;
+    n_missing  =              have_P &&  have_Q &&  have_D && have_E;
+    pq_missing =   have_N && !have_P && !have_Q &&  have_D && have_E;
+    d_missing  =              have_P &&  have_Q && !have_D && have_E;
+    is_pub     =   have_N && !have_P && !have_Q && !have_D && have_E;
 
     /* These three alternatives are mutually exclusive */
-    const int is_priv = n_missing || pq_missing || d_missing;
+    is_priv = n_missing || pq_missing || d_missing;
 
     if( !is_priv && !is_pub )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -371,9 +382,11 @@ int mbedtls_rsa_export_raw( const mbedtls_rsa_context *ctx,
                             unsigned char *E, size_t E_len )
 {
     int ret = 0;
+    int is_priv;
+    RSA_VALIDATE_RET( ctx != NULL );
 
     /* Check if key is private or public */
-    const int is_priv =
+    is_priv =
         mbedtls_mpi_cmp_int( &ctx->N, 0 ) != 0 &&
         mbedtls_mpi_cmp_int( &ctx->P, 0 ) != 0 &&
         mbedtls_mpi_cmp_int( &ctx->Q, 0 ) != 0 &&
@@ -414,9 +427,11 @@ int mbedtls_rsa_export( const mbedtls_rsa_context *ctx,
                         mbedtls_mpi *D, mbedtls_mpi *E )
 {
     int ret;
+    int is_priv;
+    RSA_VALIDATE_RET( ctx != NULL );
 
     /* Check if key is private or public */
-    int is_priv =
+    is_priv =
         mbedtls_mpi_cmp_int( &ctx->N, 0 ) != 0 &&
         mbedtls_mpi_cmp_int( &ctx->P, 0 ) != 0 &&
         mbedtls_mpi_cmp_int( &ctx->Q, 0 ) != 0 &&
@@ -456,9 +471,11 @@ int mbedtls_rsa_export_crt( const mbedtls_rsa_context *ctx,
                             mbedtls_mpi *DP, mbedtls_mpi *DQ, mbedtls_mpi *QP )
 {
     int ret;
+    int is_priv;
+    RSA_VALIDATE_RET( ctx != NULL );
 
     /* Check if key is private or public */
-    int is_priv =
+    is_priv =
         mbedtls_mpi_cmp_int( &ctx->N, 0 ) != 0 &&
         mbedtls_mpi_cmp_int( &ctx->P, 0 ) != 0 &&
         mbedtls_mpi_cmp_int( &ctx->Q, 0 ) != 0 &&
@@ -494,6 +511,10 @@ void mbedtls_rsa_init( mbedtls_rsa_context *ctx,
                int padding,
                int hash_id )
 {
+    RSA_VALIDATE( ctx != NULL );
+    RSA_VALIDATE( padding == MBEDTLS_RSA_PKCS_V15 ||
+                  padding == MBEDTLS_RSA_PKCS_V21 );
+
     memset( ctx, 0, sizeof( mbedtls_rsa_context ) );
 
     mbedtls_rsa_set_padding( ctx, padding, hash_id );
@@ -509,8 +530,13 @@ void mbedtls_rsa_init( mbedtls_rsa_context *ctx,
 /*
  * Set padding for an existing RSA context
  */
-void mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding, int hash_id )
+void mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding,
+                              int hash_id )
 {
+    RSA_VALIDATE( ctx != NULL );
+    RSA_VALIDATE( padding == MBEDTLS_RSA_PKCS_V15 ||
+                  padding == MBEDTLS_RSA_PKCS_V21 );
+
     ctx->padding = padding;
     ctx->hash_id = hash_id;
 }
@@ -529,6 +555,9 @@ size_t mbedtls_rsa_get_len( const mbedtls_rsa_context *ctx )
 
 /*
  * Generate an RSA keypair
+ *
+ * This generation method follows the RSA key pair generation procedure of
+ * FIPS 186-4 if 2^16 < exponent < 2^256 and nbits = 2048 or nbits = 3072.
  */
 int mbedtls_rsa_gen_key( mbedtls_rsa_context *ctx,
                  int (*f_rng)(void *, unsigned char *, size_t),
@@ -536,12 +565,24 @@ int mbedtls_rsa_gen_key( mbedtls_rsa_context *ctx,
                  unsigned int nbits, int exponent )
 {
     int ret;
-    mbedtls_mpi H, G;
+    mbedtls_mpi H, G, L;
+    int prime_quality = 0;
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( f_rng != NULL );
+
+    /*
+     * If the modulus is 1024 bit long or shorter, then the security strength of
+     * the RSA algorithm is less than or equal to 80 bits and therefore an error
+     * rate of 2^-80 is sufficient.
+     */
+    if( nbits > 1024 )
+        prime_quality = MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR;
 
     mbedtls_mpi_init( &H );
     mbedtls_mpi_init( &G );
+    mbedtls_mpi_init( &L );
 
-    if( f_rng == NULL || nbits < 128 || exponent < 3 || nbits % 2 != 0 )
+    if( nbits < 128 || exponent < 3 || nbits % 2 != 0 )
     {
         ret = MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
         goto cleanup;
@@ -549,52 +590,65 @@ int mbedtls_rsa_gen_key( mbedtls_rsa_context *ctx,
 
     /*
      * find primes P and Q with Q < P so that:
-     * GCD( E, (P-1)*(Q-1) ) == 1
+     * 1.  |P-Q| > 2^( nbits / 2 - 100 )
+     * 2.  GCD( E, (P-1)*(Q-1) ) == 1
+     * 3.  E^-1 mod LCM(P-1, Q-1) > 2^( nbits / 2 )
      */
     MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &ctx->E, exponent ) );
 
     do
     {
-        MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->P, nbits >> 1, 0,
-                                                f_rng, p_rng ) );
+        MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->P, nbits >> 1,
+                                                prime_quality, f_rng, p_rng ) );
 
-        MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->Q, nbits >> 1, 0,
-                                                f_rng, p_rng ) );
+        MBEDTLS_MPI_CHK( mbedtls_mpi_gen_prime( &ctx->Q, nbits >> 1,
+                                                prime_quality, f_rng, p_rng ) );
 
-        if( mbedtls_mpi_cmp_mpi( &ctx->P, &ctx->Q ) == 0 )
+        /* make sure the difference between p and q is not too small (FIPS 186-4 §B.3.3 step 5.4) */
+        MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &H, &ctx->P, &ctx->Q ) );
+        if( mbedtls_mpi_bitlen( &H ) <= ( ( nbits >= 200 ) ? ( ( nbits >> 1 ) - 99 ) : 0 ) )
             continue;
 
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->N, &ctx->P, &ctx->Q ) );
-        if( mbedtls_mpi_bitlen( &ctx->N ) != nbits )
-            continue;
-
-        if( mbedtls_mpi_cmp_mpi( &ctx->P, &ctx->Q ) < 0 )
+        /* not required by any standards, but some users rely on the fact that P > Q */
+        if( H.s < 0 )
             mbedtls_mpi_swap( &ctx->P, &ctx->Q );
 
         /* Temporarily replace P,Q by P-1, Q-1 */
         MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &ctx->P, &ctx->P, 1 ) );
         MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &ctx->Q, &ctx->Q, 1 ) );
         MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &H, &ctx->P, &ctx->Q ) );
+
+        /* check GCD( E, (P-1)*(Q-1) ) == 1 (FIPS 186-4 §B.3.1 criterion 2(a)) */
         MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->E, &H  ) );
+        if( mbedtls_mpi_cmp_int( &G, 1 ) != 0 )
+            continue;
+
+        /* compute smallest possible D = E^-1 mod LCM(P-1, Q-1) (FIPS 186-4 §B.3.1 criterion 3(b)) */
+        MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, &ctx->P, &ctx->Q ) );
+        MBEDTLS_MPI_CHK( mbedtls_mpi_div_mpi( &L, NULL, &H, &G ) );
+        MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->D, &ctx->E, &L ) );
+
+        if( mbedtls_mpi_bitlen( &ctx->D ) <= ( ( nbits + 1 ) / 2 ) ) // (FIPS 186-4 §B.3.1 criterion 3(a))
+            continue;
+
+        break;
     }
-    while( mbedtls_mpi_cmp_int( &G, 1 ) != 0 );
+    while( 1 );
 
     /* Restore P,Q */
     MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &ctx->P,  &ctx->P, 1 ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &ctx->Q,  &ctx->Q, 1 ) );
 
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &ctx->N, &ctx->P, &ctx->Q ) );
+
     ctx->len = mbedtls_mpi_size( &ctx->N );
 
+#if !defined(MBEDTLS_RSA_NO_CRT)
     /*
-     * D  = E^-1 mod ((P-1)*(Q-1))
      * DP = D mod (P - 1)
      * DQ = D mod (Q - 1)
      * QP = Q^-1 mod P
      */
-
-    MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( &ctx->D, &ctx->E, &H  ) );
-
-#if !defined(MBEDTLS_RSA_NO_CRT)
     MBEDTLS_MPI_CHK( mbedtls_rsa_deduce_crt( &ctx->P, &ctx->Q, &ctx->D,
                                              &ctx->DP, &ctx->DQ, &ctx->QP ) );
 #endif /* MBEDTLS_RSA_NO_CRT */
@@ -606,6 +660,7 @@ cleanup:
 
     mbedtls_mpi_free( &H );
     mbedtls_mpi_free( &G );
+    mbedtls_mpi_free( &L );
 
     if( ret != 0 )
     {
@@ -625,6 +680,8 @@ cleanup:
  */
 int mbedtls_rsa_check_pubkey( const mbedtls_rsa_context *ctx )
 {
+    RSA_VALIDATE_RET( ctx != NULL );
+
     if( rsa_check_context( ctx, 0 /* public */, 0 /* no blinding */ ) != 0 )
         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
 
@@ -648,6 +705,8 @@ int mbedtls_rsa_check_pubkey( const mbedtls_rsa_context *ctx )
  */
 int mbedtls_rsa_check_privkey( const mbedtls_rsa_context *ctx )
 {
+    RSA_VALIDATE_RET( ctx != NULL );
+
     if( mbedtls_rsa_check_pubkey( ctx ) != 0 ||
         rsa_check_context( ctx, 1 /* private */, 1 /* blinding */ ) != 0 )
     {
@@ -677,6 +736,9 @@ int mbedtls_rsa_check_privkey( const mbedtls_rsa_context *ctx )
 int mbedtls_rsa_check_pub_priv( const mbedtls_rsa_context *pub,
                                 const mbedtls_rsa_context *prv )
 {
+    RSA_VALIDATE_RET( pub != NULL );
+    RSA_VALIDATE_RET( prv != NULL );
+
     if( mbedtls_rsa_check_pubkey( pub )  != 0 ||
         mbedtls_rsa_check_privkey( prv ) != 0 )
     {
@@ -702,6 +764,9 @@ int mbedtls_rsa_public( mbedtls_rsa_context *ctx,
     int ret;
     size_t olen;
     mbedtls_mpi T;
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( input != NULL );
+    RSA_VALIDATE_RET( output != NULL );
 
     if( rsa_check_context( ctx, 0 /* public */, 0 /* no blinding */ ) )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -869,6 +934,10 @@ int mbedtls_rsa_private( mbedtls_rsa_context *ctx,
      * checked result; should be the same in the end. */
     mbedtls_mpi I, C;
 
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( input  != NULL );
+    RSA_VALIDATE_RET( output != NULL );
+
     if( rsa_check_context( ctx, 1             /* private key checks */,
                                 f_rng != NULL /* blinding y/n       */ ) != 0 )
     {
@@ -1103,7 +1172,7 @@ static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
     }
 
 exit:
-    mbedtls_zeroize( mask, sizeof( mask ) );
+    mbedtls_platform_zeroize( mask, sizeof( mask ) );
 
     return( ret );
 }
@@ -1129,6 +1198,13 @@ int mbedtls_rsa_rsaes_oaep_encrypt( mbedtls_rsa_context *ctx,
     const mbedtls_md_info_t *md_info;
     mbedtls_md_context_t md_ctx;
 
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( output != NULL );
+    RSA_VALIDATE_RET( input != NULL );
+    RSA_VALIDATE_RET( label_len == 0 || label != NULL );
+
     if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
@@ -1205,11 +1281,13 @@ int mbedtls_rsa_rsaes_pkcs1_v15_encrypt( mbedtls_rsa_context *ctx,
     int ret;
     unsigned char *p = output;
 
-    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
-        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( output != NULL );
+    RSA_VALIDATE_RET( input != NULL );
 
-    // We don't check p_rng because it won't be dereferenced here
-    if( f_rng == NULL || input == NULL || output == NULL )
+    if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
     olen = ctx->len;
@@ -1223,6 +1301,9 @@ int mbedtls_rsa_rsaes_pkcs1_v15_encrypt( mbedtls_rsa_context *ctx,
     *p++ = 0;
     if( mode == MBEDTLS_RSA_PUBLIC )
     {
+        if( f_rng == NULL )
+            return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
         *p++ = MBEDTLS_RSA_CRYPT;
 
         while( nb_pad-- > 0 )
@@ -1267,6 +1348,12 @@ int mbedtls_rsa_pkcs1_encrypt( mbedtls_rsa_context *ctx,
                        const unsigned char *input,
                        unsigned char *output )
 {
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( output != NULL );
+    RSA_VALIDATE_RET( input != NULL );
+
     switch( ctx->padding )
     {
 #if defined(MBEDTLS_PKCS1_V15)
@@ -1309,6 +1396,14 @@ int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx,
     const mbedtls_md_info_t *md_info;
     mbedtls_md_context_t md_ctx;
 
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( output_max_len == 0 || output != NULL );
+    RSA_VALIDATE_RET( label_len == 0 || label != NULL );
+    RSA_VALIDATE_RET( input != NULL );
+    RSA_VALIDATE_RET( olen != NULL );
+
     /*
      * Parameters sanity checks
      */
@@ -1417,8 +1512,8 @@ int mbedtls_rsa_rsaes_oaep_decrypt( mbedtls_rsa_context *ctx,
     ret = 0;
 
 cleanup:
-    mbedtls_zeroize( buf, sizeof( buf ) );
-    mbedtls_zeroize( lhash, sizeof( lhash ) );
+    mbedtls_platform_zeroize( buf, sizeof( buf ) );
+    mbedtls_platform_zeroize( lhash, sizeof( lhash ) );
 
     return( ret );
 }
@@ -1528,11 +1623,7 @@ int mbedtls_rsa_rsaes_pkcs1_v15_decrypt( mbedtls_rsa_context *ctx,
                                  size_t output_max_len )
 {
     int ret;
-    size_t ilen = ctx->len;
-    size_t i;
-    size_t plaintext_max_size = ( output_max_len > ilen - 11 ?
-                                  ilen - 11 :
-                                  output_max_len );
+    size_t ilen, i, plaintext_max_size;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
     /* The following variables take sensitive values: their value must
      * not leak into the observable behavior of the function other than
@@ -1550,6 +1641,18 @@ int mbedtls_rsa_rsaes_pkcs1_v15_decrypt( mbedtls_rsa_context *ctx,
     size_t plaintext_size = 0;
     unsigned output_too_large;
 
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( output_max_len == 0 || output != NULL );
+    RSA_VALIDATE_RET( input != NULL );
+    RSA_VALIDATE_RET( olen != NULL );
+
+    ilen = ctx->len;
+    plaintext_max_size = ( output_max_len > ilen - 11 ?
+                           ilen - 11 :
+                           output_max_len );
+
     if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
@@ -1668,7 +1771,7 @@ int mbedtls_rsa_rsaes_pkcs1_v15_decrypt( mbedtls_rsa_context *ctx,
     *olen = plaintext_size;
 
 cleanup:
-    mbedtls_zeroize( buf, sizeof( buf ) );
+    mbedtls_platform_zeroize( buf, sizeof( buf ) );
 
     return( ret );
 }
@@ -1685,6 +1788,13 @@ int mbedtls_rsa_pkcs1_decrypt( mbedtls_rsa_context *ctx,
                        unsigned char *output,
                        size_t output_max_len)
 {
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( output_max_len == 0 || output != NULL );
+    RSA_VALIDATE_RET( input != NULL );
+    RSA_VALIDATE_RET( olen != NULL );
+
     switch( ctx->padding )
     {
 #if defined(MBEDTLS_PKCS1_V15)
@@ -1721,11 +1831,18 @@ int mbedtls_rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
     size_t olen;
     unsigned char *p = sig;
     unsigned char salt[MBEDTLS_MD_MAX_SIZE];
-    unsigned int slen, hlen, offset = 0;
+    size_t slen, min_slen, hlen, offset = 0;
     int ret;
     size_t msb;
     const mbedtls_md_info_t *md_info;
     mbedtls_md_context_t md_ctx;
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
+                        hashlen == 0 ) ||
+                      hash != NULL );
+    RSA_VALIDATE_RET( sig != NULL );
 
     if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
@@ -1750,10 +1867,20 @@ int mbedtls_rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
     hlen = mbedtls_md_get_size( md_info );
-    slen = hlen;
 
-    if( olen < hlen + slen + 2 )
+    /* Calculate the largest possible salt length. Normally this is the hash
+     * length, which is the maximum length the salt can have. If there is not
+     * enough room, use the maximum salt length that fits. The constraint is
+     * that the hash length plus the salt length plus 2 bytes must be at most
+     * the key length. This complies with FIPS 186-4 §5.5 (e) and RFC 8017
+     * (PKCS#1 v2.2) §9.1.1 step 3. */
+    min_slen = hlen - 2;
+    if( olen < hlen + min_slen + 2 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+    else if( olen >= hlen + hlen + 2 )
+        slen = hlen;
+    else
+        slen = olen - hlen - 2;
 
     memset( sig, 0, olen );
 
@@ -1763,7 +1890,7 @@ int mbedtls_rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
 
     /* Note: EMSA-PSS encoding is over the length of N - 1 bits */
     msb = mbedtls_mpi_bitlen( &ctx->N ) - 1;
-    p += olen - hlen * 2 - 2;
+    p += olen - hlen - slen - 2;
     *p++ = 0x01;
     memcpy( p, salt, slen );
     p += slen;
@@ -1799,7 +1926,7 @@ int mbedtls_rsa_rsassa_pss_sign( mbedtls_rsa_context *ctx,
     p += hlen;
     *p++ = 0xBC;
 
-    mbedtls_zeroize( salt, sizeof( salt ) );
+    mbedtls_platform_zeroize( salt, sizeof( salt ) );
 
 exit:
     mbedtls_md_free( &md_ctx );
@@ -1941,7 +2068,7 @@ static int rsa_rsassa_pkcs1_v15_encode( mbedtls_md_type_t md_alg,
      * after the initial bounds check. */
     if( p != dst + dst_len )
     {
-        mbedtls_zeroize( dst, dst_len );
+        mbedtls_platform_zeroize( dst, dst_len );
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
     }
 
@@ -1963,6 +2090,14 @@ int mbedtls_rsa_rsassa_pkcs1_v15_sign( mbedtls_rsa_context *ctx,
     int ret;
     unsigned char *sig_try = NULL, *verif = NULL;
 
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
+                        hashlen == 0 ) ||
+                      hash != NULL );
+    RSA_VALIDATE_RET( sig != NULL );
+
     if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
@@ -2032,6 +2167,14 @@ int mbedtls_rsa_pkcs1_sign( mbedtls_rsa_context *ctx,
                     const unsigned char *hash,
                     unsigned char *sig )
 {
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
+                        hashlen == 0 ) ||
+                      hash != NULL );
+    RSA_VALIDATE_RET( sig != NULL );
+
     switch( ctx->padding )
     {
 #if defined(MBEDTLS_PKCS1_V15)
@@ -2078,6 +2221,14 @@ int mbedtls_rsa_rsassa_pss_verify_ext( mbedtls_rsa_context *ctx,
     mbedtls_md_context_t md_ctx;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
 
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( sig != NULL );
+    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
+                        hashlen == 0 ) ||
+                      hash != NULL );
+
     if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V21 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
@@ -2206,7 +2357,16 @@ int mbedtls_rsa_rsassa_pss_verify( mbedtls_rsa_context *ctx,
                            const unsigned char *hash,
                            const unsigned char *sig )
 {
-    mbedtls_md_type_t mgf1_hash_id = ( ctx->hash_id != MBEDTLS_MD_NONE )
+    mbedtls_md_type_t mgf1_hash_id;
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( sig != NULL );
+    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
+                        hashlen == 0 ) ||
+                      hash != NULL );
+
+    mgf1_hash_id = ( ctx->hash_id != MBEDTLS_MD_NONE )
                              ? (mbedtls_md_type_t) ctx->hash_id
                              : md_alg;
 
@@ -2232,9 +2392,19 @@ int mbedtls_rsa_rsassa_pkcs1_v15_verify( mbedtls_rsa_context *ctx,
                                  const unsigned char *sig )
 {
     int ret = 0;
-    const size_t sig_len = ctx->len;
+    size_t sig_len;
     unsigned char *encoded = NULL, *encoded_expected = NULL;
 
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( sig != NULL );
+    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
+                        hashlen == 0 ) ||
+                      hash != NULL );
+
+    sig_len = ctx->len;
+
     if( mode == MBEDTLS_RSA_PRIVATE && ctx->padding != MBEDTLS_RSA_PKCS_V15 )
         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
 
@@ -2278,13 +2448,13 @@ cleanup:
 
     if( encoded != NULL )
     {
-        mbedtls_zeroize( encoded, sig_len );
+        mbedtls_platform_zeroize( encoded, sig_len );
         mbedtls_free( encoded );
     }
 
     if( encoded_expected != NULL )
     {
-        mbedtls_zeroize( encoded_expected, sig_len );
+        mbedtls_platform_zeroize( encoded_expected, sig_len );
         mbedtls_free( encoded_expected );
     }
 
@@ -2304,6 +2474,14 @@ int mbedtls_rsa_pkcs1_verify( mbedtls_rsa_context *ctx,
                       const unsigned char *hash,
                       const unsigned char *sig )
 {
+    RSA_VALIDATE_RET( ctx != NULL );
+    RSA_VALIDATE_RET( mode == MBEDTLS_RSA_PRIVATE ||
+                      mode == MBEDTLS_RSA_PUBLIC );
+    RSA_VALIDATE_RET( sig != NULL );
+    RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
+                        hashlen == 0 ) ||
+                      hash != NULL );
+
     switch( ctx->padding )
     {
 #if defined(MBEDTLS_PKCS1_V15)
@@ -2329,6 +2507,8 @@ int mbedtls_rsa_pkcs1_verify( mbedtls_rsa_context *ctx,
 int mbedtls_rsa_copy( mbedtls_rsa_context *dst, const mbedtls_rsa_context *src )
 {
     int ret;
+    RSA_VALIDATE_RET( dst != NULL );
+    RSA_VALIDATE_RET( src != NULL );
 
     dst->len = src->len;
 
@@ -2367,14 +2547,23 @@ cleanup:
  */
 void mbedtls_rsa_free( mbedtls_rsa_context *ctx )
 {
-    mbedtls_mpi_free( &ctx->Vi ); mbedtls_mpi_free( &ctx->Vf );
-    mbedtls_mpi_free( &ctx->RN ); mbedtls_mpi_free( &ctx->D  );
-    mbedtls_mpi_free( &ctx->Q  ); mbedtls_mpi_free( &ctx->P  );
-    mbedtls_mpi_free( &ctx->E  ); mbedtls_mpi_free( &ctx->N  );
+    if( ctx == NULL )
+        return;
+
+    mbedtls_mpi_free( &ctx->Vi );
+    mbedtls_mpi_free( &ctx->Vf );
+    mbedtls_mpi_free( &ctx->RN );
+    mbedtls_mpi_free( &ctx->D  );
+    mbedtls_mpi_free( &ctx->Q  );
+    mbedtls_mpi_free( &ctx->P  );
+    mbedtls_mpi_free( &ctx->E  );
+    mbedtls_mpi_free( &ctx->N  );
 
 #if !defined(MBEDTLS_RSA_NO_CRT)
-    mbedtls_mpi_free( &ctx->RQ ); mbedtls_mpi_free( &ctx->RP );
-    mbedtls_mpi_free( &ctx->QP ); mbedtls_mpi_free( &ctx->DQ );
+    mbedtls_mpi_free( &ctx->RQ );
+    mbedtls_mpi_free( &ctx->RP );
+    mbedtls_mpi_free( &ctx->QP );
+    mbedtls_mpi_free( &ctx->DQ );
     mbedtls_mpi_free( &ctx->DP );
 #endif /* MBEDTLS_RSA_NO_CRT */