From ef1c033545c5ce87348d5e6cd4f774481d5a836f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Niels=20M=C3=B6ller?= <nisse@lysator.liu.se>
Date: Sun, 21 Jan 2024 19:02:26 +0100
Subject: [PATCH] Change order of aes decryption subkeys, update C and x86_64
 implementations.

---
 ChangeLog                       |  8 ++++++++
 aes-decrypt-internal.c          | 18 ++++++++---------
 aes-invert-internal.c           | 34 +++++++++------------------------
 x86_64/aes-decrypt-internal.asm | 13 ++++++++-----
 x86_64/aesni/aes128-decrypt.asm | 20 +++++++++----------
 x86_64/aesni/aes192-decrypt.asm | 24 +++++++++++------------
 x86_64/aesni/aes256-decrypt.asm | 32 +++++++++++++++----------------
 7 files changed, 72 insertions(+), 77 deletions(-)

diff --git a/ChangeLog b/ChangeLog
index 1a891e48..1e665796 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -19,6 +19,14 @@
 
 2024-01-21  Niels Möller  <nisse@lysator.liu.se>
 
+	* aes-invert-internal.c (_nettle_aes_invert): Don't reorder the subkeys.
+	* aes-decrypt-internal.c (_nettle_aes_decrypt): Updated to process
+	subkeys starting from the end.
+	* x86_64/aes-decrypt-internal.asm: Likewise.
+	* x86_64/aesni/aes128-decrypt.asm: Likewise.
+	* x86_64/aesni/aes192-decrypt.asm: Likewise.
+	* x86_64/aesni/aes256-decrypt.asm: Likewise.
+
 	* powerpc64/machine.m4 (OPN_XXY, OPN_XXXY): New macros.
 	* powerpc64/p8/aes-encrypt-internal.asm: Use macros for repeated
 	instruction patterns.
diff --git a/aes-decrypt-internal.c b/aes-decrypt-internal.c
index 9e8cf34a..fcd7289a 100644
--- a/aes-decrypt-internal.c
+++ b/aes-decrypt-internal.c
@@ -65,12 +65,12 @@ _nettle_aes_decrypt(unsigned rounds, const uint32_t *keys,
       /* Get clear text, using little-endian byte order.
        * Also XOR with the first subkey. */
 
-      w0 = LE_READ_UINT32(src)      ^ keys[0];
-      w1 = LE_READ_UINT32(src + 4)  ^ keys[1];
-      w2 = LE_READ_UINT32(src + 8)  ^ keys[2];
-      w3 = LE_READ_UINT32(src + 12) ^ keys[3];
+      w0 = LE_READ_UINT32(src)      ^ keys[4*rounds];
+      w1 = LE_READ_UINT32(src + 4)  ^ keys[4*rounds + 1];
+      w2 = LE_READ_UINT32(src + 8)  ^ keys[4*rounds + 2];
+      w3 = LE_READ_UINT32(src + 12) ^ keys[4*rounds + 3];
 
-      for (i = 1; i < rounds; i++)
+      for (i = rounds - 1; i > 0; i--)
 	{
 	  t0 = AES_ROUND(T, w0, w3, w2, w1, keys[4*i]);
 	  t1 = AES_ROUND(T, w1, w0, w3, w2, keys[4*i + 1]);
@@ -88,10 +88,10 @@ _nettle_aes_decrypt(unsigned rounds, const uint32_t *keys,
 
       /* Final round */
 
-      t0 = AES_FINAL_ROUND(T, w0, w3, w2, w1, keys[4*i]);
-      t1 = AES_FINAL_ROUND(T, w1, w0, w3, w2, keys[4*i + 1]);
-      t2 = AES_FINAL_ROUND(T, w2, w1, w0, w3, keys[4*i + 2]);
-      t3 = AES_FINAL_ROUND(T, w3, w2, w1, w0, keys[4*i + 3]);
+      t0 = AES_FINAL_ROUND(T, w0, w3, w2, w1, keys[0]);
+      t1 = AES_FINAL_ROUND(T, w1, w0, w3, w2, keys[1]);
+      t2 = AES_FINAL_ROUND(T, w2, w1, w0, w3, keys[2]);
+      t3 = AES_FINAL_ROUND(T, w3, w2, w1, w0, keys[3]);
 
       LE_WRITE_UINT32(dst, t0);
       LE_WRITE_UINT32(dst + 4, t1);
diff --git a/aes-invert-internal.c b/aes-invert-internal.c
index a2faefa4..7364616c 100644
--- a/aes-invert-internal.c
+++ b/aes-invert-internal.c
@@ -111,9 +111,9 @@ static const uint32_t mtable[0x100] =
   0xbe805d9f,0xb58d5491,0xa89a4f83,0xa397468d,
 };
 
-#define MIX_COLUMN(T, key) do { \
+#define MIX_COLUMN(T, out, in) do {		\
     uint32_t _k, _nk, _t;	\
-    _k = (key);			\
+    _k = (in);			\
     _nk = T[_k & 0xff];		\
     _k >>= 8;			\
     _t = T[_k & 0xff];		\
@@ -124,7 +124,7 @@ static const uint32_t mtable[0x100] =
     _k >>= 8;			\
     _t = T[_k & 0xff];		\
     _nk ^= ROTL32(24, _t);	\
-    (key) = _nk;		\
+    (out) = _nk;		\
   } while(0)
   
 
@@ -136,29 +136,13 @@ _nettle_aes_invert(unsigned rounds, uint32_t *dst, const uint32_t *src)
 {
   unsigned i;
 
-  /* Reverse the order of subkeys, in groups of 4. */
-  /* FIXME: Instead of reordering the subkeys, change the access order
-     of aes_decrypt, since it's a separate function anyway? */
-  if (src == dst)
-    {
-      unsigned j, k;
+  /* Transform all subkeys but the first and last. */
+  for (i = 4; i < 4 * rounds; i++)
+    MIX_COLUMN (mtable, dst[i], src[i]);
 
-      for (i = 0, j = rounds * 4;
-	   i < j;
-	   i += 4, j -= 4)
-	for (k = 0; k<4; k++)
-	  SWAP(dst[i+k], dst[j+k]);
-    }
-  else
+  if (src != dst)
     {
-      unsigned k;
-
-      for (i = 0; i <= rounds * 4; i += 4)
-	for (k = 0; k < 4; k++)
-	  dst[i+k] = src[rounds * 4 - i + k];
+      dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
+      dst[i] = src[i]; dst[i+1] = src[i+1]; dst[i+2] = src[i+2]; dst[i+3] = src[i+3];
     }
-
-  /* Transform all subkeys but the first and last. */
-  for (i = 4; i < 4 * rounds; i++)
-    MIX_COLUMN (mtable, dst[i]);
 }
diff --git a/x86_64/aes-decrypt-internal.asm b/x86_64/aes-decrypt-internal.asm
index d3bedc25..ed753a2c 100644
--- a/x86_64/aes-decrypt-internal.asm
+++ b/x86_64/aes-decrypt-internal.asm
@@ -83,8 +83,10 @@ PROLOGUE(_nettle_aes_decrypt)
 	push	%r15	
 
 	subl	$1, XREG(ROUNDS)
-	push	ROUNDS		C Rounds at (%rsp) 
-	
+	push	ROUNDS			C Rounds stored at (%rsp)
+	shl	$4, XREG(ROUNDS)	C Zero-extends
+	lea	16(KEYS, ROUNDS), KEYS
+
 	mov	PARAM_TABLE, TABLE
 	mov	PARAM_LENGTH, LENGTH
 	shr	$4, LENGTH
@@ -92,11 +94,12 @@ PROLOGUE(_nettle_aes_decrypt)
 	mov	KEYS, KEY
 	
 	AES_LOAD(SA, SB, SC, SD, SRC, KEY)
-	add	$16, SRC	C Increment src pointer
 
+	add	$16, SRC	C  increment src pointer
 	movl	(%rsp), XREG(ROUNDS)
 
-	add	$16, KEY	C  point to next key
+	sub	$16, KEY	C  point to next key
+
 	ALIGN(16)
 .Lround_loop:
 	AES_ROUND(TABLE, SA,SD,SC,SB, TA, TMP)
@@ -113,7 +116,7 @@ PROLOGUE(_nettle_aes_decrypt)
 	xorl	8(KEY),SC
 	xorl	12(KEY),SD
 
-	add	$16, KEY	C  point to next key
+	sub	$16, KEY	C  point to next key
 	decl	XREG(ROUNDS)
 	jnz	.Lround_loop
 
diff --git a/x86_64/aesni/aes128-decrypt.asm b/x86_64/aesni/aes128-decrypt.asm
index 79111e47..b2009894 100644
--- a/x86_64/aesni/aes128-decrypt.asm
+++ b/x86_64/aesni/aes128-decrypt.asm
@@ -64,17 +64,17 @@ PROLOGUE(nettle_aes128_decrypt)
 	test	LENGTH, LENGTH
 	jz	.Lend
 
-	movups	(CTX), KEY0
-	movups	16(CTX), KEY1
-	movups	32(CTX), KEY2
-	movups	48(CTX), KEY3
-	movups	64(CTX), KEY4
+	movups	160(CTX), KEY0
+	movups	144(CTX), KEY1
+	movups	128(CTX), KEY2
+	movups	112(CTX), KEY3
+	movups	96(CTX), KEY4
 	movups	80(CTX), KEY5
-	movups	96(CTX), KEY6
-	movups	112(CTX), KEY7
-	movups	128(CTX), KEY8
-	movups	144(CTX), KEY9
-	movups	160(CTX), KEY10
+	movups	64(CTX), KEY6
+	movups	48(CTX), KEY7
+	movups	32(CTX), KEY8
+	movups	16(CTX), KEY9
+	movups	(CTX), KEY10
 	shr	LENGTH
 	jnc	.Lblock_loop
 
diff --git a/x86_64/aesni/aes192-decrypt.asm b/x86_64/aesni/aes192-decrypt.asm
index 399f89b6..24c17827 100644
--- a/x86_64/aesni/aes192-decrypt.asm
+++ b/x86_64/aesni/aes192-decrypt.asm
@@ -66,19 +66,19 @@ PROLOGUE(nettle_aes192_decrypt)
 	test	LENGTH, LENGTH
 	jz	.Lend
 
-	movups	(CTX), KEY0
-	movups	16(CTX), KEY1
-	movups	32(CTX), KEY2
-	movups	48(CTX), KEY3
-	movups	64(CTX), KEY4
-	movups	80(CTX), KEY5
+	movups	192(CTX), KEY0
+	movups	176(CTX), KEY1
+	movups	160(CTX), KEY2
+	movups	144(CTX), KEY3
+	movups	128(CTX), KEY4
+	movups	112(CTX), KEY5
 	movups	96(CTX), KEY6
-	movups	112(CTX), KEY7
-	movups	128(CTX), KEY8
-	movups	144(CTX), KEY9
-	movups	160(CTX), KEY10
-	movups	176(CTX), KEY11
-	movups	192(CTX), KEY12
+	movups	80(CTX), KEY7
+	movups	64(CTX), KEY8
+	movups	48(CTX), KEY9
+	movups	32(CTX), KEY10
+	movups	16(CTX), KEY11
+	movups	(CTX), KEY12
 	shr	LENGTH
 	jnc	.Lblock_loop
 
diff --git a/x86_64/aesni/aes256-decrypt.asm b/x86_64/aesni/aes256-decrypt.asm
index 0fc5ad2a..247655a3 100644
--- a/x86_64/aesni/aes256-decrypt.asm
+++ b/x86_64/aesni/aes256-decrypt.asm
@@ -67,20 +67,20 @@ PROLOGUE(nettle_aes256_decrypt)
 	test	LENGTH, LENGTH
 	jz	.Lend
 
-	movups	(CTX), KEY0_7
-	movups	16(CTX), KEY1
-	movups	32(CTX), KEY2
-	movups	48(CTX), KEY3
-	movups	64(CTX), KEY4
-	movups	80(CTX), KEY5
-	movups	96(CTX), KEY6
-	movups	128(CTX), KEY8
-	movups	144(CTX), KEY9
-	movups	160(CTX), KEY10
-	movups	176(CTX), KEY11
-	movups	192(CTX), KEY12
-	movups	208(CTX), KEY13
-	movups	224(CTX), KEY14
+	movups	224(CTX), KEY0_7
+	movups	208(CTX), KEY1
+	movups	192(CTX), KEY2
+	movups	176(CTX), KEY3
+	movups	160(CTX), KEY4
+	movups	144(CTX), KEY5
+	movups	128(CTX), KEY6
+	movups	96(CTX), KEY8
+	movups	80(CTX), KEY9
+	movups	64(CTX), KEY10
+	movups	48(CTX), KEY11
+	movups	32(CTX), KEY12
+	movups	16(CTX), KEY13
+	movups	(CTX), KEY14
 
 	shr	LENGTH
 	jnc	.Lblock_loop
@@ -95,7 +95,7 @@ PROLOGUE(nettle_aes256_decrypt)
 	aesdec	KEY5, X
 	aesdec	KEY6, X
 	aesdec	KEY0_7, X
-	movups	(CTX), KEY0_7
+	movups	224(CTX), KEY0_7
 	aesdec	KEY8, X
 	aesdec	KEY9, X
 	aesdec	KEY10, X
@@ -130,7 +130,7 @@ PROLOGUE(nettle_aes256_decrypt)
 	aesdec	KEY6, Y
 	aesdec	KEY0_7, X
 	aesdec	KEY0_7, Y
-	movups	(CTX), KEY0_7
+	movups	224(CTX), KEY0_7
 	aesdec	KEY8, X
 	aesdec	KEY8, Y
 	aesdec	KEY9, X
-- 
GitLab