Skip to content
Snippets Groups Projects
aesdata.c 5.68 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include <assert.h>
    #include <inttypes.h>
    #include <stdlib.h>
    #include <stdio.h>
    #include <string.h>
    
    #if 1
    # define BYTE_FORMAT "0x%02x"
    # define BYTE_COLUMNS 8
    #else
    # define BYTE_FORMAT "%3d"
    # define BYTE_COLUMNS 0x10
    #endif
    
    #define WORD_FORMAT "0x%08x"
    #define WORD_COLUMNS 4
    
    uint8_t sbox[0x100];
    uint8_t isbox[0x100];
    
    uint8_t log[0x100];
    uint8_t ilog[0x100];
    
    uint32_t dtable[4][0x100];
    uint32_t itable[4][0x100];
    
    static unsigned
    xtime(unsigned x)
    {
      assert (x < 0x100);
    
      x <<= 1;
      if (x & 0x100)
        x ^= 0x11b;
    
      assert (x < 0x100);
    
      return x;
    }
    
    /* Computes the expoenntiatiom and logarithm tables for GF_2, to the
     * base x+1 (0x03). The unit element is 1 (0x01).*/
    static void
    compute_log(void)
    {
      unsigned i = 0;
      unsigned x = 1;
    
      memset(log, 0, 0x100);
      
      for (i = 0; i < 0x100; i++, x = x ^ xtime(x))
        {
          ilog[i] = x;
          log[x] = i;
        }
      /* Invalid. */
      log[0] = 0;
      /* The loop above sets log[1] = 0xff, which is correct,
       * but log[1] = 0 is nicer. */
      log[1] = 0;
    }
    
    static unsigned
    mult(unsigned a, unsigned b)
    {
      return (a && b) ? ilog[ (log[a] + log[b]) % 255] : 0;
    }
    
    static unsigned
    invert(unsigned x)
    {
      return x ? ilog[0xff - log[x]] : 0;
    }
    
    static unsigned
    affine(unsigned x)
    {
      return 0xff &
        (0x63^x^(x>>4)^(x<<4)^(x>>5)^(x<<3)^(x>>6)^(x<<2)^(x>>7)^(x<<1));
    }
         
    static void
    compute_sbox(void)
    {
      unsigned i;
      for (i = 0; i<0x100; i++)
        {
          sbox[i] = affine(invert(i));
          isbox[sbox[i]] = i;
        }
    }
    
    /* Generate little endian tables, i.e. the first row of the AES state
     * arrays occupies the least significant byte of the words.
     *
     * The sbox values are multiplied with the column of GF2 coefficients
     * of the polynomial 03 x^3 + x^2 + x + 02. */
    static void
    compute_dtable(void)
    {
      unsigned i;
      for (i = 0; i<0x100; i++)
        {
          unsigned s = sbox[i];
          unsigned j;
          uint32_t t  =( ( (s ^ xtime(s)) << 24)
    		     | (s << 16) | (s << 8)
    		     | xtime(s) );
    
          for (j = 0; j<4; j++, t = (t << 8) | (t >> 24))
    	dtable[j][i] = t;
        }
    }
    
    /* The inverse sbox values are multiplied with the column of GF2 coefficients
     * of the polynomial inverse 0b x^3 + 0d x^2 + 09 x + 0e. */
    static void
    compute_itable(void)
    {
      unsigned i;
      for (i = 0; i<0x100; i++)
        {
          unsigned s = isbox[i];
          unsigned j;
          uint32_t t = ( (mult(s, 0xb) << 24)
    		   | (mult(s, 0xd) << 16)
    		   | (mult(s, 0x9) << 8)
    		   | (mult(s, 0xe) ));
          
          for (j = 0; j<4; j++, t = (t << 8) | (t >> 24))
    	itable[j][i] = t;
        }
    }
    
    static void
    display_byte_table(const char *name, uint8_t *table)
    {
      unsigned i, j;
    
      printf("uint8_t %s[0x100] =\n{", name);
    
      for (i = 0; i<0x100; i+= BYTE_COLUMNS)
        {
          printf("\n  ");
          for (j = 0; j<BYTE_COLUMNS; j++)
    	printf(BYTE_FORMAT ",", table[i + j]);
        }
    
      printf("\n};\n\n");
    }
    
    static void
    display_table(const char *name, uint32_t table[][0x100])
    {
      unsigned i, j, k;
      
      printf("uint32_t %s[4][0x100] =\n{\n  ", name);
    
      for (k = 0; k<4; k++)
        {
          printf("{ ");
          for (i = 0; i<0x100; i+= WORD_COLUMNS)
    	{
    	  printf("\n    ");
    	  for (j = 0; j<WORD_COLUMNS; j++)
    	    printf(WORD_FORMAT ",", table[k][i + j]);
    	}
          printf("\n  },");
        }
      printf("\n};\n\n");
    }
    
    static void
    display_polynomial(const unsigned *p)
    {
      printf("(%x x^3 + %x x^2 + %x x + %x)",
    	 p[3], p[2], p[1], p[0]);
    }
    
    int
    main(int argc, char **argv)
    {
      compute_log();
      if (argc == 1)
        {
          display_byte_table("log", log);
          display_byte_table("ilog", ilog);
    
          compute_sbox();
          display_byte_table("sbox", sbox);
          display_byte_table("isbox", isbox);
    
          compute_dtable();
          display_table("dtable", dtable);
    
          compute_itable();
          display_table("itable", itable);
      
          return 0;
        }
      else if (argc == 2)
        {
          unsigned a;
          for (a = 1; a<0x100; a++)
    	{
    	  unsigned a1 = invert(a);
    	  unsigned b;
    	  unsigned u;
    	  if (a1 == 0)
    	    printf("invert(%x) = 0 !\n", a);
    
    	  u = mult(a, a1);
    	  if (u != 1)
    	    printf("invert(%x) = %x; product = %x\n",
    		   a, a1, u);
    	  
    	  for (b = 1; b<0x100; b++)
    	    {
    	      unsigned b1 = invert(b);
    	      unsigned c = mult(a, b);
    
    	      if (c == 0)
    		printf("%x x %x = 0\n", a, b);
    
    	      u = mult(c, a1);
    	      if (u != b)
    		printf("%x x %x = %x, invert(%x) = %x, %x x %x = %x\n",
    		       a, b, c, a, a1, c, a1, u);
    	      
    	      u = mult(c, b1);
    	      if (u != a)
    		printf("%x x %x = %x, invert(%x) = %x, %x x %x = %x\n",
    		       a, b, c, b, b1, c, b1, u);
    	    }
    	}
          return 0;
        }
      else if (argc == 4)
        {
          unsigned a, b, c;
          int op = argv[2][0];
          a = strtoul(argv[1], NULL, 16);
          b = strtoul(argv[3], NULL, 16);
          switch (op)
    	{
    	case '+':
    	  c = a ^ b;
    	  break;
    	case '*':
    	case 'x':
    	  c = mult(a,b);
    	  break;
    	case '/':
    	  c = mult(a, invert(b));
    	  break;
    	default:
    	  return 1;
    	}
          printf("%x %c %x = %x\n", a, op, b, c);
          return 0;
        }
    #if 0
      else if (argc == 5)
        {
          /* Compute gcd(a, x^4+1) */
          unsigned d[4];
          unsigned u[4];
          
          for (i = 0; i<4; i++)
    	a[i] = strtoul(argv[1+i], NULL, 16);
        }
    #endif
      else if (argc == 9)
        {
          unsigned a[4];
          unsigned b[4];
          unsigned c[4];
          unsigned i;
          for (i = 0; i<4; i++)
    	{
    	  a[i] = strtoul(argv[1+i], NULL, 16);
    	  b[i] = strtoul(argv[5+i], NULL, 16);
    	}
    
          c[0] = mult(a[0],b[0])^mult(a[3],b[1])^mult(a[2],b[2])^mult(a[1],b[3]);
          c[1] = mult(a[1],b[0])^mult(a[0],b[1])^mult(a[3],b[2])^mult(a[2],b[3]);
          c[2] = mult(a[2],b[0])^mult(a[1],b[1])^mult(a[0],b[2])^mult(a[3],b[3]);
          c[3] = mult(a[3],b[0])^mult(a[2],b[1])^mult(a[1],b[2])^mult(a[0],b[3]);
    
          display_polynomial(a); printf(" * "); display_polynomial(b);
          printf(" = "); display_polynomial(c); printf("\n");
        }
      return 1;
    }