From a7dada790fd758dd2df2d43eff2059960d3397ae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Niels=20M=C3=B6ller?= <nisse@lysator.liu.se>
Date: Fri, 29 Mar 2019 07:32:42 +0100
Subject: [PATCH] Redefine struct aes_ctx as a union of key-size specific
 contexts.

---
 ChangeLog             |  9 +++++
 aes-decrypt.c         | 17 +++++++--
 aes-encrypt.c         | 17 +++++++--
 aes-set-decrypt-key.c | 21 +++++++++--
 aes-set-encrypt-key.c | 37 +++++++++----------
 aes.h                 | 84 ++++++++++++++++++++++---------------------
 6 files changed, 115 insertions(+), 70 deletions(-)

diff --git a/ChangeLog b/ChangeLog
index bb140378..ed7bb337 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,12 @@
+2019-03-29  Niels Möller  <nisse@lysator.liu.se>
+
+	* aes.h (struct aes_ctx): Redefine using a union of key-size
+	specific contexts.
+	* aes-decrypt.c (aes_decrypt): Use switch on key_size.
+	* aes-encrypt.c (aes_encrypt): Likewise.
+	* aes-set-decrypt-key.c (aes_invert_key): Likewise.
+	* aes-set-encrypt-key.c (aes_set_encrypt_key): Likewise.
+
 2019-03-27  Niels Möller  <nisse@lysator.liu.se>
 
 	* xts.c (xts_shift): Arrange with a single write to u64[1].
diff --git a/aes-decrypt.c b/aes-decrypt.c
index a0897f51..cb9b0cea 100644
--- a/aes-decrypt.c
+++ b/aes-decrypt.c
@@ -36,6 +36,7 @@
 #endif
 
 #include <assert.h>
+#include <stdlib.h>
 
 #include "aes-internal.h"
 
@@ -349,9 +350,19 @@ aes_decrypt(const struct aes_ctx *ctx,
 	    size_t length, uint8_t *dst,
 	    const uint8_t *src)
 {
-  assert(!(length % AES_BLOCK_SIZE) );
-  _aes_decrypt(ctx->rounds, ctx->keys, &_aes_decrypt_table,
-	       length, dst, src);
+  switch (ctx->key_size)
+    {
+    default: abort();
+    case AES128_KEY_SIZE:
+      aes128_decrypt(&ctx->u.ctx128, length, dst, src);
+      break;
+    case AES192_KEY_SIZE:
+      aes192_decrypt(&ctx->u.ctx192, length, dst, src);
+      break;
+    case AES256_KEY_SIZE:
+      aes256_decrypt(&ctx->u.ctx256, length, dst, src);
+      break;
+    }
 }
 
 void
diff --git a/aes-encrypt.c b/aes-encrypt.c
index f962924a..d1496e54 100644
--- a/aes-encrypt.c
+++ b/aes-encrypt.c
@@ -36,6 +36,7 @@
 #endif
 
 #include <assert.h>
+#include <stdlib.h>
 
 #include "aes-internal.h"
 
@@ -47,9 +48,19 @@ aes_encrypt(const struct aes_ctx *ctx,
 	    size_t length, uint8_t *dst,
 	    const uint8_t *src)
 {
-  assert(!(length % AES_BLOCK_SIZE) );
-  _aes_encrypt(ctx->rounds, ctx->keys, &_aes_encrypt_table,
-	       length, dst, src);
+  switch (ctx->key_size)
+    {
+    default: abort();
+    case AES128_KEY_SIZE:
+      aes128_encrypt(&ctx->u.ctx128, length, dst, src);
+      break;
+    case AES192_KEY_SIZE:
+      aes192_encrypt(&ctx->u.ctx192, length, dst, src);
+      break;
+    case AES256_KEY_SIZE:
+      aes256_encrypt(&ctx->u.ctx256, length, dst, src);
+      break;
+    }
 }
 
 void
diff --git a/aes-set-decrypt-key.c b/aes-set-decrypt-key.c
index 20214eab..257ded6d 100644
--- a/aes-set-decrypt-key.c
+++ b/aes-set-decrypt-key.c
@@ -36,17 +36,32 @@
 # include "config.h"
 #endif
 
+#include <stdlib.h>
+
 /* This file implements and uses deprecated functions */
 #define _NETTLE_ATTRIBUTE_DEPRECATED
 
-#include "aes-internal.h"
+#include "aes.h"
 
 void
 aes_invert_key(struct aes_ctx *dst,
 	       const struct aes_ctx *src)
 {
-  _aes_invert (src->rounds, dst->keys, src->keys);
-  dst->rounds = src->rounds;
+  switch (src->key_size)
+    {
+    default: abort();
+    case AES128_KEY_SIZE:
+      aes128_invert_key(&dst->u.ctx128, &src->u.ctx128);
+      break;
+    case AES192_KEY_SIZE:
+      aes192_invert_key(&dst->u.ctx192, &src->u.ctx192);
+      break;
+    case AES256_KEY_SIZE:
+      aes256_invert_key(&dst->u.ctx256, &src->u.ctx256);
+      break;
+    }
+
+  dst->key_size = src->key_size;
 }
 
 void
diff --git a/aes-set-encrypt-key.c b/aes-set-encrypt-key.c
index 8de474dc..ed417fcd 100644
--- a/aes-set-encrypt-key.c
+++ b/aes-set-encrypt-key.c
@@ -36,32 +36,27 @@
 # include "config.h"
 #endif
 
-#include <assert.h>
 #include <stdlib.h>
 
-#include "aes-internal.h"
+#include "aes.h"
 
 void
 aes_set_encrypt_key(struct aes_ctx *ctx,
-		    size_t keysize, const uint8_t *key)
+		    size_t key_size, const uint8_t *key)
 {
-  unsigned nk, nr;
-
-  assert(keysize >= AES_MIN_KEY_SIZE);
-  assert(keysize <= AES_MAX_KEY_SIZE);
+  switch (key_size)
+    {
+    default: abort();
+    case AES128_KEY_SIZE:
+      aes128_set_encrypt_key(&ctx->u.ctx128, key);
+      break;
+    case AES192_KEY_SIZE:
+      aes192_set_encrypt_key(&ctx->u.ctx192, key);
+      break;
+    case AES256_KEY_SIZE:
+      aes256_set_encrypt_key(&ctx->u.ctx256, key);
+      break;
+    }
   
-  /* Truncate keysizes to the valid key sizes provided by Rijndael */
-  if (keysize == AES256_KEY_SIZE) {
-    nk = 8;
-    nr = _AES256_ROUNDS;
-  } else if (keysize >= AES192_KEY_SIZE) {
-    nk = 6;
-    nr = _AES192_ROUNDS;
-  } else { /* must be 16 or more */
-    nk = 4;
-    nr = _AES128_ROUNDS;
-  }
-
-  ctx->rounds = nr;
-  _aes_set_key (nr, nk, ctx->keys, key);
+  ctx->key_size = key_size;
 }
diff --git a/aes.h b/aes.h
index 333ec52f..25cb4ed0 100644
--- a/aes.h
+++ b/aes.h
@@ -71,46 +71,6 @@ extern "C" {
 #define _AES192_ROUNDS 12
 #define _AES256_ROUNDS 14
 
-/* Variable key size between 128 and 256 bits. But the only valid
- * values are 16 (128 bits), 24 (192 bits) and 32 (256 bits). */
-#define AES_MIN_KEY_SIZE AES128_KEY_SIZE
-#define AES_MAX_KEY_SIZE AES256_KEY_SIZE
-
-/* The older nettle-2.7 AES interface is deprecated, please migrate to
-   the newer interface where each algorithm has a fixed key size. */
-
-#define AES_KEY_SIZE 32
-
-struct aes_ctx
-{
-  unsigned rounds;  /* number of rounds to use for our key size */
-  uint32_t keys[4*(_AES256_ROUNDS + 1)];  /* maximum size of key schedule */
-};
-
-void
-aes_set_encrypt_key(struct aes_ctx *ctx,
-		    size_t length, const uint8_t *key)
-  _NETTLE_ATTRIBUTE_DEPRECATED;
-
-void
-aes_set_decrypt_key(struct aes_ctx *ctx,
-		   size_t length, const uint8_t *key)
-  _NETTLE_ATTRIBUTE_DEPRECATED;
-
-void
-aes_invert_key(struct aes_ctx *dst,
-	       const struct aes_ctx *src)
-  _NETTLE_ATTRIBUTE_DEPRECATED;
-
-void
-aes_encrypt(const struct aes_ctx *ctx,
-	    size_t length, uint8_t *dst,
-	    const uint8_t *src) _NETTLE_ATTRIBUTE_DEPRECATED;
-void
-aes_decrypt(const struct aes_ctx *ctx,
-	    size_t length, uint8_t *dst,
-	    const uint8_t *src) _NETTLE_ATTRIBUTE_DEPRECATED;
-
 struct aes128_ctx
 {
   uint32_t keys[4 * (_AES128_ROUNDS + 1)];
@@ -174,6 +134,50 @@ aes256_decrypt(const struct aes256_ctx *ctx,
 	       size_t length, uint8_t *dst,
 	       const uint8_t *src);
 
+/* The older nettle-2.7 AES interface is deprecated, please migrate to
+   the newer interface where each algorithm has a fixed key size. */
+
+/* Variable key size between 128 and 256 bits. But the only valid
+ * values are 16 (128 bits), 24 (192 bits) and 32 (256 bits). */
+#define AES_MIN_KEY_SIZE AES128_KEY_SIZE
+#define AES_MAX_KEY_SIZE AES256_KEY_SIZE
+
+#define AES_KEY_SIZE 32
+
+struct aes_ctx
+{
+  unsigned key_size;  /* In octets */
+  union {
+    struct aes128_ctx ctx128;
+    struct aes192_ctx ctx192;
+    struct aes256_ctx ctx256;
+  } u;
+};
+
+void
+aes_set_encrypt_key(struct aes_ctx *ctx,
+		    size_t length, const uint8_t *key)
+  _NETTLE_ATTRIBUTE_DEPRECATED;
+
+void
+aes_set_decrypt_key(struct aes_ctx *ctx,
+		   size_t length, const uint8_t *key)
+  _NETTLE_ATTRIBUTE_DEPRECATED;
+
+void
+aes_invert_key(struct aes_ctx *dst,
+	       const struct aes_ctx *src)
+  _NETTLE_ATTRIBUTE_DEPRECATED;
+
+void
+aes_encrypt(const struct aes_ctx *ctx,
+	    size_t length, uint8_t *dst,
+	    const uint8_t *src) _NETTLE_ATTRIBUTE_DEPRECATED;
+void
+aes_decrypt(const struct aes_ctx *ctx,
+	    size_t length, uint8_t *dst,
+	    const uint8_t *src) _NETTLE_ATTRIBUTE_DEPRECATED;
+
 #ifdef __cplusplus
 }
 #endif
-- 
GitLab