diff --git a/ChangeLog b/ChangeLog
index eb3b62562af5b931ae0f207738d31724353f4d09..e0b32058780f6cf3df4954a91d0504d015fc6e59 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,12 @@
+2013-05-03  Niels Möller  <nisse@lysator.liu.se>
+
+	* cast128.c: Adapt to new struct cast128_ctx.
+	(cast128_set_key): Rewrite, eliminating lots of conditions and
+	some false warnings.
+
+	* cast128.h (struct cast128_ctx): Separate the small 5-bit
+	rotation subkeys and the larger 32-bit masking subkeys.
+
 2013-05-02  Niels Möller  <nisse@lysator.liu.se>
 
 	* testsuite/testutils.c (mpz_combit): Renamed. Define only if not
diff --git a/cast128.c b/cast128.c
index ca7faac437bf884483260382fbe97f615901dfed..088bba0ca71314fa9db0e50e477234a9427c3ac0 100644
--- a/cast128.c
+++ b/cast128.c
@@ -34,6 +34,8 @@
 #endif
 
 #include <assert.h>
+#include <stdlib.h>
+#include <string.h>
 
 #include "cast128.h"
 #include "cast128_sboxes.h"
@@ -41,30 +43,39 @@
 #include "macros.h"
 
 #define CAST_SMALL_KEY 10
-#define CAST_SMALL_ROUNDS 12
-#define CAST_FULL_ROUNDS 16
+
+#define S1 cast_sbox1
+#define S2 cast_sbox2
+#define S3 cast_sbox3
+#define S4 cast_sbox4
+#define S5 cast_sbox5
+#define S6 cast_sbox6
+#define S7 cast_sbox7
+#define S8 cast_sbox8
 
 /* Macros to access 8-bit bytes out of a 32-bit word */
-#define U8a(x) ( (uint8_t) (x>>24) )
-#define U8b(x) ( (uint8_t) ((x>>16)&0xff) )
-#define U8c(x) ( (uint8_t) ((x>>8)&0xff) )
-#define U8d(x) ( (uint8_t) ((x)&0xff) )
+#define B0(x) ( (uint8_t) (x>>24) )
+#define B1(x) ( (uint8_t) ((x>>16)&0xff) )
+#define B2(x) ( (uint8_t) ((x>>8)&0xff) )
+#define B3(x) ( (uint8_t) ((x)&0xff) )
+
+/* NOTE: Depends on ROTL32 supporting a zero shift count. */
 
 /* CAST-128 uses three different round functions */
-#define F1(l, r, i) do {				\
-    t = ROTL32(ctx->keys[i+16], ctx->keys[i] + r);	\
-    l ^= ((cast_sbox1[U8a(t)] ^ cast_sbox2[U8b(t)])	\
-	  - cast_sbox3[U8c(t)]) + cast_sbox4[U8d(t)];	\
+#define F1(l, r, i) do {					\
+    t = ctx->Km[i] + r;						\
+    t = ROTL32(ctx->Kr[i], t);					\
+    l ^= ((S1[B0(t)] ^ S2[B1(t)]) - S3[B2(t)]) + S4[B3(t)];	\
   } while (0)
-#define F2(l, r, i) do {				\
-    t = ROTL32( ctx->keys[i+16], ctx->keys[i] ^ r);	\
-    l ^= ((cast_sbox1[U8a(t)] - cast_sbox2[U8b(t)])	\
-	  + cast_sbox3[U8c(t)]) ^ cast_sbox4[U8d(t)];	\
+#define F2(l, r, i) do {					\
+    t = ctx->Km[i] ^ r;						\
+    t = ROTL32( ctx->Kr[i], t);					\
+    l ^= ((S1[B0(t)] - S2[B1(t)]) + S3[B2(t)]) ^ S4[B3(t)];	\
   } while (0)
-#define F3(l, r, i) do { \
-    t = ROTL32(ctx->keys[i+16], ctx->keys[i] - r);	\
-    l ^= ((cast_sbox1[U8a(t)] + cast_sbox2[U8b(t)])	\
-	  ^ cast_sbox3[U8c(t)]) - cast_sbox4[U8d(t)];	\
+#define F3(l, r, i) do {					\
+    t = ctx->Km[i] - r;						\
+    t = ROTL32(ctx->Kr[i], t);					\
+    l ^= ((S1[B0(t)] + S2[B1(t)]) ^ S3[B2(t)]) - S4[B3(t)];	\
   } while (0)
 
 
@@ -97,7 +108,7 @@ cast128_encrypt(const struct cast128_ctx *ctx,
       F2(l, r, 10);
       F3(r, l, 11);
       /* Only do full 16 rounds if key length > 80 bits */
-      if (ctx->rounds > 12) {
+      if (ctx->rounds & 16) {
 	F1(l, r, 12);
 	F2(r, l, 13);
 	F3(l, r, 14);
@@ -106,8 +117,6 @@ cast128_encrypt(const struct cast128_ctx *ctx,
       /* Put l,r into outblock */
       WRITE_UINT32(dst, r);
       WRITE_UINT32(dst + 4, l);
-      /* Wipe clean */
-      t = l = r = 0;
     }
 }
 
@@ -129,7 +138,7 @@ cast128_decrypt(const struct cast128_ctx *ctx,
 
       /* Do the work */
       /* Only do full 16 rounds if key length > 80 bits */
-      if (ctx->rounds > 12) {
+      if (ctx->rounds & 16) {
 	F1(r, l, 15);
 	F3(l, r, 14);
 	F2(r, l, 13);
@@ -151,126 +160,112 @@ cast128_decrypt(const struct cast128_ctx *ctx,
       /* Put l,r into outblock */
       WRITE_UINT32(dst, l);
       WRITE_UINT32(dst + 4, r);
-
-      /* Wipe clean */
-      t = l = r = 0;
     }
 }
 
 /***** Key Schedule *****/
 
+#define SET_KM(i, k) ctx->Km[i] = (k)
+#define SET_KR(i, k) ctx->Kr[i] = (k) & 31
+
+#define EXPAND(set, full) do {						\
+    z0 = x0 ^ S5[B1(x3)] ^ S6[B3(x3)] ^ S7[B0(x3)] ^ S8[B2(x3)] ^ S7[B0(x2)]; \
+    z1 = x2 ^ S5[B0(z0)] ^ S6[B2(z0)] ^ S7[B1(z0)] ^ S8[B3(z0)] ^ S8[B2(x2)]; \
+    z2 = x3 ^ S5[B3(z1)] ^ S6[B2(z1)] ^ S7[B1(z1)] ^ S8[B0(z1)] ^ S5[B1(x2)]; \
+    z3 = x1 ^ S5[B2(z2)] ^ S6[B1(z2)] ^ S7[B3(z2)] ^ S8[B0(z2)] ^ S6[B3(x2)]; \
+    									\
+    set(0, S5[B0(z2)] ^ S6[B1(z2)] ^ S7[B3(z1)] ^ S8[B2(z1)] ^ S5[B2(z0)]); \
+    set(1, S5[B2(z2)] ^ S6[B3(z2)] ^ S7[B1(z1)] ^ S8[B0(z1)] ^ S6[B2(z1)]); \
+    set(2, S5[B0(z3)] ^ S6[B1(z3)] ^ S7[B3(z0)] ^ S8[B2(z0)] ^ S7[B1(z2)]); \
+    set(3, S5[B2(z3)] ^ S6[B3(z3)] ^ S7[B1(z0)] ^ S8[B0(z0)] ^ S8[B0(z3)]); \
+    									\
+    x0 = z2 ^ S5[B1(z1)] ^ S6[B3(z1)] ^ S7[B0(z1)] ^ S8[B2(z1)] ^ S7[B0(z0)]; \
+    x1 = z0 ^ S5[B0(x0)] ^ S6[B2(x0)] ^ S7[B1(x0)] ^ S8[B3(x0)] ^ S8[B2(z0)]; \
+    x2 = z1 ^ S5[B3(x1)] ^ S6[B2(x1)] ^ S7[B1(x1)] ^ S8[B0(x1)] ^ S5[B1(z0)]; \
+    x3 = z3 ^ S5[B2(x2)] ^ S6[B1(x2)] ^ S7[B3(x2)] ^ S8[B0(x2)] ^ S6[B3(z0)]; \
+    									\
+    set(4, S5[B3(x0)] ^ S6[B2(x0)] ^ S7[B0(x3)] ^ S8[B1(x3)] ^ S5[B0(x2)]); \
+    set(5, S5[B1(x0)] ^ S6[B0(x0)] ^ S7[B2(x3)] ^ S8[B3(x3)] ^ S6[B1(x3)]); \
+    set(6, S5[B3(x1)] ^ S6[B2(x1)] ^ S7[B0(x2)] ^ S8[B1(x2)] ^ S7[B3(x0)]); \
+    set(7, S5[B1(x1)] ^ S6[B0(x1)] ^ S7[B2(x2)] ^ S8[B3(x2)] ^ S8[B3(x1)]); \
+    									\
+    z0 = x0 ^ S5[B1(x3)] ^ S6[B3(x3)] ^ S7[B0(x3)] ^ S8[B2(x3)] ^ S7[B0(x2)]; \
+    z1 = x2 ^ S5[B0(z0)] ^ S6[B2(z0)] ^ S7[B1(z0)] ^ S8[B3(z0)] ^ S8[B2(x2)]; \
+    z2 = x3 ^ S5[B3(z1)] ^ S6[B2(z1)] ^ S7[B1(z1)] ^ S8[B0(z1)] ^ S5[B1(x2)]; \
+    z3 = x1 ^ S5[B2(z2)] ^ S6[B1(z2)] ^ S7[B3(z2)] ^ S8[B0(z2)] ^ S6[B3(x2)]; \
+    									\
+    set(8,  S5[B3(z0)] ^ S6[B2(z0)] ^ S7[B0(z3)] ^ S8[B1(z3)] ^ S5[B1(z2)]); \
+    set(9,  S5[B1(z0)] ^ S6[B0(z0)] ^ S7[B2(z3)] ^ S8[B3(z3)] ^ S6[B0(z3)]); \
+    set(10, S5[B3(z1)] ^ S6[B2(z1)] ^ S7[B0(z2)] ^ S8[B1(z2)] ^ S7[B2(z0)]); \
+    set(11, S5[B1(z1)] ^ S6[B0(z1)] ^ S7[B2(z2)] ^ S8[B3(z2)] ^ S8[B2(z1)]); \
+									\
+    x0 = z2 ^ S5[B1(z1)] ^ S6[B3(z1)] ^ S7[B0(z1)] ^ S8[B2(z1)] ^ S7[B0(z0)]; \
+    x1 = z0 ^ S5[B0(x0)] ^ S6[B2(x0)] ^ S7[B1(x0)] ^ S8[B3(x0)] ^ S8[B2(z0)]; \
+    x2 = z1 ^ S5[B3(x1)] ^ S6[B2(x1)] ^ S7[B1(x1)] ^ S8[B0(x1)] ^ S5[B1(z0)]; \
+    x3 = z3 ^ S5[B2(x2)] ^ S6[B1(x2)] ^ S7[B3(x2)] ^ S8[B0(x2)] ^ S6[B3(z0)]; \
+    if (full)								\
+      {									\
+	set(12, S5[B0(x2)] ^ S6[B1(x2)] ^ S7[B3(x1)] ^ S8[B2(x1)] ^ S5[B3(x0)]); \
+	set(13, S5[B2(x2)] ^ S6[B3(x2)] ^ S7[B1(x1)] ^ S8[B0(x1)] ^ S6[B3(x1)]); \
+	set(14, S5[B0(x3)] ^ S6[B1(x3)] ^ S7[B3(x0)] ^ S8[B2(x0)] ^ S7[B0(x2)]); \
+	set(15, S5[B2(x3)] ^ S6[B3(x3)] ^ S7[B1(x0)] ^ S8[B0(x0)] ^ S8[B1(x3)]); \
+      }									\
+} while (0)
+
 void
 cast128_set_key(struct cast128_ctx *ctx,
-		size_t keybytes, const uint8_t *rawkey)
+		size_t length, const uint8_t *key)
 {
-  uint32_t t[4], z[4], x[4];
-  unsigned i;
-
-  /* Set number of rounds to 12 or 16, depending on key length */
-  ctx->rounds = (keybytes <= CAST_SMALL_KEY)
-    ? CAST_SMALL_ROUNDS : CAST_FULL_ROUNDS;
-
-  /* Copy key to workspace x */
-  for (i = 0; i < 4; i++) {
-    x[i] = 0;
-    if ((i*4+0) < keybytes) x[i] = (uint32_t)rawkey[i*4+0] << 24;
-    if ((i*4+1) < keybytes) x[i] |= (uint32_t)rawkey[i*4+1] << 16;
-    if ((i*4+2) < keybytes) x[i] |= (uint32_t)rawkey[i*4+2] << 8;
-    if ((i*4+3) < keybytes) x[i] |= (uint32_t)rawkey[i*4+3];
-  }
-  /* FIXME: For the shorter key sizes, the last 4 subkeys are not
-     used, and need not be generated, nor stored. */
-  /* Generate 32 subkeys, four at a time */
-  for (i = 0; i < 32; i+=4) {
-    switch (i & 4) {
+  uint32_t x0, x1, x2, x3, z0, z1, z2, z3;
+  uint32_t w;
+  int full;
+
+  assert (length >= CAST128_MIN_KEY_SIZE);
+  assert (length <= CAST128_MAX_KEY_SIZE);
+
+  full = (length > CAST_SMALL_KEY);
+
+  x0 = READ_UINT32 (key);
+
+  /* Read final work, possibly zero-padded. */
+  switch (length & 3)
+    {
     case 0:
-      t[0] = z[0] = x[0] ^ cast_sbox5[U8b(x[3])]
-	^ cast_sbox6[U8d(x[3])] ^ cast_sbox7[U8a(x[3])]
-	^ cast_sbox8[U8c(x[3])] ^ cast_sbox7[U8a(x[2])];
-      t[1] = z[1] = x[2] ^ cast_sbox5[U8a(z[0])]
-	^ cast_sbox6[U8c(z[0])] ^ cast_sbox7[U8b(z[0])]
-	^ cast_sbox8[U8d(z[0])] ^ cast_sbox8[U8c(x[2])];
-      t[2] = z[2] = x[3] ^ cast_sbox5[U8d(z[1])]
-	^ cast_sbox6[U8c(z[1])] ^ cast_sbox7[U8b(z[1])]
-	^ cast_sbox8[U8a(z[1])] ^ cast_sbox5[U8b(x[2])];
-      t[3] = z[3] = x[1] ^ cast_sbox5[U8c(z[2])] ^
-	cast_sbox6[U8b(z[2])] ^ cast_sbox7[U8d(z[2])]
-	^ cast_sbox8[U8a(z[2])] ^ cast_sbox6[U8d(x[2])];
+      w = READ_UINT32 (key + length - 4);
       break;
-    case 4:
-      t[0] = x[0] = z[2] ^ cast_sbox5[U8b(z[1])]
-	^ cast_sbox6[U8d(z[1])] ^ cast_sbox7[U8a(z[1])]
-	^ cast_sbox8[U8c(z[1])] ^ cast_sbox7[U8a(z[0])];
-      t[1] = x[1] = z[0] ^ cast_sbox5[U8a(x[0])]
-	^ cast_sbox6[U8c(x[0])] ^ cast_sbox7[U8b(x[0])]
-	^ cast_sbox8[U8d(x[0])] ^ cast_sbox8[U8c(z[0])];
-      t[2] = x[2] = z[1] ^ cast_sbox5[U8d(x[1])]
-	^ cast_sbox6[U8c(x[1])] ^ cast_sbox7[U8b(x[1])]
-	^ cast_sbox8[U8a(x[1])] ^ cast_sbox5[U8b(z[0])];
-      t[3] = x[3] = z[3] ^ cast_sbox5[U8c(x[2])]
-	^ cast_sbox6[U8b(x[2])] ^ cast_sbox7[U8d(x[2])]
-	^ cast_sbox8[U8a(x[2])] ^ cast_sbox6[U8d(z[0])];
+    case 3:
+      w = READ_UINT24 (key + length - 3) << 8;
       break;
-    }
-    switch (i & 12) {
-    case 0:
-    case 12:
-      ctx->keys[i+0] = cast_sbox5[U8a(t[2])] ^ cast_sbox6[U8b(t[2])]
-	^ cast_sbox7[U8d(t[1])] ^ cast_sbox8[U8c(t[1])];
-      ctx->keys[i+1] = cast_sbox5[U8c(t[2])] ^ cast_sbox6[U8d(t[2])]
-	^ cast_sbox7[U8b(t[1])] ^ cast_sbox8[U8a(t[1])];
-      ctx->keys[i+2] = cast_sbox5[U8a(t[3])] ^ cast_sbox6[U8b(t[3])]
-	^ cast_sbox7[U8d(t[0])] ^ cast_sbox8[U8c(t[0])];
-      ctx->keys[i+3] = cast_sbox5[U8c(t[3])] ^ cast_sbox6[U8d(t[3])]
-	^ cast_sbox7[U8b(t[0])] ^ cast_sbox8[U8a(t[0])];
+    case 2:
+      w = READ_UINT16 (key + length - 2) << 16;
       break;
-    case 4:
-    case 8:
-      ctx->keys[i+0] = cast_sbox5[U8d(t[0])] ^ cast_sbox6[U8c(t[0])]
-	^ cast_sbox7[U8a(t[3])] ^ cast_sbox8[U8b(t[3])];
-      ctx->keys[i+1] = cast_sbox5[U8b(t[0])] ^ cast_sbox6[U8a(t[0])]
-	^ cast_sbox7[U8c(t[3])] ^ cast_sbox8[U8d(t[3])];
-      ctx->keys[i+2] = cast_sbox5[U8d(t[1])] ^ cast_sbox6[U8c(t[1])]
-	^ cast_sbox7[U8a(t[2])] ^ cast_sbox8[U8b(t[2])];
-      ctx->keys[i+3] = cast_sbox5[U8b(t[1])] ^ cast_sbox6[U8a(t[1])]
-	^ cast_sbox7[U8c(t[2])] ^ cast_sbox8[U8d(t[2])];
+    case 1:
+      w = (uint32_t) key[length - 1] << 24;
       break;
     }
-    switch (i & 12) {
-    case 0:
-      ctx->keys[i+0] ^= cast_sbox5[U8c(z[0])];
-      ctx->keys[i+1] ^= cast_sbox6[U8c(z[1])];
-      ctx->keys[i+2] ^= cast_sbox7[U8b(z[2])];
-      ctx->keys[i+3] ^= cast_sbox8[U8a(z[3])];
-      break;
-    case 4:
-      ctx->keys[i+0] ^= cast_sbox5[U8a(x[2])];
-      ctx->keys[i+1] ^= cast_sbox6[U8b(x[3])];
-      ctx->keys[i+2] ^= cast_sbox7[U8d(x[0])];
-      ctx->keys[i+3] ^= cast_sbox8[U8d(x[1])];
-      break;
-    case 8:
-      ctx->keys[i+0] ^= cast_sbox5[U8b(z[2])];
-      ctx->keys[i+1] ^= cast_sbox6[U8a(z[3])];
-      ctx->keys[i+2] ^= cast_sbox7[U8c(z[0])];
-      ctx->keys[i+3] ^= cast_sbox8[U8c(z[1])];
-      break;
-    case 12:
-      ctx->keys[i+0] ^= cast_sbox5[U8d(x[0])];
-      ctx->keys[i+1] ^= cast_sbox6[U8d(x[1])];
-      ctx->keys[i+2] ^= cast_sbox7[U8a(x[2])];
-      ctx->keys[i+3] ^= cast_sbox8[U8b(x[3])];
-      break;
+
+  if (length <= 8)
+    {
+      x1 = w;
+      x2 = x3 = 0;
     }
-    if (i >= 16) {
-      ctx->keys[i+0] &= 31;
-      ctx->keys[i+1] &= 31;
-      ctx->keys[i+2] &= 31;
-      ctx->keys[i+3] &= 31;
+  else
+    {
+      x1 = READ_UINT32 (key + 4);
+      if (length <= 12)
+	{
+	  x2 = w;
+	  x3 = 0;
+	}
+      else
+	{
+	  x2 = READ_UINT32 (key + 8);
+	  x3 = w;
+	}
     }
-  }
-  /* Wipe clean */
-  for (i = 0; i < 4; i++) {
-    t[i] = x[i] = z[i] = 0;
-  }
+
+  EXPAND(SET_KM, full);
+  EXPAND(SET_KR, full);
+
+  ctx->rounds = full ? 16 : 12;
 }
diff --git a/cast128.h b/cast128.h
index 63636d2191eabaa8df8df2ef6fd5ab7582057714..abdf71dff83550b57ac8d963025a4e746043632d 100644
--- a/cast128.h
+++ b/cast128.h
@@ -53,8 +53,10 @@ extern "C" {
 
 struct cast128_ctx
 {
-  uint32_t keys[32];  /* Key, after expansion */
-  unsigned rounds;    /* Number of rounds to use, 12 or 16 */
+  unsigned rounds;  /* Number of rounds to use, 12 or 16 */
+  /* Expanded key, rotations (5 bits only) and 32-bit masks. */
+  unsigned char Kr[16];
+  uint32_t Km[16];
 };
 
 void