diff --git a/pss-mgf.c b/pss-mgf.c index 3d7352ea70b59fff1411c4ee0c611d6f8eb23c4c..e9efa7b037b3df328b2eb1060da791711d52ac11 100644 --- a/pss-mgf.c +++ b/pss-mgf.c @@ -41,6 +41,7 @@ #include "pss.h" #include "gmp-glue.h" +#include "memxor.h" void _increment_count(uint8_t *c) { c[3]++; @@ -59,9 +60,9 @@ void _increment_count(uint8_t *c) { } int -_mgf1(void* hash_ctx, const struct nettle_hash *hash_func, +_nettle_mgf1xor(void* hash_ctx, const struct nettle_hash *hash_func, size_t m_hash_length, const uint8_t *h_seed, - size_t dbmask_length, uint8_t *dbmask) + size_t db_length, uint8_t *db) { int ret; size_t i; @@ -77,14 +78,12 @@ _mgf1(void* hash_ctx, const struct nettle_hash *hash_func, mpz_t mask_length; mpz_t hash_length; mpz_t max_mask_length; - mpz_t temp; mpz_init(mask_length); mpz_init(hash_length); mpz_init(max_mask_length); - mpz_init(temp); - mpz_set_ui(mask_length, dbmask_length); + mpz_set_ui(mask_length, db_length); mpz_set_ui(hash_length, m_hash_length); mpz_ui_pow_ui(max_mask_length, 2, 32); mpz_mul(max_mask_length, max_mask_length, hash_length); @@ -120,14 +119,14 @@ _mgf1(void* hash_ctx, const struct nettle_hash *hash_func, // hash_input = mgfseed || C memcpy(hash_input, h_seed, m_hash_length); - for (i = 0; i < (dbmask_length / m_hash_length) -1; i++) + for (i = 0; i < (db_length / m_hash_length) -1; i++) { memcpy(hash_input + m_hash_length, octetcounter, 4); hash_func->init(hash_ctx); hash_func->update(hash_ctx, m_hash_length + 4, hash_input); hash_func->digest(hash_ctx, m_hash_length, t); - memcpy(dbmask + (m_hash_length * i), t , m_hash_length); + memxor(db + (m_hash_length * i), t , m_hash_length); _increment_count(octetcounter); } @@ -140,7 +139,6 @@ cleanup: mpz_clear(mask_length); mpz_clear(hash_length); mpz_clear(max_mask_length); - mpz_clear(temp); TMP_GMP_FREE(hash_input); TMP_GMP_FREE(t); diff --git a/pss.h b/pss.h index 4bca51800852c884b86f9ae4639fe813eb715db8..7b4754e694879e5b26d37cb53ca2e517d64f76e8 100644 --- a/pss.h +++ b/pss.h @@ -44,14 +44,12 @@ extern "C" { /* Name mangling */ #define emsapss_rsa_encode nettle_emsapss_rsa_encode #define emsapss_rsa_verify nettle_emsapss_rsa_verify -#define _mgf1 nettle_mgf1 #define _increment_count nettle_increment_count /* PSS primitive functions */ int emsapss_rsa_encode(void* hash_ctx, const struct nettle_hash *hash_func, - void *random_ctx, nettle_random_func *random, - size_t salt_length, + size_t salt_length, const uint8_t *salt, size_t m_hash_length, const uint8_t *m_hash, size_t embits, mpz_t em); @@ -62,9 +60,9 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, size_t embits, const mpz_t em); // Internal mask generation functions -int _mgf1(void* hash_ctx, const struct nettle_hash *hash_func, +int _nettle_mgf1xor(void* hash_ctx, const struct nettle_hash *hash_func, size_t m_hash_length, const uint8_t *h, - size_t dbmask_length, uint8_t *dbmask); + size_t db_length, uint8_t *db); void _increment_count(uint8_t *c); /* Mask generation function constants */ diff --git a/rsa-emsapss-encode.c b/rsa-emsapss-encode.c index c70fdec39a402d78095cbb43429c6ca7ffa8d423..d8737ae4f695065e2754b60362ff35898810edbf 100644 --- a/rsa-emsapss-encode.c +++ b/rsa-emsapss-encode.c @@ -44,8 +44,7 @@ int emsapss_rsa_encode(void* hash_ctx, const struct nettle_hash *hash_func, - void *random_ctx, nettle_random_func *random, - size_t salt_length, + size_t salt_length, const uint8_t *salt, size_t m_hash_length, const uint8_t *m_hash, size_t embits, mpz_t em) { @@ -60,14 +59,12 @@ emsapss_rsa_encode(void* hash_ctx, const struct nettle_hash *hash_func, const size_t emlen = (embits+7)/8; const size_t db_length = emlen - m_hash_length -1; - TMP_GMP_DECL(salt, uint8_t); TMP_GMP_DECL(h, uint8_t); TMP_GMP_DECL(m_prime, uint8_t); TMP_GMP_DECL(db, uint8_t); TMP_GMP_DECL(dbmask, uint8_t); TMP_GMP_DECL(em_prime, uint8_t); - TMP_GMP_ALLOC(salt, salt_length); TMP_GMP_ALLOC(h, m_hash_length); TMP_GMP_ALLOC(m_prime, 8 + m_hash_length + salt_length); TMP_GMP_ALLOC(db, db_length); @@ -83,16 +80,13 @@ emsapss_rsa_encode(void* hash_ctx, const struct nettle_hash *hash_func, /* 4. Generate a random octet string salt of length sLen; if sLen = 0, then salt is the empty string. */ - random(random_ctx, salt_length, salt); + // We pass the salt as input to this function. /* 5. Let M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; M' is an octet string of length 8 + hLen + sLen with eight initial zero octets. */ // Set the first 8 octets of M' to 0 - for (i = 0; i < 8; i++) - { - m_prime[i] = 0; - } + memset(m_prime, 0, 8); // Copy mHash into M' memcpy(m_prime + 8, m_hash, m_hash_length); @@ -111,10 +105,7 @@ emsapss_rsa_encode(void* hash_ctx, const struct nettle_hash *hash_func, emLen - hLen - 1. */ // Set the PS octets in DB to be 0 - for (i = 0; i < emlen - salt_length - m_hash_length - 2; i++) - { - db[i] = 0; - } + memset(db, 0, emlen - salt_length - m_hash_length - 2); // Set the next octet to be 0x01 db[emlen - salt_length - m_hash_length -2] = 0x01; @@ -125,17 +116,12 @@ emsapss_rsa_encode(void* hash_ctx, const struct nettle_hash *hash_func, /* 9. Let dbMask = MGF(H, emLen - hLen - 1). */ /* 10. Let maskedDB = DB \xor dbMask. */ - if (!_mgf1(hash_ctx, hash_func, m_hash_length, h, db_length, dbmask)) + if (!_nettle_mgf1xor(hash_ctx, hash_func, m_hash_length, h, db_length, dbmask)) { ret = 0; goto cleanup; } - for (i=0; i < db_length; i++) - { - db[i] ^= dbmask[i]; - } - /* 11. Set the leftmost 8emLen - emBits bits of the leftmost octet in maskedDB to zero. */ db[0] &= (0xFF >> (unsigned int)(8*emlen-embits)); @@ -153,7 +139,6 @@ emsapss_rsa_encode(void* hash_ctx, const struct nettle_hash *hash_func, // Clean up. Free temporary variables and return. cleanup: - TMP_GMP_FREE(salt); TMP_GMP_FREE(h); TMP_GMP_FREE(m_prime); TMP_GMP_FREE(db); diff --git a/rsa-emsapss-verify.c b/rsa-emsapss-verify.c index ec2c6588838704102891aded8d55fc5ada5dc10d..c59870e6a7016eebcf622ecd9332f3564e005cc5 100644 --- a/rsa-emsapss-verify.c +++ b/rsa-emsapss-verify.c @@ -59,13 +59,12 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, int ret; size_t i; - const size_t emlen = (embits+7)/8; - const size_t db_length = emlen - m_hash_length -1; + const size_t emlen = (embits + 7) / 8; + const size_t db_length = emlen - m_hash_length - 1; // Convert m to an octet string em of length emlen TMP_GMP_DECL(em, uint8_t); TMP_GMP_DECL(db, uint8_t); - TMP_GMP_DECL(dbmask, uint8_t); TMP_GMP_DECL(h, uint8_t); TMP_GMP_DECL(h_prime, uint8_t); TMP_GMP_DECL(salt, uint8_t); @@ -73,7 +72,6 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, TMP_GMP_ALLOC(em, emlen); TMP_GMP_ALLOC(db, db_length); - TMP_GMP_ALLOC(dbmask, db_length); TMP_GMP_ALLOC(h, m_hash_length); TMP_GMP_ALLOC(h_prime, m_hash_length); TMP_GMP_ALLOC(salt, salt_length); @@ -90,7 +88,7 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, /* 4. If the rightmost octet of EM does not have hexadecimal value 0xbc, output "inconsistent" and stop. */ - if (em[emlen-1] != 0xbc) + if (em[emlen - 1] != 0xbc) { ret = 0; goto cleanup; @@ -104,7 +102,7 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, /* 6. If the leftmost 8emLen - emBits bits of the leftmost octet in maskedDB are not all equal to zero, output "inconsistent" and stop. */ - if ((db[0] & (0xff << (unsigned int)(8-(8*emlen-embits)))) != 0) + if ((db[0] & (0xff << (unsigned int)(8 - (8 * emlen-embits)))) != 0) { ret = 0; goto cleanup; @@ -112,17 +110,12 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, /* 7. Let dbMask = MGF(H, emLen - hLen - 1). 8. Let DB = maskedDB \xor dbMask. */ - if (!_mgf1(hash_ctx, hash_func, m_hash_length, h, db_length, dbmask)) + if (!_nettle_mgf1xor(hash_ctx, hash_func, m_hash_length, h, db_length, db)) { ret = 0; goto cleanup; } - for (i = 0; i < db_length; i++) - { - db[i] ^= dbmask[i]; - } - /* 9. Set the leftmost 8emLen - emBits bits of the leftmost octet in DB to zero. */ db[0] &= (0xFF >> (unsigned int)(8*emlen-embits)); @@ -140,7 +133,7 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, } } - if (db[db_length -2] != 0x01) + if (db[db_length - 2] != 0x01) { ret = 0; goto cleanup; @@ -153,10 +146,7 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, M' is an octet string of length 8 + hLen + sLen with eight initial zero octets. */ // Set the inital octets to be 0 - for (i=0; i < 8; i++) - { - m_prime[i] = 0; - } + memset(m_prime, 0, 8); // copy m_hash into m_prime memcpy(m_prime + 8, m_hash, m_hash_length); @@ -183,7 +173,6 @@ emsapss_rsa_verify(void* hash_ctx, const struct nettle_hash *hash_func, cleanup: TMP_GMP_FREE(em); TMP_GMP_FREE(db); - TMP_GMP_FREE(dbmask); TMP_GMP_FREE(h); TMP_GMP_FREE(salt); TMP_GMP_FREE(m_prime); diff --git a/rsa-pss-sign-tr.c b/rsa-pss-sign-tr.c index 82fbad99aa305eedd0c2282ebc1eb8328a69bd1b..d77997786ea0ceabb315a889ed6a1b969d3e2d62 100644 --- a/rsa-pss-sign-tr.c +++ b/rsa-pss-sign-tr.c @@ -36,8 +36,8 @@ #endif #include "rsa.h" - #include "pss.h" +#include "gmp-glue.h" int rsa_pss_sign_tr(const struct rsa_public_key *pub, @@ -52,13 +52,18 @@ rsa_pss_sign_tr(const struct rsa_public_key *pub, int ret; mpz_init(em); + TMP_GMP_DECL(salt, uint8_t); + TMP_GMP_ALLOC(salt, salt_length); + + random(random_ctx, salt_length, salt); + ret = emsapss_rsa_encode(hash_ctx, hash_func, - random_ctx, random, - salt_length, + salt_length, salt, m_hash_length, m_hash, - key->size * 8, em) && - rsa_compute_root_tr (pub, key, random_ctx, random, + key->size * 8, em) + && rsa_compute_root_tr (pub, key, random_ctx, random, signature, em); mpz_clear(em); + TMP_GMP_FREE(salt); return ret; }