From 8b6cd994fe5a4d88a467fa93ab1596e1b445582a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Niels=20M=C3=B6ller?= <nisse@lysator.liu.se>
Date: Tue, 23 Sep 2014 14:04:25 +0200
Subject: [PATCH] curve25519: Use powering to compute modp inverses, 5.5 times
 faster than ecc_mod_inv.

---
 ChangeLog                   |  20 +++++
 ecc-192.c                   |   4 +
 ecc-224.c                   |   4 +
 ecc-25519.c                 | 160 +++++++++++++++++++++---------------
 ecc-256.c                   |   4 +
 ecc-384.c                   |   4 +
 ecc-521.c                   |   4 +
 ecc-eh-to-a.c               |   7 +-
 ecc-internal.h              |   3 +-
 testsuite/ecc-modinv-test.c |  30 +++++--
 10 files changed, 165 insertions(+), 75 deletions(-)

diff --git a/ChangeLog b/ChangeLog
index 6aab4252..c50ee4b0 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,25 @@
 2014-09-23  Niels Möller  <nisse@lysator.liu.se>
 
+	* testsuite/ecc-modinv-test.c (zero_p): New function, checking for
+	zero modulo p.
+	(test_modulo): Use zero_p. Switch to dynamic allocation. Updated
+	for larger modinv result area, and use invert_itch.
+
+	* ecc-25519.c (ecc_mod_pow_2kp1): Renamed, and take a struct
+	ecc_modulo * as argument.
+	(ecc_modp_powm_2kp1): ... old name.
+	(ecc_mod_pow_252m3): New function, extracted from ecc_25519_sqrt.
+	(ecc_25519_inv): New modp invert function, about 5.5 times faster
+	then ecc_mod_inv.
+	(ecc_25519_sqrt): Use ecc_mod_pow_252m3.
+	(nettle_curve25519): Point to ecc_25519_inv. Updated p.invert_itch
+	and h_to_a_itch.
+
+	* ecc-internal.h (struct ecc_modulo): New field invert_itch.
+	Updated all implementations.
+	(ECC_EH_TO_A_ITCH): Updated, and take invert itch as an argument.
+	* ecc-eh-to-a.c (ecc_eh_to_a_itch): Take invert scratch into account.
+
 	* testsuite/testutils.c (test_ecc_mul_h): Use ecc->h_to_a_itch.
 
 	* ecc-mod-inv.c (ecc_mod_inv): Interface change, make ap input
diff --git a/ecc-192.c b/ecc-192.c
index 6d4010f2..af31d376 100644
--- a/ecc-192.c
+++ b/ecc-192.c
@@ -117,6 +117,8 @@ const struct ecc_curve nettle_secp_192r1 =
     ECC_LIMB_SIZE,
     ECC_BMODP_SIZE,
     ECC_REDC_SIZE,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_p,
     ecc_Bmodp,
     ecc_Bmodp_shifted,    
@@ -132,6 +134,8 @@ const struct ecc_curve nettle_secp_192r1 =
     ECC_LIMB_SIZE,
     ECC_BMODQ_SIZE,
     0,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_q,
     ecc_Bmodq,
     ecc_Bmodq_shifted,
diff --git a/ecc-224.c b/ecc-224.c
index ecd3f217..a0794ffe 100644
--- a/ecc-224.c
+++ b/ecc-224.c
@@ -69,6 +69,8 @@ const struct ecc_curve nettle_secp_224r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODP_SIZE,
     -ECC_REDC_SIZE,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_p,
     ecc_Bmodp,
     ecc_Bmodp_shifted,
@@ -84,6 +86,8 @@ const struct ecc_curve nettle_secp_224r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODQ_SIZE,
     0,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_q,
     ecc_Bmodq,
     ecc_Bmodq_shifted,
diff --git a/ecc-25519.c b/ecc-25519.c
index 5416b74f..39e80260 100644
--- a/ecc-25519.c
+++ b/ecc-25519.c
@@ -103,32 +103,103 @@ ecc_25519_modq (const struct ecc_modulo *q, mp_limb_t *rp)
 /* Needs 2*ecc->size limbs at rp, and 2*ecc->size additional limbs of
    scratch space. No overlap allowed. */
 static void
-ecc_modp_powm_2kp1 (const struct ecc_curve *ecc,
-		    mp_limb_t *rp, const mp_limb_t *xp,
-		    unsigned k, mp_limb_t *tp)
+ecc_mod_pow_2kp1 (const struct ecc_modulo *m,
+		  mp_limb_t *rp, const mp_limb_t *xp,
+		  unsigned k, mp_limb_t *tp)
 {
   if (k & 1)
     {
-      ecc_modp_sqr (ecc, tp, xp);
+      ecc_mod_sqr (m, tp, xp);
       k--;
     }
   else
     {
-      ecc_modp_sqr (ecc, rp, xp);
-      ecc_modp_sqr (ecc, tp, rp);
+      ecc_mod_sqr (m, rp, xp);
+      ecc_mod_sqr (m, tp, rp);
       k -= 2;
     }
   while (k > 0)
     {
-      ecc_modp_sqr (ecc, rp, tp);
-      ecc_modp_sqr (ecc, tp, rp);
+      ecc_mod_sqr (m, rp, tp);
+      ecc_mod_sqr (m, tp, rp);
       k -= 2;
     }
-  ecc_modp_mul (ecc, rp, tp, xp);
+  ecc_mod_mul (m, rp, tp, xp);
+}
+
+/* Computes a^{2^{252-3}} mod m. Needs 5 * n scratch space. */
+static void
+ecc_mod_pow_252m3 (const struct ecc_modulo *m,
+		   mp_limb_t *rp, const mp_limb_t *ap, mp_limb_t *scratch)
+{
+#define a7 scratch
+#define t0 (scratch + ECC_LIMB_SIZE)
+#define t1 (scratch + 3*ECC_LIMB_SIZE)
+
+  /* a^{2^252 - 3} = a^{(p-5)/8}, using the addition chain
+     2^252 - 3
+     = 1 + (2^252-4)
+     = 1 + 4 (2^250-1)
+     = 1 + 4 (2^125+1)(2^125-1)
+     = 1 + 4 (2^125+1)(1+2(2^124-1))
+     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^62-1))
+     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(2^31-1))
+     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^28-1)))
+     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^14-1)))
+     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^7+1)(2^7-1)))
+     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^7+1)(1+2(2^6-1))))
+     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^7+1)(1+2(2^3+1)*7)))
+  */ 
+     
+  ecc_mod_pow_2kp1 (m, t0, ap, 1, t1);	/* a^3 */
+  ecc_mod_sqr (m, rp, t0);		/* a^6 */
+  ecc_mod_mul (m, a7, rp, ap);		/* a^7 */
+  ecc_mod_pow_2kp1 (m, rp, a7, 3, t0);	/* a^63 = a^{2^6-1} */
+  ecc_mod_sqr (m, t0, rp);		/* a^{2^7-2} */
+  ecc_mod_mul (m, rp, t0, ap);		/* a^{2^7-1} */
+  ecc_mod_pow_2kp1 (m, t0, rp, 7, t1);	/* a^{2^14-1}*/
+  ecc_mod_pow_2kp1 (m, rp, t0, 14, t1);	/* a^{2^28-1} */
+  ecc_mod_sqr (m, t0, rp);		/* a^{2^29-2} */
+  ecc_mod_sqr (m, t1, t0);		/* a^{2^30-4} */
+  ecc_mod_sqr (m, t0, t1);		/* a^{2^31-8} */
+  ecc_mod_mul (m, rp, t0, a7);		/* a^{2^31-1} */
+  ecc_mod_pow_2kp1 (m, t0, rp, 31, t1);	/* a^{2^62-1} */  
+  ecc_mod_pow_2kp1 (m, rp, t0, 62, t1);	/* a^{2^124-1}*/
+  ecc_mod_sqr (m, t0, rp);		/* a^{2^125-2} */
+  ecc_mod_mul (m, rp, t0, ap);		/* a^{2^125-1} */
+  ecc_mod_pow_2kp1 (m, t0, rp, 125, t1);/* a^{2^250-1} */
+  ecc_mod_sqr (m, rp, t0);		/* a^{2^251-2} */
+  ecc_mod_sqr (m, t0, rp);		/* a^{2^252-4} */
+  ecc_mod_mul (m, rp, t0, ap);	    	/* a^{2^252-3} */
+#undef t0
 #undef t1
-#undef t2
+#undef a7
 }
 
+/* Needs 5*ECC_LIMB_SIZE scratch space. */
+#define ECC_25519_INV_ITCH (5*ECC_LIMB_SIZE)
+
+static void ecc_25519_inv (const struct ecc_modulo *p,
+			   mp_limb_t *rp, const mp_limb_t *ap,
+			   mp_limb_t *scratch)
+{
+#define t0 scratch
+
+  /* Addition chain
+     
+       p - 2 = 2^{255} - 21
+             = 1 + 2 (1 + 4 (2^{252}-3))	     
+  */
+  ecc_mod_pow_252m3 (p, rp, ap, t0);
+  ecc_mod_sqr (p, t0, rp);
+  ecc_mod_sqr (p, rp, t0);
+  ecc_mod_mul (p, t0, ap, rp);
+  ecc_mod_sqr (p, rp, t0);
+  ecc_mod_mul (p, t0, ap, rp);
+  mpn_copyi (rp, t0, ECC_LIMB_SIZE); /* FIXME: Eliminate copy? */
+#undef t0
+}
+ 
 /* Compute x such that x^2 = a (mod p). Returns one on success, zero
    on failure. using the e == 2 special case of the Shanks-Tonelli
    algorithm (see http://www.math.vt.edu/people/brown/doc/sqrts.pdf,
@@ -146,63 +217,24 @@ ecc_25519_sqrt(mp_limb_t *rp, const mp_limb_t *ap)
   mp_size_t itch;
   mp_limb_t *scratch;
   int res;
-  const struct ecc_curve *ecc = &nettle_curve25519;
+  const struct ecc_modulo *p = &nettle_curve25519.p;
 
   itch = 7*ECC_LIMB_SIZE;
   scratch = gmp_alloc_limbs (itch);
 
 #define t0 scratch
-#define a7 (scratch + 2*ECC_LIMB_SIZE)
-#define t1 (scratch + 3*ECC_LIMB_SIZE)
-#define t2 (scratch + 5*ECC_LIMB_SIZE)
-#define scratch_out (scratch + 3*ECC_LIMB_SIZE) /* overlap t1, t2 */
-
 #define xp (scratch + ECC_LIMB_SIZE)
 #define bp (scratch + 2*ECC_LIMB_SIZE)
 
-  /* a^{2^252 - 3} = a^{(p-5)/8}, using the addition chain
-     2^252 - 3
-     = 1 + (2^252-4)
-     = 1 + 4 (2^250-1)
-     = 1 + 4 (2^125+1)(2^125-1)
-     = 1 + 4 (2^125+1)(1+2(2^124-1))
-     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^62-1))
-     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(2^31-1))
-     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^28-1)))
-     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^14-1)))
-     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^7+1)(2^7-1)))
-     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^7+1)(1+2(2^6-1))))
-     = 1 + 4 (2^125+1)(1+2(2^62+1)(2^31+1)(7+8(2^14+1)(2^7+1)(1+2(2^3+1)*7)))
-  */ 
-     
-  ecc_modp_powm_2kp1 (ecc, t1, ap, 1, t2);  /* a^3 */
-  ecc_modp_sqr (ecc, t0, t1);		    /* a^6 */
-  ecc_modp_mul (ecc, a7, t0, ap);	    /* a^7 */
-  ecc_modp_powm_2kp1 (ecc, t0, a7, 3, t1);  /* a^63 = a^{2^6-1} */
-  ecc_modp_sqr (ecc, t1, t0);		    /* a^{2^7-2} */
-  ecc_modp_mul (ecc, t0, t1, ap);	    /* a^{2^7-1} */
-  ecc_modp_powm_2kp1 (ecc, t1, t0, 7, t2);  /* a^{2^14-1}*/
-  ecc_modp_powm_2kp1 (ecc, t0, t1, 14, t2); /* a^{2^28-1} */
-  ecc_modp_sqr (ecc, t1, t0);		    /* a^{2^29-2} */
-  ecc_modp_sqr (ecc, t2, t1);		    /* a^{2^30-4} */
-  ecc_modp_sqr (ecc, t1, t2);		    /* a^{2^31-8} */
-  ecc_modp_mul (ecc, t0, t1, a7);	    /* a^{2^31-1} */
-  ecc_modp_powm_2kp1 (ecc, t1, t0, 31, t2); /* a^{2^62-1} */  
-  ecc_modp_powm_2kp1 (ecc, t0, t1, 62, t2); /* a^{2^124-1}*/
-  ecc_modp_sqr (ecc, t1, t0);		    /* a^{2^125-2} */
-  ecc_modp_mul (ecc, t0, t1, ap);	    /* a^{2^125-1} */
-  ecc_modp_powm_2kp1 (ecc, t1, t0, 125, t2); /* a^{2^250-1} */
-  ecc_modp_sqr (ecc, t0, t1);		    /* a^{2^251-2} */
-  ecc_modp_sqr (ecc, t1, t0);		    /* a^{2^252-4} */
-  ecc_modp_mul (ecc, t0, t1, ap);	    /* a^{2^252-3} */
+  ecc_mod_pow_252m3 (p, t0, ap, t0 + 2*ECC_LIMB_SIZE);
 
   /* Compute candidate root x and fudgefactor b. */
-  ecc_modp_mul (ecc, xp, t0, ap); /* a^{(p+3)/8 */
-  ecc_modp_mul (ecc, bp, t0, xp); /* a^{(p-1)/4} */
+  ecc_mod_mul (p, xp, t0, ap); /* a^{(p+3)/8 */
+  ecc_mod_mul (p, bp, t0, xp); /* a^{(p-1)/4} */
   /* Check if b == 1 (mod p) */
-  if (mpn_cmp (bp, ecc->p.m, ECC_LIMB_SIZE) >= 0)
-    mpn_sub_n (bp, bp, ecc->p.m, ECC_LIMB_SIZE);
-  if (mpn_cmp (bp, ecc->unit, ECC_LIMB_SIZE) == 0)
+  if (mpn_cmp (bp, p->m, ECC_LIMB_SIZE) >= 0)
+    mpn_sub_n (bp, bp, p->m, ECC_LIMB_SIZE);
+  if (mpn_cmp (bp, ecc_unit, ECC_LIMB_SIZE) == 0)
     {
       mpn_copyi (rp, xp, ECC_LIMB_SIZE);
       res = 1;
@@ -210,9 +242,9 @@ ecc_25519_sqrt(mp_limb_t *rp, const mp_limb_t *ap)
   else
     {
       mpn_add_1 (bp, bp, ECC_LIMB_SIZE, 1);
-      if (mpn_cmp (bp, ecc->p.m, ECC_LIMB_SIZE) == 0)
+      if (mpn_cmp (bp, p->m, ECC_LIMB_SIZE) == 0)
 	{
-	  ecc_modp_mul (&nettle_curve25519, bp, xp, ecc_sqrt_z);
+	  ecc_mod_mul (p, bp, xp, ecc_sqrt_z);
 	  mpn_copyi (rp, bp, ECC_LIMB_SIZE);
 	  res = 1;
 	}
@@ -222,12 +254,8 @@ ecc_25519_sqrt(mp_limb_t *rp, const mp_limb_t *ap)
   gmp_free_limbs (scratch, itch);
   return res;
 #undef t0
-#undef t1
-#undef t2
-#undef a7
 #undef xp
 #undef bp
-#undef scratch_out
 }
 
 const struct ecc_curve nettle_curve25519 =
@@ -237,6 +265,8 @@ const struct ecc_curve nettle_curve25519 =
     ECC_LIMB_SIZE,
     ECC_BMODP_SIZE,
     0,
+    ECC_25519_INV_ITCH,
+
     ecc_p,
     ecc_Bmodp,
     ecc_Bmodp_shifted,
@@ -245,13 +275,15 @@ const struct ecc_curve nettle_curve25519 =
 
     ecc_25519_modp,
     ecc_25519_modp,
-    ecc_mod_inv,
+    ecc_25519_inv,
   },
   {
     253,
     ECC_LIMB_SIZE,
     ECC_BMODQ_SIZE,
     0,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_q,
     ecc_Bmodq,  
     ecc_mBmodq_shifted, /* Use q - 2^{252} instead. */
@@ -270,7 +302,7 @@ const struct ecc_curve nettle_curve25519 =
   ECC_ADD_EHH_ITCH (ECC_LIMB_SIZE),
   ECC_MUL_A_EH_ITCH (ECC_LIMB_SIZE),
   ECC_MUL_G_EH_ITCH (ECC_LIMB_SIZE),
-  ECC_EH_TO_A_ITCH (ECC_LIMB_SIZE),
+  ECC_EH_TO_A_ITCH (ECC_LIMB_SIZE, ECC_25519_INV_ITCH),
 
   ecc_add_ehh,
   ecc_mul_a_eh,
diff --git a/ecc-256.c b/ecc-256.c
index 95b719ed..4516f87b 100644
--- a/ecc-256.c
+++ b/ecc-256.c
@@ -232,6 +232,8 @@ const struct ecc_curve nettle_secp_256r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODP_SIZE,
     ECC_REDC_SIZE,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_p,
     ecc_Bmodp,
     ecc_Bmodp_shifted,
@@ -247,6 +249,8 @@ const struct ecc_curve nettle_secp_256r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODQ_SIZE,
     0,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_q,
     ecc_Bmodq,
     ecc_Bmodq_shifted,
diff --git a/ecc-384.c b/ecc-384.c
index e6a744c8..e8ced0ab 100644
--- a/ecc-384.c
+++ b/ecc-384.c
@@ -154,6 +154,8 @@ const struct ecc_curve nettle_secp_384r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODP_SIZE,
     ECC_REDC_SIZE,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_p,
     ecc_Bmodp,
     ecc_Bmodp_shifted,
@@ -169,6 +171,8 @@ const struct ecc_curve nettle_secp_384r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODQ_SIZE,
     0,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_q,
     ecc_Bmodq,
     ecc_Bmodq_shifted,
diff --git a/ecc-521.c b/ecc-521.c
index ae6f29f2..ff13e3a5 100644
--- a/ecc-521.c
+++ b/ecc-521.c
@@ -82,6 +82,8 @@ const struct ecc_curve nettle_secp_521r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODP_SIZE,
     ECC_REDC_SIZE,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_p,
     ecc_Bmodp,
     ecc_Bmodp_shifted,
@@ -97,6 +99,8 @@ const struct ecc_curve nettle_secp_521r1 =
     ECC_LIMB_SIZE,    
     ECC_BMODQ_SIZE,
     0,
+    ECC_MOD_INV_ITCH (ECC_LIMB_SIZE),
+
     ecc_q,
     ecc_Bmodq,
     ecc_Bmodq_shifted,
diff --git a/ecc-eh-to-a.c b/ecc-eh-to-a.c
index 4cfcad3b..3b7aff22 100644
--- a/ecc-eh-to-a.c
+++ b/ecc-eh-to-a.c
@@ -41,10 +41,11 @@
 mp_size_t
 ecc_eh_to_a_itch (const struct ecc_curve *ecc)
 {
-  /* Needs ecc->p.size + scratch for ecc_modq_inv */
-  return ECC_EH_TO_A_ITCH (ecc->p.size);
+  /* Needs 2*ecc->p.size + scratch for ecc_modq_inv */  
+  return ECC_EH_TO_A_ITCH (ecc->p.size, ecc->p.invert_itch);
 }
 
+
 /* Convert from homogeneous coordinates on the Edwards curve to affine
    coordinates. */
 void
@@ -63,7 +64,7 @@ ecc_eh_to_a (const struct ecc_curve *ecc,
 
   mp_limb_t cy;
 
-  /* Needs 2*size scratch */
+  /* Needs 2*size + scratch for the invert call. */
   ecc->p.invert (&ecc->p, izp, zp, tp + ecc->p.size);
 
   ecc_modp_mul (ecc, tp, xp, izp);
diff --git a/ecc-internal.h b/ecc-internal.h
index fe9acfb3..c07fdcfc 100644
--- a/ecc-internal.h
+++ b/ecc-internal.h
@@ -107,6 +107,7 @@ struct ecc_modulo
   unsigned short size;
   unsigned short B_size;
   unsigned short redc_size;
+  unsigned short invert_itch;
 
   const mp_limb_t *m;
   /* B^size mod m. Expected to have at least 32 leading zeros
@@ -265,7 +266,7 @@ curve25519_eh_to_x (mp_limb_t *xp, const mp_limb_t *p,
 /* Current scratch needs: */
 #define ECC_MOD_INV_ITCH(size) (2*(size))
 #define ECC_J_TO_A_ITCH(size) (5*(size))
-#define ECC_EH_TO_A_ITCH(size) (4*(size))
+#define ECC_EH_TO_A_ITCH(size, inv) (2*(size)+(inv))
 #define ECC_DUP_JJ_ITCH(size) (5*(size))
 #define ECC_DUP_EH_ITCH(size) (5*(size))
 #define ECC_ADD_JJA_ITCH(size) (6*(size))
diff --git a/testsuite/ecc-modinv-test.c b/testsuite/ecc-modinv-test.c
index 2faefc86..a685f59c 100644
--- a/testsuite/ecc-modinv-test.c
+++ b/testsuite/ecc-modinv-test.c
@@ -34,6 +34,13 @@ ref_modinv (mp_limb_t *rp, const mp_limb_t *ap, const mp_limb_t *mp, mp_size_t m
   return 1;
 }
 
+static int
+zero_p (const struct ecc_modulo *m, const mp_limb_t *xp)
+{
+  return mpn_zero_p (xp, m->size)
+    || mpn_cmp (xp, m->m, m->size) == 0;
+}
+
 #define MAX_ECC_SIZE (1 + 521 / GMP_NUMB_BITS)
 #define COUNT 500
 
@@ -41,20 +48,25 @@ static void
 test_modulo (gmp_randstate_t rands, const char *name,
 	     const struct ecc_modulo *m)
 {
-  mp_limb_t a[MAX_ECC_SIZE];
-  mp_limb_t ai[2*MAX_ECC_SIZE];
-  mp_limb_t ref[MAX_ECC_SIZE];
-  mp_limb_t scratch[ECC_MOD_INV_ITCH (MAX_ECC_SIZE)];
+  mp_limb_t *a;
+  mp_limb_t *ai;
+  mp_limb_t *ref;
+  mp_limb_t *scratch;
   unsigned j;
   mpz_t r;
 
   mpz_init (r);
-  
+
+  a = xalloc_limbs (m->size);
+  ai = xalloc_limbs (2*m->size);
+  ref = xalloc_limbs (m->size);;
+  scratch = xalloc_limbs (m->invert_itch);
+
   /* Check behaviour for zero input */
   mpn_zero (a, m->size);
   memset (ai, 17, m->size * sizeof(*ai));
   m->invert (m, ai, a, scratch);
-  if (!mpn_zero_p (ai, m->size))
+  if (!zero_p (m, ai))
     {
       fprintf (stderr, "%s->invert failed for zero input (bit size %u):\n",
 	       name, m->bit_size);
@@ -68,7 +80,7 @@ test_modulo (gmp_randstate_t rands, const char *name,
   /* Check behaviour for a = m */
   memset (ai, 17, m->size * sizeof(*ai));
   m->invert (m, ai, m->m, scratch);
-  if (!mpn_zero_p (ai, m->size))
+  if (!zero_p (m, ai))
     {
       fprintf (stderr, "%s->invert failed for a = p input (bit size %u):\n",
 	       name, m->bit_size);
@@ -112,6 +124,10 @@ test_modulo (gmp_randstate_t rands, const char *name,
 	  
     }
   mpz_clear (r);
+  free (a);
+  free (ai);
+  free (ref);
+  free (scratch);
 }
 
 void
-- 
GitLab