[Cryptech-Commits] [sw/libhal] 37/58: Refactor CRT code into public API.

git at cryptech.is git at cryptech.is
Tue Jul 7 18:25:21 UTC 2015


This is an automated email from the git hooks/post-receive script.

sra at hactrn.net pushed a commit to branch master
in repository sw/libhal.

commit 5e4fc533393e01e16739f450d46f739ca4b24fe8
Author: Rob Austein <sra at hactrn.net>
Date:   Thu Jun 18 14:55:51 2015 -0400

    Refactor CRT code into public API.
---
 cryptech.h       |  10 ++-
 rsa.c            | 183 ++++++++++++++++++++++++++++++++++++-------------------
 tests/test-rsa.c |   8 +--
 3 files changed, 131 insertions(+), 70 deletions(-)

diff --git a/cryptech.h b/cryptech.h
index 6af9ce8..4b8fe17 100644
--- a/cryptech.h
+++ b/cryptech.h
@@ -628,9 +628,13 @@ extern hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type,
 
 extern void hal_rsa_key_clear(hal_rsa_key_t key);
 
-extern hal_error_t hal_rsa_crt(hal_rsa_key_t key,
-                               const uint8_t * const m,  const size_t m_len,
-                               uint8_t * result, const size_t result_len);
+extern hal_error_t hal_rsa_encrypt(hal_rsa_key_t key,
+                                   const uint8_t * const input,  const size_t input_len,
+                                   uint8_t * output, const size_t output_len);
+
+extern hal_error_t hal_rsa_decrypt(hal_rsa_key_t key,
+                                   const uint8_t * const input,  const size_t input_len,
+                                   uint8_t * output, const size_t output_len);
 
 extern hal_error_t hal_rsa_key_gen(hal_rsa_key_t *key,
                                    void *keybuf, const size_t keybuf_len,
diff --git a/rsa.c b/rsa.c
index 9a42563..543becc 100644
--- a/rsa.c
+++ b/rsa.c
@@ -196,6 +196,123 @@ static hal_error_t modexp_fp(fp_int *msg, fp_int *exp, fp_int *mod, fp_int *res)
   return err;
 }
 
+/*
+ * RSA decryption via Chinese Remainder Theorem (Garner's formula).
+ */
+
+static hal_error_t rsa_crt(struct rsa_key *key, fp_int *msg, fp_int *sig)
+{
+  assert(key != NULL && msg != NULL && sig != NULL);
+
+  hal_error_t err = HAL_OK;
+  fp_int t, m1, m2;
+
+  fp_init(&t);
+  fp_init(&m1);
+  fp_init(&m2);
+
+  /*
+   * m1 = msg ** dP mod p
+   * m2 = msg ** dQ mod q
+   */
+  if ((err = modexp_fp(msg, &key->dP, &key->p, &m1)) != HAL_OK ||
+      (err = modexp_fp(msg, &key->dQ, &key->q, &m2)) != HAL_OK)
+    goto fail;
+
+  /*
+   * t = m1 - m2.
+   */
+  fp_sub(&m1, &m2, &t);
+
+  /*
+   * Add zero (mod p) if needed to make t positive.  If doing this
+   * once or twice doesn't help, something is very wrong.
+   */
+  if (fp_cmp_d(&t, 0) == FP_LT)
+    fp_add(&t, &key->p, &t);
+  if (fp_cmp_d(&t, 0) == FP_LT)
+    fp_add(&t, &key->p, &t);
+  if (fp_cmp_d(&t, 0) == FP_LT)
+    lose(HAL_ERROR_IMPOSSIBLE);
+
+  /*
+   * sig = (t * u mod p) * q + m2
+   */
+  FP_CHECK(fp_mulmod(&t, &key->u, &key->p, &t));
+  fp_mul(&t, &key->q, &t);
+  fp_add(&t, &m2, sig);
+
+ fail:
+  fp_zero(&t);
+  fp_zero(&m1);
+  fp_zero(&m2);
+  return err;
+}
+
+/*
+ * Public API for raw RSA encryption and decryption.
+ */
+
+hal_error_t hal_rsa_encrypt(hal_rsa_key_t key_,
+                            const uint8_t * const input,  const size_t input_len,
+                            uint8_t * output, const size_t output_len)
+{
+  struct rsa_key *key = key_.key;
+  hal_error_t err = HAL_OK;
+
+  if (key == NULL || input == NULL || output == NULL || input_len > output_len)
+    return HAL_ERROR_BAD_ARGUMENTS;
+
+  fp_int i, o;
+  fp_init(&i);
+  fp_init(&o);
+
+  fp_read_unsigned_bin(&i, (uint8_t *) input, input_len);
+
+  if ((err = modexp_fp(&i, &key->e, &key->n, &o)) != HAL_OK ||
+      (err = unpack_fp(&o, output, output_len))   != HAL_OK)
+    goto fail;
+
+ fail:
+  fp_zero(&i);
+  fp_zero(&o);
+  return err;
+}
+
+hal_error_t hal_rsa_decrypt(hal_rsa_key_t key_,
+                            const uint8_t * const input,  const size_t input_len,
+                            uint8_t * output, const size_t output_len)
+{
+  struct rsa_key *key = key_.key;
+  hal_error_t err = HAL_OK;
+
+  if (key == NULL || input == NULL || output == NULL || input_len > output_len)
+    return HAL_ERROR_BAD_ARGUMENTS;
+
+  fp_int i, o;
+  fp_init(&i);
+  fp_init(&o);
+
+  fp_read_unsigned_bin(&i, (uint8_t *) input, input_len);
+
+  /*
+   * Do CRT if we have all the necessary key components, otherwise
+   * just do brute force ModExp.
+   */
+
+  if (fp_iszero(&key->p) || fp_iszero(&key->q) || fp_iszero(&key->u) || fp_iszero(&key->dP) || fp_iszero(&key->dQ))
+    err = modexp_fp(&i, &key->d, &key->n, &o);
+  else
+    err = rsa_crt(key, &i, &o);
+  
+  if (err != HAL_OK || (err = unpack_fp(&o, output, output_len)) != HAL_OK)
+    goto fail;
+
+ fail:
+  fp_zero(&i);
+  fp_zero(&o);
+  return err;
+}
 
 /*
  * Clear a key.  We might want to do something a bit more energetic
@@ -255,74 +372,14 @@ hal_error_t hal_rsa_key_load(const hal_rsa_key_type_t type,
   return HAL_ERROR_BAD_ARGUMENTS;
 }
 
-/*
- * RSA decyrption/signature using the Chinese Remainder Theorem
- * (Garner's formula).
- */
-
-hal_error_t hal_rsa_crt(hal_rsa_key_t key_,
-                        const uint8_t * const m,  const size_t m_len,
-                        uint8_t * result, const size_t result_len)
-{
-  hal_error_t err = HAL_OK;
-  struct rsa_key *key = key_.key;
-  struct { fp_int t, msg, m1, m2; } tmp;
-
-  fp_init(&tmp.t);
-  fp_init(&tmp.msg);
-  fp_init(&tmp.m1);
-  fp_init(&tmp.m2);
-
-  fp_read_unsigned_bin(&tmp.msg, (uint8_t *) m, m_len);
-
-  /*
-   * m1 = msg ** dP mod p
-   * m2 = msg ** dQ mod q
-   */
-  if ((err = modexp_fp(&tmp.msg, &key->dP, &key->p, &tmp.m1)) != HAL_OK ||
-      (err = modexp_fp(&tmp.msg, &key->dQ, &key->q, &tmp.m2)) != HAL_OK)
-    goto fail;
-
-  /*
-   * t = m1 - m2.
-   * Add zero (mod p) once or twice if necessary to get positive result.
-   */
-  fp_sub(&tmp.m1, &tmp.m2, &tmp.t);
-  if (fp_cmp_d(&tmp.t, 0) == FP_LT)
-    fp_add(&tmp.t, &key->p, &tmp.t);
-  if (fp_cmp_d(&tmp.t, 0) == FP_LT)
-    fp_add(&tmp.t, &key->p, &tmp.t);
-  if (fp_cmp_d(&tmp.t, 0) == FP_LT)
-    lose(HAL_ERROR_IMPOSSIBLE);
-
-  /*
-   * t = (t * u mod p) * q + m2
-   */
-  FP_CHECK(fp_mulmod(&tmp.t, &key->u, &key->p, &tmp.t));
-  fp_mul(&tmp.t, &key->q, &tmp.t);
-  fp_add(&tmp.t, &tmp.m2, &tmp.t);
-
-  /*
-   * t now holds result, write it back to caller
-   */
-  if ((err = unpack_fp(&tmp.t, result, result_len)) != HAL_OK)
-    goto fail;
-
-  /*
-   * Done, fall through into cleanup.
-   */
-
- fail:
-  memset(&tmp, 0, sizeof(tmp));
-  return err;
-}
-
 static hal_error_t find_prime(unsigned prime_length, fp_int *e, fp_int *result)
 {
   uint8_t buffer[prime_length];
   hal_error_t err;
   fp_int t;
 
+  fp_init(&t);
+
   /*
    * Get random bytes, munge a few bits, and stuff into a bignum.
    * Keep doing this until we find a result that's (probably) prime
@@ -547,7 +604,7 @@ static hal_error_t decode_integer(fp_int *bn,
   if (der_len != NULL)
     *der_len = hlen + vlen;
 
-  if (vlen < 1)
+  if (vlen < 1 || (der[hlen] & 0x80) != 0x00)
     return HAL_ERROR_ASN1_PARSE_FAILED;
 
   fp_init(bn);
diff --git a/tests/test-rsa.c b/tests/test-rsa.c
index 08d22c5..799f1fa 100644
--- a/tests/test-rsa.c
+++ b/tests/test-rsa.c
@@ -82,7 +82,7 @@ static int test_modexp(const char * const kind,
  * Run one RSA CRT test.
  */
 
-static int test_crt(const char * const kind, const rsa_tc_t * const tc)
+static int test_decrypt(const char * const kind, const rsa_tc_t * const tc)
 {
   printf("%s test for %lu-bit RSA key\n", kind, (unsigned long) tc->size);
 
@@ -106,7 +106,7 @@ static int test_crt(const char * const kind, const rsa_tc_t * const tc)
 
   uint8_t result[tc->n.len];
 
-  if ((err = hal_rsa_crt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
+  if ((err = hal_rsa_decrypt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
     printf("RSA CRT failed: %s\n", hal_error_string(err));
 
   const int mismatch = (err == HAL_OK && memcmp(result, tc->s.val, tc->s.len) != 0);
@@ -172,7 +172,7 @@ static int test_gen(const char * const kind, const rsa_tc_t * const tc)
 
   uint8_t result[tc->n.len];
 
-  if ((err = hal_rsa_crt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
+  if ((err = hal_rsa_decrypt(key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
     printf("RSA CRT failed: %s\n", hal_error_string(err));
 
   snprintf(fn, sizeof(fn), "test-rsa-sig-%04lu.der", (unsigned long) tc->size);
@@ -244,7 +244,7 @@ static int test_rsa(const rsa_tc_t * const tc)
   time_check(test_modexp("Signature (ModExp)", tc, &tc->m, &tc->d, &tc->s));
 
   /* RSA decyrption using CRT */
-  time_check(test_crt("Signature (CRT)", tc));
+  time_check(test_decrypt("Signature (CRT)", tc));
 
   /* Key generation and CRT -- not test vector, so writes key and sig to file */
   time_check(test_gen("Generation and CRT", tc));



More information about the Commits mailing list