From 8702a4c48ae2381dd2994e5053961044fca89676 Mon Sep 17 00:00:00 2001
From: Simo Sorce <simo@redhat.com>
Date: Tue, 19 Mar 2019 16:30:53 -0400
Subject: [PATCH] Inline ciphertext stealing

This avoids copying and may be somewhat more readable without the need
for so much explanation.

Signed-off-by: Simo Sorce <simo@redhat.com>
---
 xts.c | 106 +++++++++++++++++++++++++---------------------------------
 1 file changed, 46 insertions(+), 60 deletions(-)

diff --git a/xts.c b/xts.c
index f4d2da4b..5c37655d 100644
--- a/xts.c
+++ b/xts.c
@@ -61,42 +61,6 @@ xts_shift(union nettle_block16 *dst,
   dst->b[0] ^= 0x87 & -carry;
 }
 
-/*
- * prev is the block to steal from
- * curr is the input block to the last step
- * length is the partial block length
- * dst is the destination partial block
- * src is the source partial block
- *
- * In the Encryption case:
- *   prev -> the output of the N-1 encryption step
- *   curr -> the input to the Nth step (will be encrypted as Cn-1)
- *   dst  -> the final Cn partial block
- *   src  -> the final Pn partial block
- *
- * In the decryption case:
- *   prev -> the output of the N-1 decryption step
- *   curr -> the input to the Nth step (will be decrypted as Pn-1)
- *   dst  -> the final Pn partial block
- *   src  -> the final Cn partial block
- */
-static void
-xts_steal(uint8_t *prev, uint8_t *curr,
-	  size_t length, uint8_t *dst, const uint8_t *src)
-{
-  /* copy the remaining in the current input block */
-  memcpy(curr, src, length);
-  /* fill the current block with the last blocksize - length
-   * bytes of the previous block */
-  memcpy(&curr[length], &prev[length], XTS_BLOCK_SIZE - length);
-
-  /* This must be last or inplace operations will break
-   * copy 'length' bytes of the previous block in the
-   * destination block, which is the final partial block
-   * returned to the caller */
-  memcpy(dst, prev, length);
-}
-
 static void
 check_length(size_t length, uint8_t *dst)
 {
@@ -125,26 +89,45 @@ xts_encrypt_message(const void *enc_ctx, const void *twk_ctx,
 
   /* the zeroth power of alpha is the initial ciphertext value itself, so we
    * skip shifting and do it at the end of each block operation instead */
-  for (;length >= XTS_BLOCK_SIZE;
+  for (;length >= 2 * XTS_BLOCK_SIZE || length == XTS_BLOCK_SIZE;
        length -= XTS_BLOCK_SIZE, src += XTS_BLOCK_SIZE, dst += XTS_BLOCK_SIZE)
     {
       memxor3(P.b, src, T.b, XTS_BLOCK_SIZE);	/* P -> PP */
       encf(enc_ctx, XTS_BLOCK_SIZE, dst, P.b);  /* CC */
       memxor(dst, T.b, XTS_BLOCK_SIZE);	        /* CC -> C */
 
-      /* shift T for next block */
-      xts_shift(&T, &T);
+      /* shift T for next block if any */
+      if (length > XTS_BLOCK_SIZE)
+          xts_shift(&T, &T);
     }
 
   /* if the last block is partial, handle via stealing */
   if (length)
     {
-      uint8_t *C = dst - XTS_BLOCK_SIZE;
-      /* C points to C(n-1) */
-      xts_steal(C, P.b, length, dst, src);
-      memxor(P.b, T.b, XTS_BLOCK_SIZE);	        /* P -> PP */
-      encf(enc_ctx, XTS_BLOCK_SIZE, C, P.b);    /* CC */
-      memxor(C, T.b, XTS_BLOCK_SIZE);
+      /* S Holds the real C(n-1) (Whole last block to steal from) */
+      union nettle_block16 S;
+
+      memxor3(P.b, src, T.b, XTS_BLOCK_SIZE);	/* P -> PP */
+      encf(enc_ctx, XTS_BLOCK_SIZE, S.b, P.b);  /* CC */
+      memxor(S.b, T.b, XTS_BLOCK_SIZE);	        /* CC -> S */
+
+      /* shift T for next block */
+      xts_shift(&T, &T);
+
+      length -= XTS_BLOCK_SIZE;
+      src += XTS_BLOCK_SIZE;
+
+      memxor3(P.b, src, T.b, length);           /* P |.. */
+      /* steal ciphertext to complete block */
+      memxor3(P.b + length, S.b + length, T.b + length,
+              XTS_BLOCK_SIZE - length);         /* ..| S_2 -> PP */
+
+      encf(enc_ctx, XTS_BLOCK_SIZE, dst, P.b);  /* CC */
+      memxor(dst, T.b, XTS_BLOCK_SIZE);         /* CC -> C(n-1) */
+
+      /* Do this after we read src so inplace operations do not break */
+      dst += XTS_BLOCK_SIZE;
+      memcpy(dst, S.b, length);                 /* S_1 -> C(n) */
     }
 }
 
@@ -161,42 +144,45 @@ xts_decrypt_message(const void *dec_ctx, const void *twk_ctx,
 
   encf(twk_ctx, XTS_BLOCK_SIZE, T.b, tweak);
 
-  for (;length >= XTS_BLOCK_SIZE;
+  for (;length >= 2 * XTS_BLOCK_SIZE || length == XTS_BLOCK_SIZE;
        length -= XTS_BLOCK_SIZE, src += XTS_BLOCK_SIZE, dst += XTS_BLOCK_SIZE)
     {
-      if (length > XTS_BLOCK_SIZE && length < 2 * XTS_BLOCK_SIZE)
-        break;                  /* must ciphersteal on last two blocks */
-
       memxor3(C.b, src, T.b, XTS_BLOCK_SIZE);	/* c -> CC */
       decf(dec_ctx, XTS_BLOCK_SIZE, dst, C.b);  /* PP */
       memxor(dst, T.b, XTS_BLOCK_SIZE);	        /* PP -> P */
 
-      xts_shift(&T, &T);
+      /* shift T for next block if any */
+      if (length > XTS_BLOCK_SIZE)
+          xts_shift(&T, &T);
     }
 
   /* if the last block is partial, handle via stealing */
   if (length)
     {
       union nettle_block16 T1;
-      uint8_t *P;
+      /* S Holds the real P(n) (with part of stolen ciphertext) */
+      union nettle_block16 S;
 
       /* we need the last T(n) and save the T(n-1) for later */
       xts_shift(&T1, &T);
 
-      P = dst;      /* use P(n-1) as temp storage for partial P(n) */
       memxor3(C.b, src, T1.b, XTS_BLOCK_SIZE);	/* C -> CC */
-      decf(dec_ctx, XTS_BLOCK_SIZE, P, C.b);    /* PP */
-      memxor(P, T1.b, XTS_BLOCK_SIZE);	        /* PP -> P */
+      decf(dec_ctx, XTS_BLOCK_SIZE, S.b, C.b);  /* PP */
+      memxor(S.b, T1.b, XTS_BLOCK_SIZE);	/* PP -> S */
 
       /* process next block (Pn-1) */
       length -= XTS_BLOCK_SIZE;
       src += XTS_BLOCK_SIZE;
-      dst += XTS_BLOCK_SIZE;
 
-      /* Fill P(n) and prepare C, P still pointing to P(n-1) */
-      xts_steal(P, C.b, length, dst, src);
-      memxor(C.b, T.b, XTS_BLOCK_SIZE);	        /* C -> CC */
-      decf(dec_ctx, XTS_BLOCK_SIZE, P, C.b);    /* PP */
-      memxor(P, T.b, XTS_BLOCK_SIZE);	        /* PP -> P */
+      /* Prepare C, P holds the real P(n) */
+      memxor3(C.b, src, T.b, length);	        /* C_1 |.. */
+      memxor3(C.b + length, S.b + length, T.b + length,
+              XTS_BLOCK_SIZE - length);         /* ..| S_2 -> CC */
+      decf(dec_ctx, XTS_BLOCK_SIZE, dst, C.b);  /* PP */
+      memxor(dst, T.b, XTS_BLOCK_SIZE);	        /* PP -> P(n-1) */
+
+      /* Do this after we read src so inplace operations do not break */
+      dst += XTS_BLOCK_SIZE;
+      memcpy(dst, S.b, length);                 /* S_1 -> P(n) */
     }
 }
-- 
GitLab