[Cryptech-Commits] [sw/libhal] 01/02: Preliminary support for parallel core RSA CRT.

git at cryptech.is git at cryptech.is
Thu Sep 14 00:23:08 UTC 2017


This is an automated email from the git hooks/post-receive script.

sra at hactrn.net pushed a commit to branch systolic_crt
in repository sw/libhal.

commit 410e0cf1d22c67585f0a5346e62f60aa4e90fe05
Author: Rob Austein <sra at hactrn.net>
AuthorDate: Wed Sep 13 20:20:55 2017 -0400

    Preliminary support for parallel core RSA CRT.
---
 Makefile         |   7 +-
 core.c           |  39 +++++-----
 hal.h            |  32 +++++---
 hal_internal.h   |  12 +++
 hal_io_eim.c     |  29 -------
 hal_io_fmc.c     |  31 --------
 hal_io_i2c.c     |  26 -------
 modexp.c         | 227 +++++++++++++++++++++++++++++++++----------------------
 rpc_pkey.c       |   6 +-
 rsa.c            | 138 ++++++++++++++++++++++++++++-----
 tests/test-rsa.c |  17 ++++-
 11 files changed, 329 insertions(+), 235 deletions(-)

diff --git a/Makefile b/Makefile
index ae6888d..59236af 100644
--- a/Makefile
+++ b/Makefile
@@ -109,12 +109,13 @@ CORE_OBJ = core.o csprng.o pbkdf2.o aes_keywrap.o modexp.o mkmif.o ${IO_OBJ}
 #   i2c:	Older I2C bus from Novena
 #   fmc:	FMC bus from dev-bridge and alpha boards
 
+IO_OBJ = hal_io.o
 ifeq "${IO_BUS}" "eim"
-  IO_OBJ = hal_io_eim.o novena-eim.o
+  IO_OBJ += hal_io_eim.o novena-eim.o
 else ifeq "${IO_BUS}" "i2c"
-  IO_OBJ = hal_io_i2c.o
+  IO_OBJ += hal_io_i2c.o
 else ifeq "${IO_BUS}" "fmc"
-  IO_OBJ = hal_io_fmc.o
+  IO_OBJ += hal_io_fmc.o
 endif
 
 # If we're building for STM32, position-independent code leads to some
diff --git a/core.c b/core.c
index 8e9f2b2..32823a6 100644
--- a/core.c
+++ b/core.c
@@ -97,7 +97,7 @@ static int name_matches(const hal_core_t *const core, const char * const name)
 static const struct { const char *name; hal_addr_t extra; } gaps[] = {
   { "csprng",  11 * CORE_SIZE }, /* empty slots after csprng */
   { "modexps6", 3 * CORE_SIZE }, /* ModexpS6 uses four slots */
-  { "modexpa7", 3 * CORE_SIZE }, /* ModexpA7 uses four slots */
+  { "modexpa7", 7 * CORE_SIZE }, /* ModexpA7 uses eight slots */
 };
 
 static hal_core_t *head = NULL;
@@ -203,15 +203,17 @@ hal_core_t *hal_core_find(const char *name, hal_core_t *core)
 
 hal_error_t hal_core_alloc(const char *name, hal_core_t **pcore)
 {
-  hal_core_t *core;
-  hal_error_t err = HAL_ERROR_CORE_NOT_FOUND;
+  /*
+   * This used to allow name == NULL iff *core != NULL, but the
+   * semantics were fragile and in practice we always pass a name
+   * anyway, so simplify by requiring name != NULL, always.
+   */
 
-  if (name == NULL && (pcore == NULL || *pcore == NULL))
+  if (name == NULL || pcore == NULL)
     return HAL_ERROR_BAD_ARGUMENTS;
 
-  core = *pcore;
-  if (name == NULL)
-    name = core->info.name;
+  hal_error_t err = HAL_ERROR_CORE_NOT_FOUND;
+  hal_core_t *core = *pcore;
 
   if (core != NULL) {
     /* if we can reallocate the same core, do it now */
@@ -221,24 +223,23 @@ hal_error_t hal_core_alloc(const char *name, hal_core_t **pcore)
       hal_critical_section_end();
       return HAL_OK;
     }
-    /* else fall through to search */
+    /* else forget that core and fall through to search */
+    *pcore = NULL;
   }
 
   while (1) {
     hal_critical_section_start();
     for (core = hal_core_iterate(NULL); core != NULL; core = core->next) {
-      if (name_matches(core, name)) {
-        if (core->busy) {
-          err = HAL_ERROR_CORE_BUSY;
-          continue;
-        }
-        else {
-          err = HAL_OK;
-          *pcore = core;
-          core->busy = 1;
-          break;
-        }
+      if (!name_matches(core, name))
+        continue;
+      if (core->busy) {
+        err = HAL_ERROR_CORE_BUSY;
+        continue;
       }
+      err = HAL_OK;
+      *pcore = core;
+      core->busy = 1;
+      break;
     }
     hal_critical_section_end();
     if (err == HAL_ERROR_CORE_BUSY)
diff --git a/hal.h b/hal.h
index f7a7522..c017b2d 100644
--- a/hal.h
+++ b/hal.h
@@ -201,7 +201,8 @@ typedef struct hal_core hal_core_t;
 extern void hal_io_set_debug(int onoff);
 extern hal_error_t hal_io_write(const hal_core_t *core, hal_addr_t offset, const uint8_t *buf, size_t len);
 extern hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf, size_t len);
-extern hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count);
+extern hal_error_t hal_io_wait(const hal_core_t *core, const uint8_t status, int *count);
+extern hal_error_t hal_io_wait2(const hal_core_t *core1, const hal_core_t *core2, const uint8_t status, int *count);
 
 /*
  * Core management functions.
@@ -368,19 +369,25 @@ extern hal_error_t hal_pbkdf2(hal_core_t *core,
 			      unsigned iterations_desired);
 
 /*
- * Modular exponentiation.
+ * Modular exponentiation.  This takes a ridiculous number of
+ * arguments of very similar types, making it easy to confuse them,
+ * particularly when performing two modexp operations in parallel, so
+ * we encapsulate the arguments in a structure.
  */
 
-extern void hal_modexp_set_debug(const int onoff);
+typedef struct {
+  hal_core_t *core;
+  const uint8_t *msg;    size_t msg_len;        /* Message */
+  const uint8_t *exp;    size_t exp_len;        /* Exponent */
+  const uint8_t *mod;    size_t mod_len;        /* Modulus */
+  uint8_t       *result; size_t result_len;     /* Result of exponentiation */
+  uint8_t       *coeff;  size_t coeff_len;      /* Modulus coefficient (r/w) */
+  uint8_t       *mont;   size_t mont_len;       /* Montgomery factor (r/w)*/
+} hal_modexp_arg_t;
 
-extern hal_error_t hal_modexp(hal_core_t *core,
-                              const int precalc,
-                              const uint8_t * const msg, const size_t msg_len,         /* Message */
-                              const uint8_t * const exp, const size_t exp_len,         /* Exponent */
-                              const uint8_t * const mod, const size_t mod_len,         /* Modulus */
-                              uint8_t       *    result, const size_t result_len,      /* Result of exponentiation */
-                              uint8_t       *     coeff, const size_t coeff_len,       /* Modulus coefficient (r/w) */
-                              uint8_t       *      mont, const size_t mont_len);       /* Montgomery factor (r/w)*/
+extern void hal_modexp_set_debug(const int onoff);
+extern hal_error_t hal_modexp( const int precalc, hal_modexp_arg_t *args);
+extern hal_error_t hal_modexp2(const int precalc, hal_modexp_arg_t *args1, hal_modexp_arg_t *args2);
 
 /*
  * Master Key Memory Interface
@@ -462,7 +469,8 @@ extern hal_error_t hal_rsa_encrypt(hal_core_t *core,
                                    const uint8_t * const input,  const size_t input_len,
                                    uint8_t * output, const size_t output_len);
 
-extern hal_error_t hal_rsa_decrypt(hal_core_t *core,
+extern hal_error_t hal_rsa_decrypt(hal_core_t *core1,
+                                   hal_core_t *core2,
                                    hal_rsa_key_t *key,
                                    const uint8_t * const input,  const size_t input_len,
                                    uint8_t * output, const size_t output_len);
diff --git a/hal_internal.h b/hal_internal.h
index a60d0b5..ac51cfb 100644
--- a/hal_internal.h
+++ b/hal_internal.h
@@ -103,6 +103,18 @@ static inline hal_error_t hal_io_wait_valid(const hal_core_t *core)
   return hal_io_wait(core, STATUS_VALID, &limit);
 }
 
+static inline hal_error_t hal_io_wait_ready2(const hal_core_t *core1, const hal_core_t *core2)
+{
+  int limit = -1;
+  return hal_io_wait2(core1, core2, STATUS_READY, &limit);
+}
+
+static inline hal_error_t hal_io_wait_valid2(const hal_core_t *core1, const hal_core_t *core2)
+{
+  int limit = -1;
+  return hal_io_wait2(core1, core2, STATUS_VALID, &limit);
+}
+
 /*
  * Static memory allocation on start-up.  Don't use this except where
  * really necessary.  By design, there's no way to free this, we don't
diff --git a/hal_io_eim.c b/hal_io_eim.c
index eabc42e..040cb2b 100644
--- a/hal_io_eim.c
+++ b/hal_io_eim.c
@@ -43,10 +43,6 @@
 static int debug = 0;
 static int inited = 0;
 
-#ifndef EIM_IO_TIMEOUT
-#define EIM_IO_TIMEOUT  100000000
-#endif
-
 static inline hal_error_t init(void)
 {
   if (inited)
@@ -134,31 +130,6 @@ hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf,
   return HAL_OK;
 }
 
-hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count)
-{
-  hal_error_t err;
-  uint8_t buf[4];
-  int i;
-
-  if (count && *count == -1)
-    *count = EIM_IO_TIMEOUT;
-
-  for (i = 1; ; ++i) {
-
-    if (count && (*count > 0) && (i >= *count))
-      return HAL_ERROR_IO_TIMEOUT;
-
-    if ((err = hal_io_read(core, ADDR_STATUS, buf, sizeof(buf))) != HAL_OK)
-      return err;
-
-    if ((buf[3] & status) != 0) {
-      if (count)
-        *count = i;
-      return HAL_OK;
-    }
-  }
-}
-
 /*
  * Local variables:
  * indent-tabs-mode: nil
diff --git a/hal_io_fmc.c b/hal_io_fmc.c
index 5ac73c4..0d49f1e 100644
--- a/hal_io_fmc.c
+++ b/hal_io_fmc.c
@@ -47,10 +47,6 @@
 static int debug = 0;
 static int inited = 0;
 
-#ifndef FMC_IO_TIMEOUT
-#define FMC_IO_TIMEOUT  100000000
-#endif
-
 static inline hal_error_t init(void)
 {
   if (!inited) {
@@ -136,33 +132,6 @@ hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf,
   return HAL_OK;
 }
 
-hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count)
-{
-  hal_error_t err;
-  uint8_t buf[4];
-  int i;
-
-  if (count && *count == -1)
-    *count = FMC_IO_TIMEOUT;
-
-  for (i = 1; ; ++i) {
-
-    if (count && (*count > 0) && (i >= *count))
-      return HAL_ERROR_IO_TIMEOUT;
-
-    hal_task_yield();
-
-    if ((err = hal_io_read(core, ADDR_STATUS, buf, sizeof(buf))) != HAL_OK)
-      return err;
-
-    if ((buf[3] & status) != 0) {
-      if (count)
-        *count = i;
-      return HAL_OK;
-    }
-  }
-}
-
 /*
  * Local variables:
  * indent-tabs-mode: nil
diff --git a/hal_io_i2c.c b/hal_io_i2c.c
index 018e264..8596174 100644
--- a/hal_io_i2c.c
+++ b/hal_io_i2c.c
@@ -301,32 +301,6 @@ hal_error_t hal_io_read(const hal_core_t *core, hal_addr_t offset, uint8_t *buf,
   return HAL_OK;
 }
 
-hal_error_t hal_io_wait(const hal_core_t *core, uint8_t status, int *count)
-{
-  hal_error_t err;
-  uint8_t buf[4];
-  int i;
-
-  if (count && *count == -1)
-    *count = 10;
-
-  for (i = 1; ; ++i) {
-
-    if (count && (*count > 0) && (i >= *count))
-      return HAL_ERROR_IO_TIMEOUT;
-
-    if ((err = hal_io_read(core, ADDR_STATUS, buf, 4)) != HAL_OK)
-      return err;
-
-    if (buf[3] & status) {
-      if (count)
-        *count = i;
-      return HAL_OK;
-
-    }
-  }
-}
-
 /*
  * Local variables:
  * indent-tabs-mode: nil
diff --git a/modexp.c b/modexp.c
index 12b5789..7973258 100644
--- a/modexp.c
+++ b/modexp.c
@@ -157,125 +157,174 @@ static inline hal_error_t set_buffer(const hal_core_t *core,
 }
 
 /*
- * Check a result, report on failure if debugging, pass failures up
- * the chain.
- */
-
-#define check(_expr_)                                                                   \
-  do {                                                                                  \
-    hal_error_t _err = (_expr_);                                                        \
-    if (_err != HAL_OK && debug)                                                        \
-      hal_log(HAL_LOG_WARN, "%s failed: %s\n", #_expr_, hal_error_string(_err));        \
-    if (_err != HAL_OK) {                                                               \
-      hal_core_free(core);                                                              \
-      return _err;                                                                      \
-    }                                                                                   \
-  } while (0)
-
-/*
- * Run one modexp operation.
+ * Stuff moved out of modexp so we can run two cores in parallel more
+ * easily.  We have to return to the jacket routine every time we kick
+ * a core into doing something, since only the jacket routines know
+ * how many cores we're running for any particular calculation.
+ *
+ * In theory we could do something clever where we don't wait for both
+ * cores to finish precalc before starting either of them on the main
+ * computation, but that way probably lies madness.
  */
 
-hal_error_t hal_modexp(hal_core_t *core,
-                       const int precalc,
-                       const uint8_t * const msg, const size_t msg_len,         /* Message */
-                       const uint8_t * const exp, const size_t exp_len,         /* Exponent */
-                       const uint8_t * const mod, const size_t mod_len,         /* Modulus */
-                       uint8_t *result,           const size_t result_len,      /* Result of exponentiation */
-                       uint8_t *coeff,            const size_t coeff_len,       /* Modulus coefficient (r/w) */
-                       uint8_t *mont,             const size_t mont_len)        /* Montgomery factor (r/w)*/
+static inline hal_error_t check_args(hal_modexp_arg_t *a)
 {
-  hal_error_t err;
-
   /*
-   * All pointers must be set, exponent may not be longer than
+   * All data pointers must be set, exponent may not be longer than
    * modulus, message may not be longer than twice the modulus (CRT
    * mode), result buffer must not be shorter than modulus, and all
    * input lengths must be a multiple of four bytes (the core is all
    * about 32-bit words).
    */
 
-  if (msg    == NULL || msg_len    > MODEXPA7_OPERAND_BYTES || msg_len    >  mod_len * 2 ||
-      exp    == NULL || exp_len    > MODEXPA7_OPERAND_BYTES || exp_len    >  mod_len     ||
-      mod    == NULL || mod_len    > MODEXPA7_OPERAND_BYTES ||
-      result == NULL || result_len > MODEXPA7_OPERAND_BYTES || result_len <  mod_len     ||
-      coeff  == NULL || coeff_len  > MODEXPA7_OPERAND_BYTES ||
-      mont   == NULL || mont_len   > MODEXPA7_OPERAND_BYTES ||
-      ((msg_len | exp_len | mod_len) & 3) != 0)
+  if (a         == NULL ||
+      a->msg    == NULL || a->msg_len    > MODEXPA7_OPERAND_BYTES || a->msg_len    >  a->mod_len * 2 ||
+      a->exp    == NULL || a->exp_len    > MODEXPA7_OPERAND_BYTES || a->exp_len    >  a->mod_len     ||
+      a->mod    == NULL || a->mod_len    > MODEXPA7_OPERAND_BYTES ||
+      a->result == NULL || a->result_len > MODEXPA7_OPERAND_BYTES || a->result_len <  a->mod_len     ||
+      a->coeff  == NULL || a->coeff_len  > MODEXPA7_OPERAND_BYTES ||
+      a->mont   == NULL || a->mont_len   > MODEXPA7_OPERAND_BYTES ||
+      ((a->msg_len | a->exp_len | a->mod_len) & 3) != 0)
     return HAL_ERROR_BAD_ARGUMENTS;
 
-  /*
-   * Gonna need to think about running two modexpa7 cores in parallel
-   * in CRT mode for full speed signature.
-   */
+  return HAL_OK;
+}
 
-  if (((err = hal_core_alloc(MODEXPA7_NAME, &core)) != HAL_OK))
-    return err;
+static inline hal_error_t setup_precalc(const int precalc, hal_modexp_arg_t *a)
+{
+  hal_error_t err;
 
   /*
-   * Now that we have the core, check operand length against what it
-   * says it can handle.
+   * Check that operand size is compatabible with the core.
    */
 
   uint32_t operand_max = 0;
-  check(get_register(core, MODEXPA7_ADDR_BUFFER_BITS, &operand_max));
+
+  if ((err = get_register(a->core, MODEXPA7_ADDR_BUFFER_BITS, &operand_max)) != HAL_OK)
+    return err;
+
   operand_max /= 8;
 
-  if (msg_len   > operand_max ||
-      exp_len   > operand_max ||
-      mod_len   > operand_max ||
-      coeff_len > operand_max ||
-      mont_len  > operand_max) {
-    hal_core_free(core);
+  if (a->msg_len   > operand_max ||
+      a->exp_len   > operand_max ||
+      a->mod_len   > operand_max ||
+      a->coeff_len > operand_max ||
+      a->mont_len  > operand_max)
     return HAL_ERROR_BAD_ARGUMENTS;
-  }
 
-  /* Set modulus */
+  /*
+   * Set the modulus, then initiate calculation of modulus-dependent
+   * speedup factors if necessary, by edge-triggering the "init" bit,
+   * then return to caller so it can wait for precalc.
+   */
+
+  if ((err = set_register(a->core, MODEXPA7_ADDR_MODULUS_BITS, a->mod_len * 8)) != HAL_OK  ||
+      (err = set_buffer(a->core, MODEXPA7_ADDR_MODULUS, a->mod, a->mod_len))    != HAL_OK  ||
+      (precalc && (err = hal_io_zero(a->core))                                  != HAL_OK) ||
+      (precalc && (err = hal_io_init(a->core))                                  != HAL_OK))
+    return err;
+
+  return HAL_OK;
+}
+
+static inline hal_error_t setup_calc(const int precalc, hal_modexp_arg_t *a)
+{
+  hal_error_t err;
+
+  /*
+   * Select CRT mode if and only if message is longer than exponent.
+   */
 
-  check(set_register(core, MODEXPA7_ADDR_MODULUS_BITS, mod_len * 8));
-  check(set_buffer(core, MODEXPA7_ADDR_MODULUS, mod, mod_len));
+  const uint32_t mode = a->msg_len > a->mod_len ? MODEXPA7_MODE_CRT : MODEXPA7_MODE_PLAIN;
 
   /*
-   * Calculate modulus-dependent speedup factors if needed.  Buffer
-   * space is always caller's problem (because caller almost certainly
-   * wants to stash these values in the keystore anyway).  Calculation
-   * is edge-triggered by "init" bit going from zero to one.
+   * Copy out precalc results if necessary, then load everything and
+   * start the calculation by edge-triggering the "next" bit.  If
+   * everything works, return to caller so it can wait for the
+   * calculation to complete.
    */
 
-  if (precalc) {
-    check(hal_io_zero(core));
-    check(hal_io_init(core));
-    check(hal_io_wait_ready(core));
-    check(get_buffer(core, MODEXPA7_ADDR_MODULUS_COEFF_OUT,     coeff, coeff_len));
-    check(get_buffer(core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_OUT, mont,  mont_len));
-  }
-
-  /* Load modulus-dependent speedup factors (even if we just calculated them) */
-  check(set_buffer(core, MODEXPA7_ADDR_MODULUS_COEFF_IN,     coeff, coeff_len));
-  check(set_buffer(core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_IN, mont,  mont_len));
-
-  /* Select CRT mode if and only if message is longer than exponent */
-  check(set_register(core, MODEXPA7_ADDR_MODE,
-                     (msg_len > mod_len
-                      ? MODEXPA7_MODE_CRT
-                      : MODEXPA7_MODE_PLAIN)));
-
-  /* Set message and exponent */
-  check(set_buffer(core, MODEXPA7_ADDR_MESSAGE, msg, msg_len));
-  check(set_buffer(core, MODEXPA7_ADDR_EXPONENT, exp, exp_len));
-  check(set_register(core, MODEXPA7_ADDR_EXPONENT_BITS, exp_len * 8));
-
-  /* Edge-trigger the "next" bit to start calculation, then wait for the result */
-  check(hal_io_zero(core));
-  check(hal_io_next(core));
-  check(hal_io_wait_valid(core));
-
-  /* Extract result, clean up, then done */
-  check(get_buffer(core, MODEXPA7_ADDR_RESULT, result, mod_len));
-  hal_core_free(core);
+  if ((precalc &&
+       (err = get_buffer(a->core, MODEXPA7_ADDR_MODULUS_COEFF_OUT,     a->coeff, a->coeff_len)) != HAL_OK) ||
+      (precalc &&
+        (err = get_buffer(a->core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_OUT, a->mont,  a->mont_len)) != HAL_OK) ||
+      (err = set_buffer(a->core, MODEXPA7_ADDR_MODULUS_COEFF_IN,     a->coeff, a->coeff_len))   != HAL_OK  ||
+      (err = set_buffer(a->core, MODEXPA7_ADDR_MONTGOMERY_FACTOR_IN, a->mont,  a->mont_len))    != HAL_OK  ||
+      (err = set_register(a->core, MODEXPA7_ADDR_MODE, mode))                                   != HAL_OK  ||
+      (err = set_buffer(a->core, MODEXPA7_ADDR_MESSAGE, a->msg, a->msg_len))                    != HAL_OK  ||
+      (err = set_buffer(a->core, MODEXPA7_ADDR_EXPONENT, a->exp, a->exp_len))                   != HAL_OK  ||
+      (err = set_register(a->core, MODEXPA7_ADDR_EXPONENT_BITS, a->exp_len * 8))                != HAL_OK  ||
+      (err = hal_io_zero(a->core))                                                              != HAL_OK  ||
+      (err = hal_io_next(a->core)) != HAL_OK)
+    return err;
+
   return HAL_OK;
 }
 
+static inline hal_error_t extract_result(hal_modexp_arg_t *a)
+{
+  /*
+   * Extract results from the main calculation and we're done.
+   * Hardly seems worth making this a separate function.
+   */
+
+  return get_buffer(a->core, MODEXPA7_ADDR_RESULT, a->result, a->mod_len);
+}
+
+/*
+ * Run one modexp operation.
+ */
+
+hal_error_t hal_modexp(const int precalc, hal_modexp_arg_t *a)
+{
+  hal_error_t err;
+
+  if ((err = check_args(a)) != HAL_OK)
+    return err;
+
+  if ((err = hal_core_alloc(MODEXPA7_NAME, &a->core)) == HAL_OK  &&
+      (err = setup_precalc(precalc, a))               == HAL_OK  &&
+      (!precalc ||
+       (err = hal_io_wait_ready(a->core))             == HAL_OK) &&
+      (err = setup_calc(precalc, a))                  == HAL_OK  &&
+      (err = hal_io_wait_valid(a->core))              == HAL_OK  &&
+      (err = extract_result(a))                       == HAL_OK)
+    err = HAL_OK;
+
+  hal_core_free(a->core);
+  return err;
+}
+
+/*
+ * Run two modexp operations in parallel.
+ */
+
+hal_error_t hal_modexp2(const int precalc, hal_modexp_arg_t *a1, hal_modexp_arg_t *a2)
+{
+  hal_error_t err;
+
+  if ((err = check_args(a1)) != HAL_OK ||
+      (err = check_args(a2)) != HAL_OK)
+    return err;
+
+  if ((err = hal_core_alloc(MODEXPA7_NAME, &a1->core)) == HAL_OK  &&
+      (err = hal_core_alloc(MODEXPA7_NAME, &a2->core)) == HAL_OK  &&
+      (err = setup_precalc(precalc, a1))               == HAL_OK  &&
+      (err = setup_precalc(precalc, a2))               == HAL_OK  &&
+      (!precalc ||
+       (err = hal_io_wait_ready2(a1->core, a2->core))  == HAL_OK) &&
+      (err = setup_calc(precalc, a1))                  == HAL_OK  &&
+      (err = setup_calc(precalc, a2))                  == HAL_OK  &&
+      (err = hal_io_wait_valid2(a1->core, a2->core))   == HAL_OK  &&
+      (err = extract_result(a1))                       == HAL_OK  &&
+      (err = extract_result(a2))                       == HAL_OK)
+    err = HAL_OK;
+
+  hal_core_free(a1->core);
+  hal_core_free(a2->core);
+  return err;
+}
+
 /*
  * Local variables:
  * indent-tabs-mode: nil
diff --git a/rpc_pkey.c b/rpc_pkey.c
index 53d3214..9d8975f 100644
--- a/rpc_pkey.c
+++ b/rpc_pkey.c
@@ -760,8 +760,8 @@ static hal_error_t pkey_local_sign_rsa(hal_pkey_slot_t *slot,
     input = signature;
   }
 
-  if ((err = pkcs1_5_pad(input, input_len, signature, *signature_len, 0x01))                   != HAL_OK ||
-      (err = hal_rsa_decrypt(NULL, key, signature, *signature_len, signature, *signature_len)) != HAL_OK)
+  if ((err = pkcs1_5_pad(input, input_len, signature, *signature_len, 0x01))                         != HAL_OK ||
+      (err = hal_rsa_decrypt(NULL, NULL, key, signature, *signature_len, signature, *signature_len)) != HAL_OK)
     return err;
 
   if (hal_rsa_key_needs_saving(key)) {
@@ -1276,7 +1276,7 @@ static hal_error_t pkey_local_import(const hal_client_handle_t client,
     goto fail;
   }
 
-  if ((err = hal_rsa_decrypt(NULL, rsa, data, data_len, der, data_len)) != HAL_OK)
+  if ((err = hal_rsa_decrypt(NULL, NULL, rsa, data, data_len, der, data_len)) != HAL_OK)
     goto fail;
 
   if ((err = hal_get_random(NULL, kek, sizeof(kek))) != HAL_OK)
diff --git a/rsa.c b/rsa.c
index dace19b..44ad84e 100644
--- a/rsa.c
+++ b/rsa.c
@@ -233,16 +233,20 @@ static hal_error_t modexp(hal_core_t *core,
   uint8_t modbuf[mod_len];
   uint8_t resbuf[mod_len];
 
+  hal_modexp_arg_t args = {
+    .core   = core,
+    .msg    = msgbuf, .msg_len    = sizeof(msgbuf),
+    .exp    = expbuf, .exp_len    = sizeof(expbuf),
+    .mod    = modbuf, .mod_len    = sizeof(modbuf),
+    .result = resbuf, .result_len = sizeof(resbuf),
+    .coeff  = coeff,  .coeff_len  = coeff_len,
+    .mont   = mont,   .mont_len   = mont_len
+  };
+
   if ((err = unpack_fp(msg, msgbuf, sizeof(msgbuf))) != HAL_OK ||
       (err = unpack_fp(exp, expbuf, sizeof(expbuf))) != HAL_OK ||
       (err = unpack_fp(mod, modbuf, sizeof(modbuf))) != HAL_OK ||
-      (err = hal_modexp(core, precalc,
-                        msgbuf, sizeof(msgbuf),
-                        expbuf, sizeof(expbuf),
-                        modbuf, sizeof(modbuf),
-                        resbuf, sizeof(resbuf),
-                        coeff,  coeff_len,
-                        mont,   mont_len))           != HAL_OK)
+      (err = hal_modexp(precalc, &args))             != HAL_OK)
     goto fail;
 
   fp_read_unsigned_bin(res, resbuf, sizeof(resbuf));
@@ -252,6 +256,83 @@ static hal_error_t modexp(hal_core_t *core,
   memset(expbuf, 0, sizeof(expbuf));
   memset(modbuf, 0, sizeof(modbuf));
   memset(resbuf, 0, sizeof(resbuf));
+  memset(&args,  0, sizeof(args));
+  return err;
+}
+
+static hal_error_t modexp2(const int precalc,
+                           const fp_int * const msg,
+                           hal_core_t *core1,
+                           const fp_int * const exp1,
+                           const fp_int * const mod1,
+                           fp_int       *       res1,
+                           uint8_t *coeff1, const size_t coeff1_len,
+                           uint8_t *mont1,  const size_t mont1_len,
+                           hal_core_t *core2,
+                           const fp_int * const exp2,
+                           const fp_int * const mod2,
+                           fp_int       *       res2,
+                           uint8_t *coeff2, const size_t coeff2_len,
+                           uint8_t *mont2,  const size_t mont2_len)
+{
+  hal_error_t err = HAL_OK;
+
+  if (msg  == NULL ||
+      exp1 == NULL || mod1 == NULL || res1 == NULL || coeff1 == NULL || mont1 == NULL ||
+      exp2 == NULL || mod2 == NULL || res2 == NULL || coeff2 == NULL || mont2 == NULL)
+    return HAL_ERROR_IMPOSSIBLE;
+
+  const size_t msg_len  = (fp_unsigned_bin_size(unconst_fp_int(msg))  + 3) & ~3;
+  const size_t exp1_len = (fp_unsigned_bin_size(unconst_fp_int(exp1)) + 3) & ~3;
+  const size_t mod1_len = (fp_unsigned_bin_size(unconst_fp_int(mod1)) + 3) & ~3;
+  const size_t exp2_len = (fp_unsigned_bin_size(unconst_fp_int(exp2)) + 3) & ~3;
+  const size_t mod2_len = (fp_unsigned_bin_size(unconst_fp_int(mod2)) + 3) & ~3;
+
+  uint8_t msgbuf[msg_len];
+  uint8_t expbuf1[exp1_len], modbuf1[mod1_len], resbuf1[mod1_len];
+  uint8_t expbuf2[exp2_len], modbuf2[mod2_len], resbuf2[mod2_len];
+
+  hal_modexp_arg_t args1 = {
+    .core   = core1,
+    .msg    = msgbuf,  .msg_len    = sizeof(msgbuf),
+    .exp    = expbuf1, .exp_len    = sizeof(expbuf1),
+    .mod    = modbuf1, .mod_len    = sizeof(modbuf1),
+    .result = resbuf1, .result_len = sizeof(resbuf1),
+    .coeff  = coeff1,  .coeff_len  = coeff1_len,
+    .mont   = mont1,   .mont_len   = mont1_len
+  };
+
+  hal_modexp_arg_t args2 = {
+    .core   = core2,
+    .msg    = msgbuf,  .msg_len    = sizeof(msgbuf),
+    .exp    = expbuf2, .exp_len    = sizeof(expbuf2),
+    .mod    = modbuf2, .mod_len    = sizeof(modbuf2),
+    .result = resbuf2, .result_len = sizeof(resbuf2),
+    .coeff  = coeff2,  .coeff_len  = coeff2_len,
+    .mont   = mont2,   .mont_len   = mont2_len
+  };
+
+  if ((err = unpack_fp(msg,  msgbuf,  sizeof(msgbuf)))  != HAL_OK ||
+      (err = unpack_fp(exp1, expbuf1, sizeof(expbuf1))) != HAL_OK ||
+      (err = unpack_fp(mod1, modbuf1, sizeof(modbuf1))) != HAL_OK ||
+      (err = unpack_fp(exp2, expbuf2, sizeof(expbuf2))) != HAL_OK ||
+      (err = unpack_fp(mod2, modbuf2, sizeof(modbuf2))) != HAL_OK ||
+      (err = hal_modexp2(precalc, &args1, &args2))      != HAL_OK)
+    goto fail;
+
+  fp_read_unsigned_bin(res1, resbuf1, sizeof(resbuf1));
+  fp_read_unsigned_bin(res2, resbuf2, sizeof(resbuf2));
+
+ fail:
+  memset(msgbuf,  0, sizeof(msgbuf));
+  memset(expbuf1, 0, sizeof(expbuf1));
+  memset(modbuf1, 0, sizeof(modbuf1));
+  memset(resbuf1, 0, sizeof(resbuf1));
+  memset(&args1,  0, sizeof(args1));
+  memset(expbuf2, 0, sizeof(expbuf2));
+  memset(modbuf2, 0, sizeof(modbuf2));
+  memset(resbuf2, 0, sizeof(resbuf2));
+  memset(&args2,  0, sizeof(args2));
   return err;
 }
 
@@ -280,6 +361,28 @@ static hal_error_t modexp(const hal_core_t *core, /* ignored */
   return err;
 }
 
+static hal_error_t modexp2(const int precalc, /* ignored */
+                           const fp_int * const msg,
+                           hal_core_t *core1, /* ignored */
+                           const fp_int * const exp1,
+                           const fp_int * const mod1,
+                           fp_int       *       res1,
+                           uint8_t *coeff1, const size_t coeff1_len, /* ignored */
+                           uint8_t *mont1,  const size_t mont1_len, /* ignored */
+                           hal_core_t *core2, /* ignored */
+                           const fp_int * const exp2,
+                           const fp_int * const mod2,
+                           fp_int       *       res2,
+                           uint8_t *coeff2, const size_t coeff2_len, /* ignored */
+                           uint8_t *mont2,  const size_t mont2_len) /* ignored */
+{
+  hal_error_t err = HAL_OK;
+  FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp1), unconst_fp_int(mod1), res1));
+  FP_CHECK(fp_exptmod(unconst_fp_int(msg), unconst_fp_int(exp2), unconst_fp_int(mod2), res2));
+ fail:
+  return err;
+}
+
 #endif /* HAL_RSA_SIGN_USE_MODEXP */
 
 /*
@@ -351,7 +454,7 @@ static hal_error_t create_blinding_factors(hal_core_t *core, hal_rsa_key_t *key,
  * RSA decryption via Chinese Remainder Theorem (Garner's formula).
  */
 
-static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp_int *sig)
+static hal_error_t rsa_crt(hal_core_t *core1, hal_core_t *core2, hal_rsa_key_t *key, fp_int *msg, fp_int *sig)
 {
   if (key == NULL || msg == NULL || sig == NULL)
     return HAL_ERROR_IMPOSSIBLE;
@@ -368,7 +471,7 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp
    * Handle blinding if requested.
    */
   if (blinding) {
-    if ((err = create_blinding_factors(core, key, bf, ubf)) != HAL_OK)
+    if ((err = create_blinding_factors(core1, key, bf, ubf)) != HAL_OK)
       goto fail;
     FP_CHECK(fp_mulmod(msg, bf, unconst_fp_int(key->n), msg));
   }
@@ -376,14 +479,10 @@ static hal_error_t rsa_crt(hal_core_t *core, hal_rsa_key_t *key, fp_int *msg, fp
   /*
    * m1 = msg ** dP mod p
    * m2 = msg ** dQ mod q
-   *
-   * This is just crying out to be done with parallel cores, but get
-   * the boring version working before jumping off that cliff.
    */
-  if ((err = modexp(core, precalc, msg, key->dP, key->p, m1,
-                    key->pC, sizeof(key->pC), key->pF, sizeof(key->pF))) != HAL_OK ||
-      (err = modexp(core, precalc, msg, key->dQ, key->q, m2,
-                    key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK)
+  if ((err = modexp2(precalc, msg,
+                     core1, key->dP, key->p, m1, key->pC, sizeof(key->pC), key->pF, sizeof(key->pF),
+                     core2, key->dQ, key->q, m2, key->qC, sizeof(key->qC), key->qF, sizeof(key->qF))) != HAL_OK)
     goto fail;
 
   if (precalc)
@@ -462,7 +561,8 @@ hal_error_t hal_rsa_encrypt(hal_core_t *core,
   return err;
 }
 
-hal_error_t hal_rsa_decrypt(hal_core_t *core,
+hal_error_t hal_rsa_decrypt(hal_core_t *core1,
+                            hal_core_t *core2,
                             hal_rsa_key_t *key,
                             const uint8_t * const input,  const size_t input_len,
                             uint8_t * output, const size_t output_len)
@@ -484,11 +584,11 @@ hal_error_t hal_rsa_decrypt(hal_core_t *core,
 
   if (!fp_iszero(key->p) && !fp_iszero(key->q) && !fp_iszero(key->u) &&
       !fp_iszero(key->dP) && !fp_iszero(key->dQ))
-    err = rsa_crt(core, key, i, o);
+    err = rsa_crt(core1, core2, key, i, o);
 
   else {
     const int precalc = !(key->flags & RSA_FLAG_PRECALC_N_DONE);
-    err = modexp(core, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC),
+    err = modexp(core1, precalc, i, key->d, key->n, o, key->nC, sizeof(key->nC),
                  key->nF, sizeof(key->nF));
     if (err == HAL_OK && precalc)
       key->flags |= RSA_FLAG_PRECALC_N_DONE | RSA_FLAG_NEEDS_SAVING;
diff --git a/tests/test-rsa.c b/tests/test-rsa.c
index 9ba9889..e73feea 100644
--- a/tests/test-rsa.c
+++ b/tests/test-rsa.c
@@ -60,8 +60,17 @@ static int test_modexp(hal_core_t *core,
 
   printf("%s test for %lu-bit RSA key\n", kind, (unsigned long) tc->size);
 
-  if (hal_modexp(core, 0, msg->val, msg->len, exp->val, exp->len,
-                 tc->n.val, tc->n.len, result, sizeof(result), C, sizeof(C), F, sizeof(F)) != HAL_OK)
+  hal_modexp_arg_t args = {
+    .core   = core,
+    .msg    = msg->val, .msg_len = msg->len,
+    .exp    = exp->val, .exp_len = exp->len,
+    .mod    = tc->n.val, .mod_len = tc->n.len,
+    .result = result, .result_len = sizeof(result),
+    .coeff  = C, .coeff_len = sizeof(C),
+    .mont   = F, .mont_len = sizeof(F)
+  };
+
+  if (hal_modexp(1, &args) != HAL_OK)
     return printf("ModExp failed\n"), 0;
 
   if (memcmp(result, val->val, val->len))
@@ -98,7 +107,7 @@ static int test_decrypt(hal_core_t *core,
 
   uint8_t result[tc->n.len];
 
-  if ((err = hal_rsa_decrypt(core, key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
+  if ((err = hal_rsa_decrypt(core, NULL, key, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
     printf("RSA CRT failed: %s\n", hal_error_string(err));
 
   const int mismatch = (err == HAL_OK && memcmp(result, tc->s.val, tc->s.len) != 0);
@@ -165,7 +174,7 @@ static int test_gen(hal_core_t *core,
 
   uint8_t result[tc->n.len];
 
-  if ((err = hal_rsa_decrypt(core, key1, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
+  if ((err = hal_rsa_decrypt(core, NULL, key1, tc->m.val, tc->m.len, result, sizeof(result))) != HAL_OK)
     printf("RSA CRT failed: %s\n", hal_error_string(err));
 
   snprintf(fn, sizeof(fn), "test-rsa-sig-%04lu.der", (unsigned long) tc->size);



More information about the Commits mailing list