From b10570807a47346eee8ffa010acba942f054c02e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Niels=20M=C3=B6ller?= <nisse@lysator.liu.se>
Date: Mon, 6 Nov 2023 19:15:22 +0100
Subject: [PATCH] Improve side-channel silence when comparing values to zero.

---
 ChangeLog           | 25 +++++++++++++++++++++++++
 cnd-copy.c          |  2 +-
 ecc-gostdsa-sign.c  |  2 +-
 ecc-internal.h      |  8 ++++++++
 ecc-j-to-a.c        |  2 +-
 ecc-mod-arith.c     |  4 ++--
 ecc-mul-a-eh.c      | 15 ++++++---------
 ecc-mul-a.c         | 27 ++++++++++++++-------------
 ecc-mul-g.c         | 12 ++++++++----
 eddsa-decompress.c  |  2 +-
 memeql-sec.c        |  3 ++-
 nettle-internal.h   |  6 +++++-
 pkcs1-sec-decrypt.c |  4 ++--
 13 files changed, 76 insertions(+), 36 deletions(-)

diff --git a/ChangeLog b/ChangeLog
index 278de08b..4d53b896 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,28 @@
+2023-11-06  Niels Möller  <nisse@lysator.liu.se>
+
+	Avoid comparison like cnd = (x == 0) in code intended to be
+	side-channel silent, since to eliminate branches with some
+	compilers/architectures, in particular 32-bit x86 and the msvc compiler.
+	* nettle-internal.h (IS_ZERO_SMALL): New macro.
+	* memeql-sec.c (memeql_sec): Use IS_ZERO_SMALL.
+	* pkcs1-sec-decrypt.c (EQUAL): Likewise.
+
+	* cnd-copy.c (cnd_copy): Require that cnd argument is 1 or 0.
+	* ecc-mul-a.c (ecc_mul_a) [ECC_MUL_A_WBITS == 0]:
+	Rearrange loop to pass 0 or 1 to cnd_copy.
+	* ecc-mul-a-eh.c (ecc_mul_a_eh) [ECC_MUL_A_EH_WBITS == 0]:
+	Likewise.
+	* ecc-mul-a.c (ecc_mul_a) [ECC_MUL_A_WBITS > 0]: Use
+	IS_ZERO_SMALL, and pass 0 or 1 to cnd_copy.
+	* ecc-mul-g.c (ecc_mul_g): Likewise.
+
+	* ecc-internal.h (is_zero_limb): New inline function.
+	* eddsa-decompress.c (_eddsa_decompress): Likewise.
+	* ecc-gostdsa-sign.c (ecc_gostdsa_sign): Likewise.
+	* ecc-mod-arith.c (ecc_mod_zero_p): Likewise.
+	(ecc_mod_equal_p): Avoid comparison cy == 0.
+	* ecc-j-to-a.c (ecc_j_to_a): Avoid comparison cy == 0.
+
 2023-10-06  Niels Möller  <nisse@lysator.liu.se>
 
 	* testsuite/rsa-sec-decrypt-test.c (test_main): Skip side-channel
diff --git a/cnd-copy.c b/cnd-copy.c
index d24da3d0..878cd320 100644
--- a/cnd-copy.c
+++ b/cnd-copy.c
@@ -43,7 +43,7 @@ cnd_copy (int cnd, mp_limb_t *rp, const mp_limb_t *ap, mp_size_t n)
   mp_limb_t mask, keep;
   mp_size_t i;
 
-  mask = -(mp_limb_t) (cnd !=0);
+  mask = -(mp_limb_t) cnd;
   keep = ~mask;
 
   for (i = 0; i < n; i++)
diff --git a/ecc-gostdsa-sign.c b/ecc-gostdsa-sign.c
index 491a2281..c811c87e 100644
--- a/ecc-gostdsa-sign.c
+++ b/ecc-gostdsa-sign.c
@@ -91,7 +91,7 @@ ecc_gostdsa_sign (const struct ecc_curve *ecc,
    * so one subtraction should suffice. */
 
   *scratch = mpn_sub_n (tp, sp, ecc->q.m, ecc->p.size);
-  cnd_copy (*scratch == 0, sp, tp, ecc->p.size);
+  cnd_copy (is_zero_limb (*scratch), sp, tp, ecc->p.size);
 
 #undef P
 #undef hp
diff --git a/ecc-internal.h b/ecc-internal.h
index be02de5f..2a5e3ae1 100644
--- a/ecc-internal.h
+++ b/ecc-internal.h
@@ -85,6 +85,13 @@
 #define curve25519_eh_to_x _nettle_curve25519_eh_to_x
 #define curve448_eh_to_x _nettle_curve448_eh_to_x
 
+inline int
+is_zero_limb (mp_limb_t x)
+{
+  x |= (x << 1);
+  return ((x >> 1) - 1) >> (GMP_LIMB_BITS - 1);
+}
+
 extern const struct ecc_curve _nettle_secp_192r1;
 extern const struct ecc_curve _nettle_secp_224r1;
 extern const struct ecc_curve _nettle_secp_256r1;
@@ -464,6 +471,7 @@ ecc_mul_m (const struct ecc_modulo *m,
 	   mp_limb_t *qx, const uint8_t *n, const mp_limb_t *px,
 	   mp_limb_t *scratch);
 
+/* The cnd argument must be 1 or 0. */
 void
 cnd_copy (int cnd, mp_limb_t *rp, const mp_limb_t *ap, mp_size_t n);
 
diff --git a/ecc-j-to-a.c b/ecc-j-to-a.c
index ac134b51..3196cd6f 100644
--- a/ecc-j-to-a.c
+++ b/ecc-j-to-a.c
@@ -72,7 +72,7 @@ ecc_j_to_a (const struct ecc_curve *ecc,
 	     already be < 2*ecc->q, so one subtraction should
 	     suffice. */
 	  cy = mpn_sub_n (scratch, r, ecc->q.m, ecc->p.size);
-	  cnd_copy (cy == 0, r, scratch, ecc->p.size);
+	  cnd_copy (1 - cy, r, scratch, ecc->p.size);
 	}
       return;
     }
diff --git a/ecc-mod-arith.c b/ecc-mod-arith.c
index d0137864..3b9bcb47 100644
--- a/ecc-mod-arith.c
+++ b/ecc-mod-arith.c
@@ -55,7 +55,7 @@ ecc_mod_zero_p (const struct ecc_modulo *m, const mp_limb_t *xp_in)
       is_not_p |= (xp[i] ^ m->m[i]);
     }
 
-  return (is_non_zero == 0) | (is_not_p == 0);
+  return is_zero_limb (is_non_zero) | is_zero_limb (is_not_p);
 }
 
 int
@@ -65,7 +65,7 @@ ecc_mod_equal_p (const struct ecc_modulo *m, const mp_limb_t *a,
   mp_limb_t cy;
   cy = mpn_sub_n (scratch, a, ref, m->size);
   /* If cy > 0, i.e., a < ref, then they can't be equal mod m. */
-  return (cy == 0) & ecc_mod_zero_p (m, scratch);
+  return (1 - cy) & ecc_mod_zero_p (m, scratch);
 }
 
 void
diff --git a/ecc-mul-a-eh.c b/ecc-mul-a-eh.c
index 980fec3f..30130aea 100644
--- a/ecc-mul-a-eh.c
+++ b/ecc-mul-a-eh.c
@@ -66,21 +66,18 @@ ecc_mul_a_eh (const struct ecc_curve *ecc,
   
   for (i = ecc->p.size; i-- > 0; )
     {
-      mp_limb_t w = np[i];
-      mp_limb_t bit;
+      mp_limb_t w = np[i] << (GMP_LIMB_BITS - GMP_NUMB_BITS);
+      unsigned j;
 
-      for (bit = (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
-	   bit > 0;
-	   bit >>= 1)
+      for (j = 0; j < GMP_NUMB_BITS; j++, w <<= 1)
 	{
-	  int digit;
-
+	  int bit;
 	  ecc->dup (ecc, r, r, scratch_out);
 	  ecc->add_hh (ecc, tp, r, pe, scratch_out);
 
-	  digit = (w & bit) > 0;
+	  bit = w >> (GMP_LIMB_BITS - 1);
 	  /* If we had a one-bit, use the sum. */
-	  cnd_copy (digit, r, tp, 3*ecc->p.size);
+	  cnd_copy (bit, r, tp, 3*ecc->p.size);
 	}
     }
 }
diff --git a/ecc-mul-a.c b/ecc-mul-a.c
index 8e1355eb..cc2a7960 100644
--- a/ecc-mul-a.c
+++ b/ecc-mul-a.c
@@ -39,6 +39,7 @@
 
 #include "ecc.h"
 #include "ecc-internal.h"
+#include "nettle-internal.h"
 
 /* Binary algorithm needs 6*ecc->p.size + scratch for ecc_add_jja.
    Current total is 12 ecc->p.size, at most 864 bytes.
@@ -67,25 +68,23 @@ ecc_mul_a (const struct ecc_curve *ecc,
   
   for (i = ecc->p.size, is_zero = 1; i-- > 0; )
     {
-      mp_limb_t w = np[i];
-      mp_limb_t bit;
+      mp_limb_t w = np[i] << (GMP_LIMB_BITS - GMP_NUMB_BITS);
+      unsigned j;
 
-      for (bit = (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
-	   bit > 0;
-	   bit >>= 1)
+      for (j = 0; j < GMP_NUMB_BITS; j++, w <<= 1)
 	{
-	  int digit;
+	  int bit;
 
 	  ecc_dup_jj (ecc, r, r, scratch_out);
 	  ecc_add_jja (ecc, tp, r, pj, scratch_out);
 
-	  digit = (w & bit) > 0;
+	  bit = w >> (GMP_LIMB_BITS - 1);
 	  /* If is_zero is set, r is the zero point,
 	     and ecc_add_jja produced garbage. */
 	  cnd_copy (is_zero, tp, pj, 3*ecc->p.size);
-	  is_zero &= ~digit;
+	  is_zero &= 1 - bit;
 	  /* If we had a one-bit, use the sum. */
-	  cnd_copy (digit, r, tp, 3*ecc->p.size);
+	  cnd_copy (bit, r, tp, 3*ecc->p.size);
 	}
     }
 }
@@ -145,10 +144,11 @@ ecc_mul_a (const struct ecc_curve *ecc,
   assert (bits < TABLE_SIZE);
 
   mpn_sec_tabselect (r, table, 3*ecc->p.size, TABLE_SIZE, bits);
-  is_zero = (bits == 0);
+  is_zero = IS_ZERO_SMALL (bits);
 
   for (;;)
     {
+      int bits_is_zero;
       unsigned j;
       if (shift >= ECC_MUL_A_WBITS)
 	{
@@ -174,11 +174,12 @@ ecc_mul_a (const struct ecc_curve *ecc,
       mpn_sec_tabselect (tp, table, 3*ecc->p.size, TABLE_SIZE, bits);
       cnd_copy (is_zero, r, tp, 3*ecc->p.size);
       ecc_add_jjj (ecc, tp, tp, r, scratch_out);
+      bits_is_zero = IS_ZERO_SMALL (bits);
 
       /* Use the sum when valid. ecc_add_jja produced garbage if
-	 is_zero != 0 or bits == 0, . */	  
-      cnd_copy (bits & (is_zero - 1), r, tp, 3*ecc->p.size);
-      is_zero &= (bits == 0);
+	 is_zero or bits_is_zero. */
+      cnd_copy (1 - (bits_is_zero | is_zero), r, tp, 3*ecc->p.size);
+      is_zero &= bits_is_zero;
     }
 #undef table
 #undef tp
diff --git a/ecc-mul-g.c b/ecc-mul-g.c
index 677a37e7..97bbabad 100644
--- a/ecc-mul-g.c
+++ b/ecc-mul-g.c
@@ -39,6 +39,7 @@
 
 #include "ecc.h"
 #include "ecc-internal.h"
+#include "nettle-internal.h"
 
 void
 ecc_mul_g (const struct ecc_curve *ecc, mp_limb_t *r,
@@ -71,7 +72,8 @@ ecc_mul_g (const struct ecc_curve *ecc, mp_limb_t *r,
 	  /* Avoid the mp_bitcnt_t type for compatibility with older GMP
 	     versions. */
 	  unsigned bit_index;
-	  
+	  int bits_is_zero;
+
 	  /* Extract c bits from n, stride k, starting at i + kcj,
 	     ending at i + k (cj + c - 1)*/
 	  for (bits = 0, bit_index = i + k*(c*j+c); bit_index > i + k*c*j; )
@@ -96,10 +98,12 @@ ecc_mul_g (const struct ecc_curve *ecc, mp_limb_t *r,
 	  cnd_copy (is_zero, r + 2*ecc->p.size, ecc->unit, ecc->p.size);
 	  
 	  ecc_add_jja (ecc, tp, r, tp, scratch_out);
+	  bits_is_zero = IS_ZERO_SMALL (bits);
+
 	  /* Use the sum when valid. ecc_add_jja produced garbage if
-	     is_zero != 0 or bits == 0, . */	  
-	  cnd_copy (bits & (is_zero - 1), r, tp, 3*ecc->p.size);
-	  is_zero &= (bits == 0);
+	     is_zero or bits_is_zero. */
+	  cnd_copy (1 - (bits_is_zero | is_zero), r, tp, 3*ecc->p.size);
+	  is_zero &= bits_is_zero;
 	}
     }
 #undef tp
diff --git a/eddsa-decompress.c b/eddsa-decompress.c
index 8517fb7b..0718c1a0 100644
--- a/eddsa-decompress.c
+++ b/eddsa-decompress.c
@@ -83,7 +83,7 @@ _eddsa_decompress (const struct ecc_curve *ecc, mp_limb_t *p,
 
   /* Check range. */
   if (nlimbs > ecc->p.size)
-    res = (scratch[nlimbs - 1] == 0);
+    res = is_zero_limb (scratch[nlimbs - 1]);
   else
     res = 1;
 
diff --git a/memeql-sec.c b/memeql-sec.c
index b19052ed..ba8c5dd6 100644
--- a/memeql-sec.c
+++ b/memeql-sec.c
@@ -34,6 +34,7 @@
 #endif
 
 #include "memops.h"
+#include "nettle-internal.h"
 
 int
 memeql_sec (const void *a, const void *b, size_t n)
@@ -47,5 +48,5 @@ memeql_sec (const void *a, const void *b, size_t n)
   for (i = diff = 0; i < n; i++)
     diff |= (ap[i] ^ bp[i]);
 
-  return diff == 0;
+  return IS_ZERO_SMALL (diff);
 }
diff --git a/nettle-internal.h b/nettle-internal.h
index c41f3ee0..2b7dc816 100644
--- a/nettle-internal.h
+++ b/nettle-internal.h
@@ -84,6 +84,10 @@
 #define NETTLE_MAX_CIPHER_BLOCK_SIZE 32
 #define NETTLE_MAX_CIPHER_KEY_SIZE 32
 
+/* Equivalent to x == 0, but with an expression that should compile to
+   branch free code on all compilers. Requires that x is at most 31 bits. */
+#define IS_ZERO_SMALL(x) (((uint32_t) (x) - 1U) >> 31)
+
 /* Doesn't quite fit with the other algorithms, because of the weak
  * keys. Weak keys are not reported, the functions will simply crash
  * if you try to use a weak key. */
@@ -103,7 +107,7 @@ extern const struct nettle_aead nettle_chacha;
 extern const struct nettle_aead nettle_salsa20;
 extern const struct nettle_aead nettle_salsa20r12;
 
-/* All-in-one CBC encrypt fucntinos treated as AEAD with no
+/* All-in-one CBC encrypt functinos treated as AEAD with no
    authentication and no decrypt method. */
 extern const struct nettle_aead nettle_cbc_aes128;
 extern const struct nettle_aead nettle_cbc_aes192;
diff --git a/pkcs1-sec-decrypt.c b/pkcs1-sec-decrypt.c
index 942a2bd3..7cf56fe4 100644
--- a/pkcs1-sec-decrypt.c
+++ b/pkcs1-sec-decrypt.c
@@ -44,14 +44,14 @@
 
 #include "gmp-glue.h"
 #include "pkcs1-internal.h"
+#include "nettle-internal.h"
 
 /* Inputs are always cast to uint32_t values. But all values used in this
  * function should never exceed the maximum value of a uint32_t anyway.
  * these macros returns 1 on success, 0 on failure */
 #define NOT_EQUAL(a, b) \
     ((0U - ((uint32_t)(a) ^ (uint32_t)(b))) >> 31)
-#define EQUAL(a, b) \
-    ((((uint32_t)(a) ^ (uint32_t)(b)) - 1U) >> 31)
+#define EQUAL(a, b) (IS_ZERO_SMALL ((a) ^ (b)))
 #define GREATER_OR_EQUAL(a, b) \
     (1U - (((uint32_t)(a) - (uint32_t)(b)) >> 31))
 
-- 
GitLab