[Cryptech-Commits] [sw/libhal] 11/13: Reconstruct the hashsig hash tree(s) on device restart.

git at cryptech.is git at cryptech.is
Fri Apr 20 01:06:18 UTC 2018


This is an automated email from the git hooks/post-receive script.

paul at psgd.org pushed a commit to branch hashsig
in repository sw/libhal.

commit e5541de6f5e2831ebfc32c3afcfa35ff32341938
Author: Paul Selkirk <paul at psgd.org>
AuthorDate: Thu Apr 19 18:36:12 2018 -0400

    Reconstruct the hashsig hash tree(s) on device restart.
---
 hal_internal.h           |  10 +-
 hashsig.c                | 334 +++++++++++++++++++++++++++++++++++++++--------
 hashsig.h                |   4 +-
 tests/test-rpc_hashsig.c |  20 ++-
 4 files changed, 296 insertions(+), 72 deletions(-)

diff --git a/hal_internal.h b/hal_internal.h
index 4d812cc..95785ae 100644
--- a/hal_internal.h
+++ b/hal_internal.h
@@ -125,17 +125,17 @@ static inline hal_error_t hal_io_wait_valid2(const hal_core_t *core1, const hal_
 
 /*
  * Static memory allocation on start-up.  Don't use this except where
- * really necessary.  By design, there's no way to free this, we don't
- * want to have to manage a heap.  Intent is just to allow allocation
- * things like the large-ish ks_index arrays used by ks_flash.c from a
- * memory source external to the executable image file (eg, from the
- * secondary SDRAM chip on the Cryptech Alpha board).
+ * really necessary.  Intent is just to allow allocation of things like
+ * the large-ish ks_index arrays used by ks_flash.c from a memory source
+ * external to the executable image file (eg, from the secondary SDRAM
+ * chip on the Cryptech Alpha board).
  *
  * We shouldn't need this except on the HSM, so for now we don't bother
  * with implementing a version of this based on malloc() or sbrk().
  */
 
 extern void *hal_allocate_static_memory(const size_t size);
+extern hal_error_t hal_free_static_memory(const void * const ptr);
 
 /*
  * Longest hash block and digest we support at the moment.
diff --git a/hashsig.c b/hashsig.c
index 5ffbb12..0396ff7 100644
--- a/hashsig.c
+++ b/hashsig.c
@@ -185,7 +185,7 @@ static inline size_t lmots_signature_len(lmots_parameter_t * const lmots)
 
 #if RPC_CLIENT == RPC_CLIENT_LOCAL
 /* Given a key with most fields filled in, generate the lmots private and
- * public key components.
+ * public key components (x and K).
  * Let the caller worry about storage.
  */
 static hal_error_t lmots_generate(lmots_key_t * const key)
@@ -502,8 +502,8 @@ static hal_error_t lmots_private_key_to_der(const lmots_key_t * const key,
     if (key == NULL || key->type != HAL_KEY_TYPE_HASHSIG_LMOTS)
         return HAL_ERROR_BAD_ARGUMENTS;
 
-    // u32str(lmots_type) || I || u32str(q) || x[0] || x[1] || ... || x[p-1]
-    /* we also store K, to speed up restart */
+    // u32str(lmots_type) || I || u32str(q) || K || x[0] || x[1] || ... || x[p-1]
+    /* K is not an integral part of the private key, but we store it to speed up restart */
 
     /*
      * Calculate data length.
@@ -514,10 +514,10 @@ static hal_error_t lmots_private_key_to_der(const lmots_key_t * const key,
     check(hal_asn1_encode_lmots_algorithm(key->lmots->type, NULL, &len, 0)); vlen += len;
     check(hal_asn1_encode_bytestring16(&key->I, NULL, &len, 0));             vlen += len;
     check(hal_asn1_encode_size_t(key->q, NULL, &len, 0));                    vlen += len;
+    check(hal_asn1_encode_bytestring32(&key->K, NULL, &len, 0));             vlen += len;
     for (size_t i = 0; i < key->lmots->p; ++i) {
         check(hal_asn1_encode_bytestring32(&key->x[i], NULL, &len, 0));      vlen += len;
     }
-    check(hal_asn1_encode_bytestring32(&key->K, NULL, &len, 0));             vlen += len;
 
     check(hal_asn1_encode_header(ASN1_SEQUENCE, vlen, NULL, &hlen, 0));
 
@@ -539,10 +539,10 @@ static hal_error_t lmots_private_key_to_der(const lmots_key_t * const key,
     check(hal_asn1_encode_lmots_algorithm(key->lmots->type, d, &len, vlen)); d += len; vlen -= len;
     check(hal_asn1_encode_bytestring16(&key->I, d, &len, vlen));             d += len; vlen -= len;
     check(hal_asn1_encode_size_t(key->q, d, &len, vlen));                    d += len; vlen -= len;
+    check(hal_asn1_encode_bytestring32(&key->K, d, &len, vlen));             d += len; vlen -= len;
     for (size_t i = 0; i < key->lmots->p; ++i) {
         check(hal_asn1_encode_bytestring32(&key->x[i], d, &len, vlen));      d += len; vlen -= len;
     }
-    check(hal_asn1_encode_bytestring32(&key->K, d, &len, vlen));             d += len; vlen -= len;
 
     return hal_asn1_encode_pkcs8_privatekeyinfo(hal_asn1_oid_mts_hashsig, hal_asn1_oid_mts_hashsig_len,
                                                 NULL, 0, der, d - der, der, der_len, der_max);
@@ -580,20 +580,22 @@ static hal_error_t lmots_private_key_from_der(lmots_key_t *key,
     const uint8_t *d = privkey + hlen;
     size_t len;
 
-    // u32str(lmots_type) || I || u32str(q) || x[0] || x[1] || ... || x[p-1]
+    // u32str(lmots_type) || I || u32str(q) || K || x[0] || x[1] || ... || x[p-1]
 
     lmots_algorithm_t lmots_type;
     check(hal_asn1_decode_lmots_algorithm(&lmots_type, d, &len, vlen));  d += len; vlen -= len;
     key->lmots = lmots_select_parameter_set(lmots_type);
     check(hal_asn1_decode_bytestring16(&key->I, d, &len, vlen));         d += len; vlen -= len;
     check(hal_asn1_decode_size_t(&key->q, d, &len, vlen));               d += len; vlen -= len;
-    for (size_t i = 0; i < key->lmots->p; ++i) {
-        check(hal_asn1_decode_bytestring32(&key->x[i], d, &len, vlen));  d += len; vlen -= len;
-    }
     check(hal_asn1_decode_bytestring32(&key->K, d, &len, vlen));         d += len; vlen -= len;
+    if (key->x != NULL) {
+        for (size_t i = 0; i < key->lmots->p; ++i) {
+            check(hal_asn1_decode_bytestring32(&key->x[i], d, &len, vlen));  d += len; vlen -= len;
+        }
 
-    if (d != privkey + privkey_len)
-        return HAL_ERROR_ASN1_PARSE_FAILED;
+        if (d != privkey + privkey_len)
+            return HAL_ERROR_ASN1_PARSE_FAILED;
+    }
 
     return HAL_OK;
 }
@@ -677,7 +679,7 @@ static hal_error_t lms_generate(lms_key_t *key)
     hal_pkey_slot_t slot = {
         .type  = HAL_KEY_TYPE_HASHSIG_LMOTS,
         .curve = HAL_CURVE_NONE,
-        .flags = (key->level == 0) ? HAL_KEY_FLAG_TOKEN: 0
+        .flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE | ((key->level == 0) ? HAL_KEY_FLAG_TOKEN: 0)
     };
     hal_ks_t *ks = (key->level == 0) ? hal_ks_token : hal_ks_volatile;
 
@@ -745,10 +747,7 @@ static hal_error_t lms_generate(lms_key_t *key)
 
 static hal_error_t lms_delete(const lms_key_t * const key)
 {
-    hal_pkey_slot_t slot;
-    memset(&slot, 0, sizeof(slot));
-    slot.flags = (key->level == 0) ? HAL_KEY_FLAG_TOKEN: 0;
-
+    hal_pkey_slot_t slot = {0};
     hal_ks_t *ks = (key->level == 0) ? hal_ks_token : hal_ks_volatile;
 
     /* delete the lmots keys */
@@ -787,7 +786,6 @@ static hal_error_t lms_sign(lms_key_t * const key,
     /* fetch and decode the lmots signing key from the keystore */
     hal_pkey_slot_t slot;
     memset(&slot, 0, sizeof(slot));
-    slot.flags = (key->level == 0) ? HAL_KEY_FLAG_TOKEN : 0;
     memcpy(&slot.name, &key->lmots_keys[key->q], sizeof(slot.name));
 
     lmots_key_t lmots_key;
@@ -823,6 +821,8 @@ static hal_error_t lms_sign(lms_key_t * const key,
     /* update and store q before returning the signature */
     ++key->q;
     check(lms_private_key_to_der(key, der, &der_len, sizeof(der)));
+    slot.type = HAL_KEY_TYPE_HASHSIG_LMS;
+    slot.flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE | ((key->level == 0) ? HAL_KEY_FLAG_TOKEN : 0);
     memcpy(&slot.name, &key->I, sizeof(slot.name));
     check(hal_ks_rewrite_der(ks, &slot, der, der_len));
 
@@ -1069,10 +1069,7 @@ static size_t lms_private_key_to_der_len(const lms_key_t * const key)
     size_t len = 0;
     return lms_private_key_to_der(key, NULL, &len, 0) == HAL_OK ? len : 0;
 }
-#endif
 
-#if 0
-// used in restart - caller will have to allocate and attach storage for lmots_keys[] and T[]
 static hal_error_t lms_private_key_from_der(lms_key_t *key,
                                             const uint8_t *der, const size_t der_len)
 {
@@ -1132,6 +1129,7 @@ typedef struct hal_hashsig_key hss_key_t;
 struct hal_hashsig_key {
     hal_key_type_t type;
     hss_key_t *next;
+    hal_uuid_t name;
     size_t L;
     lms_parameter_t *lms;
     lmots_parameter_t *lmots;
@@ -1191,12 +1189,10 @@ static inline void *gnaw(uint8_t **mem, size_t *len, const size_t size)
     return ret;
 }
 
-/* called from pkey_local_generate_hashsig */
-hal_error_t hal_hashsig_key_gen(hal_core_t *core,
-                                hal_hashsig_key_t **key_,
-                                const size_t L,
-                                const lms_algorithm_t lms_type,
-                                const lmots_algorithm_t lmots_type)
+static hal_error_t hss_alloc(hal_hashsig_key_t **key_,
+                             const size_t L,
+                             const lms_algorithm_t lms_type,
+                             const lmots_algorithm_t lmots_type)
 {
     if (key_ == NULL)
         return HAL_ERROR_BAD_ARGUMENTS;
@@ -1221,19 +1217,11 @@ hal_error_t hal_hashsig_key_gen(hal_core_t *core,
     if (lmots_private_key_len(lmots) > HAL_KS_BLOCK_SIZE)
         return HAL_ERROR_UNSUPPORTED_KEY;
 
-    /* w=2 fails on the Alpha, as does w=4 with L=2, because the signature
-     * exceeds the meagre 4096-byte RPC packet size.
-     */
     if (hss_signature_len(L, lms, lmots) > HAL_RPC_MAX_PKT_SIZE)
         return HAL_ERROR_UNSUPPORTED_KEY;
 
-    /* check flash keystore for space to store the root tree */
-    size_t available;
-    check(hal_ks_available(hal_ks_token, &available));
-    if (available < h2 + 2)
-        return HAL_ERROR_NO_KEY_INDEX_SLOTS;
-
     /* check volatile keystore for space to store the lower-level trees */
+    size_t available;
     check(hal_ks_available(hal_ks_volatile, &available));
     if (available < (L - 1) * (h2 + 1))
         return HAL_ERROR_NO_KEY_INDEX_SLOTS;
@@ -1247,7 +1235,7 @@ hal_error_t hal_hashsig_key_gen(hal_core_t *core,
                   L * lms_sig_len +
                   L * lms_pub_len +
                   L * h2 * sizeof(hal_uuid_t) +
-                  L * (2 * h2 - 1) * sizeof(bytestring32));
+                  L * (2 * h2) * sizeof(bytestring32));
     uint8_t *mem = hal_allocate_static_memory(len);
     if (mem == NULL)
         return HAL_ERROR_ALLOCATION_FAILURE;
@@ -1255,6 +1243,7 @@ hal_error_t hal_hashsig_key_gen(hal_core_t *core,
 
     /* allocate the key that will stay in working memory */
     hss_key_t *key = gnaw(&mem, &len, sizeof(hss_key_t));
+    *key_ = key;
     key->type = HAL_KEY_TYPE_HASHSIG_PRIVATE;
     key->L = L;
     key->lms = lms;
@@ -1266,34 +1255,62 @@ hal_error_t hal_hashsig_key_gen(hal_core_t *core,
 
     /* allocate the list of lms trees */
     key->lms_keys = gnaw(&mem, &len, L * sizeof(lms_key_t));
-
-    /* generate the lms trees */
     for (size_t i = 0; i < L; ++i) {
+        /* XXX some of this is redundant to lms_private_key_from_der */
         lms_key_t * lms_key = &key->lms_keys[i];
         lms_key->type = HAL_KEY_TYPE_HASHSIG_LMS;
         lms_key->lms = lms;
         lms_key->lmots = lmots;
         lms_key->level = i;
         lms_key->lmots_keys = (hal_uuid_t *)gnaw(&mem, &len, h2 * sizeof(hal_uuid_t));
-        lms_key->T = gnaw(&mem, &len, (2 * h2 - 1) * sizeof(bytestring32));
+        lms_key->T = gnaw(&mem, &len, (2 * h2) * sizeof(bytestring32));
         lms_key->signature = gnaw(&mem, &len, lms_sig_len);
         lms_key->signature_len = lms_sig_len;
         lms_key->pubkey = gnaw(&mem, &len, lms_pub_len);
         lms_key->pubkey_len = lms_pub_len;
+    }
+
+    return HAL_OK;
+}
+
+/* called from pkey_local_generate_hashsig */
+hal_error_t hal_hashsig_key_gen(hal_core_t *core,
+                                hal_hashsig_key_t **key_,
+                                const size_t L,
+                                const lms_algorithm_t lms_type,
+                                const lmots_algorithm_t lmots_type)
+{
+    /* hss_alloc does most of the checks */
+
+    /* check flash keystore for space to store the root tree */
+    lms_parameter_t *lms = lms_select_parameter_set(lms_type);
+    if (lms == NULL)
+        return HAL_ERROR_BAD_ARGUMENTS;
+    size_t available;
+    check(hal_ks_available(hal_ks_token, &available));
+    if (available < (1U << lms->h) + 2)
+        return HAL_ERROR_NO_KEY_INDEX_SLOTS;
+
+    check(hss_alloc(key_, L, lms_type, lmots_type));
+    hss_key_t *key = *key_;
+
+    /* generate the lms trees */
+    for (size_t i = 0; i < L; ++i) {
+        lms_key_t * lms_key = &key->lms_keys[i];
 
         check(lms_generate(lms_key));
 
         if (i > 0)
             /* sign this tree with the previous */
             check(lms_sign(&key->lms_keys[i-1],
-                           (const uint8_t * const)lms_key->pubkey, lms_pub_len,
-                           lms_key->signature, NULL, lms_sig_len));
+                           (const uint8_t * const)lms_key->pubkey, lms_public_key_len(key->lms),
+                           lms_key->signature, NULL, lms_signature_len(key->lms, key->lmots)));
 
         /* store the lms key */
         hal_pkey_slot_t slot = {
             .type  = HAL_KEY_TYPE_HASHSIG_LMS,
             .curve = HAL_CURVE_NONE,
-            .flags = (i == 0) ? HAL_KEY_FLAG_TOKEN: 0
+            .flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE | ((i == 0) ? HAL_KEY_FLAG_TOKEN: 0)
         };
         hal_ks_t *ks = (i == 0) ? hal_ks_token : hal_ks_volatile;
         uint8_t der[lms_private_key_to_der_len(lms_key)];
@@ -1307,13 +1324,12 @@ hal_error_t hal_hashsig_key_gen(hal_core_t *core,
     memcpy(&key->I, &key->lms_keys[0].I, sizeof(key->I));
     memcpy(&key->T1, &key->lms_keys[0].T1, sizeof(key->T1));
 
-    *key_ = key;
-
     /* pkey_local_generate_hashsig stores the key */
 
     return HAL_OK;
 }
 
+/* caller will delete the hss key from the keystore */
 hal_error_t hal_hashsig_key_delete(const hal_hashsig_key_t * const key)
 {
     if (key == NULL || key->type != HAL_KEY_TYPE_HASHSIG_PRIVATE)
@@ -1324,6 +1340,7 @@ hal_error_t hal_hashsig_key_delete(const hal_hashsig_key_t * const key)
         check(lms_delete(&key->lms_keys[level]));
 
     /* XXX free memory, if supported */
+    (void)hal_free_static_memory(key);
 
     /* remove from global hss_keys linked list */
     /* XXX or mark it unused, for possible re-use */
@@ -1536,10 +1553,11 @@ hal_error_t hal_hashsig_private_key_to_der(const hal_hashsig_key_t * const key,
 
     size_t len, vlen = 0, hlen;
 
-    check(hal_asn1_encode_size_t(key->L, NULL, &len, 0));                          vlen += len;
-    check(hal_asn1_encode_lms_algorithm(key->lms->type, NULL, &len, 0));           vlen += len;
-    check(hal_asn1_encode_lmots_algorithm(key->lmots->type, NULL, &len, 0));       vlen += len;
-    check(hal_asn1_encode_uuid((hal_uuid_t *)&key->lms_keys[0].I, NULL, &len, 0)); vlen += len;
+    check(hal_asn1_encode_size_t(key->L, NULL, &len, 0));                    vlen += len;
+    check(hal_asn1_encode_lms_algorithm(key->lms->type, NULL, &len, 0));     vlen += len;
+    check(hal_asn1_encode_lmots_algorithm(key->lmots->type, NULL, &len, 0)); vlen += len;
+    check(hal_asn1_encode_bytestring16(&key->I, NULL, &len, 0));             vlen += len;
+    check(hal_asn1_encode_bytestring32(&key->T1, NULL, &len, 0));            vlen += len;
 
     check(hal_asn1_encode_header(ASN1_SEQUENCE, vlen, NULL, &hlen, 0));
 
@@ -1558,10 +1576,11 @@ hal_error_t hal_hashsig_private_key_to_der(const hal_hashsig_key_t * const key,
     uint8_t *d = der + hlen;
     memset(d, 0, vlen);
 
-    check(hal_asn1_encode_size_t(key->L, d, &len, vlen));                          d += len; vlen -= len;
-    check(hal_asn1_encode_lms_algorithm(key->lms->type, d, &len, vlen));           d += len; vlen -= len;
-    check(hal_asn1_encode_lmots_algorithm(key->lmots->type, d, &len, vlen));       d += len; vlen -= len;
-    check(hal_asn1_encode_uuid((hal_uuid_t *)&key->lms_keys[0].I, d, &len, vlen)); d += len; vlen -= len;
+    check(hal_asn1_encode_size_t(key->L, d, &len, vlen));                    d += len; vlen -= len;
+    check(hal_asn1_encode_lms_algorithm(key->lms->type, d, &len, vlen));     d += len; vlen -= len;
+    check(hal_asn1_encode_lmots_algorithm(key->lmots->type, d, &len, vlen)); d += len; vlen -= len;
+    check(hal_asn1_encode_bytestring16(&key->I, d, &len, vlen));             d += len; vlen -= len;
+    check(hal_asn1_encode_bytestring32(&key->T1, d, &len, vlen));            d += len; vlen -= len;
 
     return hal_asn1_encode_pkcs8_privatekeyinfo(hal_asn1_oid_mts_hashsig, hal_asn1_oid_mts_hashsig_len,
                                                 NULL, 0, der, d - der, der, der_len, der_max);
@@ -1582,7 +1601,7 @@ hal_error_t hal_hashsig_private_key_from_der(hal_hashsig_key_t **key_,
 
     memset(keybuf, 0, keybuf_len);
 
-    hss_key_t *key = keybuf;
+    hss_key_t *key = *key_ = keybuf;
 
     key->type = HAL_KEY_TYPE_HASHSIG_PRIVATE;
 
@@ -1614,8 +1633,8 @@ hal_error_t hal_hashsig_private_key_from_der(hal_hashsig_key_t **key_,
     lmots_algorithm_t lmots_type;
     check(hal_asn1_decode_lmots_algorithm(&lmots_type, d, &n, vlen)); d += n; vlen -= n;
     key->lmots = lmots_select_parameter_set(lmots_type);
-    hal_uuid_t I;
-    check(hal_asn1_decode_uuid(&I, d, &n, vlen));                     d += n; vlen -= n;
+    check(hal_asn1_decode_bytestring16(&key->I, d, &n, vlen));        d += n; vlen -= n;
+    check(hal_asn1_decode_bytestring32(&key->T1, d, &n, vlen));       d += n; vlen -= n;
 
     if (d != privkey + privkey_len)
         return HAL_ERROR_ASN1_PARSE_FAILED;
@@ -1626,12 +1645,12 @@ hal_error_t hal_hashsig_private_key_from_der(hal_hashsig_key_t **key_,
      * and not molest ours.)
      */
     for (hss_key_t *hss_key = hss_keys; hss_key != NULL; hss_key = hss_key->next) {
-        if (hal_uuid_cmp(&I, (hal_uuid_t *)&hss_key->lms_keys[0].I) == 0) {
+        if (memcmp(&key->I, &hss_key->lms_keys[0].I, sizeof(key->I)) == 0) {
             *key_ = hss_key;
-            return HAL_OK;
         }
     }
-    return HAL_ERROR_KEY_NOT_FOUND;     // or IMPOSSIBLE?
+
+    return HAL_OK;
 }
 
 hal_error_t hal_hashsig_public_key_to_der(const hal_hashsig_key_t * const key,
@@ -1838,3 +1857,202 @@ hal_error_t hal_hashsig_public_key_der_to_xdr(const uint8_t * const der, const s
 
     return HAL_OK;
 }
+
+#if RPC_CLIENT == RPC_CLIENT_LOCAL
+/* Reinitialize the hashsig key structures after a device restart */
+hal_error_t hal_hashsig_ks_init(void)
+{
+    const hal_client_handle_t  client  = { -1 };
+    const hal_session_handle_t session = { HAL_HANDLE_NONE };
+    hal_uuid_t prev_name = {{0}};
+    unsigned len;
+    hal_pkey_slot_t slot = {0};
+    uint8_t der[HAL_KS_WRAPPED_KEYSIZE];
+    size_t der_len;
+
+    /* Find all hss private keys */
+    while ((hal_ks_match(hal_ks_token, client, session,
+                         HAL_KEY_TYPE_HASHSIG_PRIVATE, HAL_CURVE_NONE, 0, 0, NULL, 0,
+                         &slot.name, &len, 1, &prev_name) == HAL_OK) &&  (len > 0)) {
+        hal_hashsig_key_t keybuf, *key;
+        if (hal_ks_fetch(hal_ks_token, &slot, der, &der_len, sizeof(der)) != HAL_OK ||
+            hal_hashsig_private_key_from_der(&key, (void *)&keybuf, sizeof(keybuf), der, der_len) != HAL_OK) {
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            continue;
+        }
+
+        /* Make sure we have the lms key */
+        hal_pkey_slot_t lms_slot = {0};
+        lms_key_t lms_key;
+        memcpy(&lms_slot.name, &key->I, sizeof(lms_slot.name));
+        if (hal_ks_fetch(hal_ks_token, &lms_slot, der, &der_len, sizeof(der)) != HAL_OK ||
+            lms_private_key_from_der(&lms_key, der, der_len) != HAL_OK ||
+            /* check keys for consistency */
+            lms_key.lms != key->lms ||
+            lms_key.lmots != key->lmots ||
+            memcmp(&lms_key.I, &key->I, sizeof(lms_key.I)) != 0 ||
+            /* optimistically allocate the full hss key structure */
+            hss_alloc(&key, key->L, key->lms->type, key->lmots->type) != HAL_OK) {
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            (void)hal_ks_delete(hal_ks_token, &lms_slot);
+            continue;
+        }
+
+        /* hss_alloc redefines key, so copy fields from the old version of the key */
+        memcpy(&key->I, &keybuf.I, sizeof(key->I));
+        memcpy(&key->T1, &keybuf.T1, sizeof(key->T1));
+        key->name = slot.name;
+
+        /* initialize top-level lms key (beyond what hss_alloc did) */
+        memcpy(&key->lms_keys[0].I, &lms_key.I, sizeof(lms_key.I));
+        key->lms_keys[0].q = lms_key.q;
+
+        prev_name = slot.name;
+    }
+
+    /* Delete orphaned lms keys */
+    memset(&prev_name, 0, sizeof(prev_name));
+    while ((hal_ks_match(hal_ks_token, client, session,
+                         HAL_KEY_TYPE_HASHSIG_LMS, HAL_CURVE_NONE, 0, 0, NULL, 0,
+                         &slot.name, &len, 1, &prev_name) == HAL_OK) && (len > 0)) {
+        hss_key_t *hss_key;
+        for (hss_key = hss_keys; hss_key != NULL; hss_key = hss_key->next) {
+            if (memcmp(&slot.name, &hss_key->I, sizeof(slot.name)) == 0)
+                break;
+        }
+        if (hss_key == NULL) {
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            continue;
+        }
+
+        prev_name = slot.name;
+    }
+
+    /* Find all lmots keys */
+    memset(&prev_name, 0, sizeof(prev_name));
+    while ((hal_ks_match(hal_ks_token, client, session,
+                         HAL_KEY_TYPE_HASHSIG_LMOTS, HAL_CURVE_NONE, 0, 0, NULL, 0,
+                         &slot.name, &len, 1, &prev_name) == HAL_OK) && (len > 0)) {
+        if (hss_keys == NULL) {
+            /* if no hss keys were recovered, all lmots keys are orphaned */
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            continue;
+        }
+
+        lmots_key_t lmots_key = {0};
+        if (hal_ks_fetch(hal_ks_token, &slot, der, &der_len, sizeof(der)) != HAL_OK ||
+            lmots_private_key_from_der(&lmots_key, der, der_len) != HAL_OK) {
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            continue;
+        }
+
+        hss_key_t *hss_key;
+        for (hss_key = hss_keys; hss_key != NULL; hss_key = hss_key->next) {
+            if (memcmp(&hss_key->I, &lmots_key.I, sizeof(lmots_key.I)) == 0)
+                break;
+        }
+        if (hss_key == NULL) {
+            /* delete orphaned key */
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            continue;
+        }
+
+        /* record this lmots key in the top-level lms key */
+        memcpy(&hss_key->lms_keys[0].lmots_keys[lmots_key.q], &slot.name, sizeof(slot.name));
+
+        /* compute T[r] = H(I || u32str(r) || u16str(D_LEAF) || K) */
+        size_t r = (1U << hss_key->lms->h) + lmots_key.q;
+        uint8_t statebuf[512];
+        hal_hash_state_t *state = NULL;
+        hal_hash_initialize(NULL, hal_hash_sha256, &state, statebuf, sizeof(statebuf));
+        hal_hash_update(state, (const uint8_t *)&hss_key->I, sizeof(hss_key->I));
+        uint32_t l = u32str(r); hal_hash_update(state, (const uint8_t *)&l, sizeof(l));
+        uint16_t s = u16str(D_LEAF); hal_hash_update(state, (const uint8_t *)&s, sizeof(s));
+        hal_hash_update(state, (const uint8_t *)&lmots_key.K, sizeof(lmots_key.K));
+        hal_hash_finalize(state, (uint8_t *)&hss_key->lms_keys[0].T[r], sizeof(hss_key->lms_keys[0].T[r]));
+
+        prev_name = slot.name;
+    }
+
+    /* After all keys have been read, scan for completeness. */
+    hal_uuid_t uuid_0 = {{0}};
+    hss_key_t *hss_key, *hss_next = NULL;
+    for (hss_key = hss_keys; hss_key != NULL; hss_key = hss_next) {
+        hss_next = hss_key->next;
+        int fail = 0;
+        for (size_t i = 0; i < (1U << hss_key->lms->h); ++i) {
+            if (hal_uuid_cmp(&hss_key->lms_keys[0].lmots_keys[i], &uuid_0) == 0) {
+                fail = 1;
+                break;
+            }
+        }
+        if (fail) {
+        fail:
+            /* lms key is incomplete, give up on it */
+            /* delete lmots keys */
+            for (size_t i = 0; i < (1U << hss_key->lms->h); ++i) {
+                if (hal_uuid_cmp(&hss_key->lms_keys[0].lmots_keys[i], &uuid_0) != 0) {
+                    memcpy(&slot.name, &hss_key->lms_keys[0].lmots_keys[i], sizeof(slot.name));
+                    (void)hal_ks_delete(hal_ks_token, &slot);
+                }
+            }
+            /* delete lms key */
+            memcpy(&slot.name, &hss_key->I, sizeof(slot.name));
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            /* delete hss key */
+            slot.name = hss_key->name;
+            (void)hal_ks_delete(hal_ks_token, &slot);
+            /* remove the hss key from the key list */
+            if (hss_keys == hss_key) {
+                hss_keys = hss_key->next;
+            }
+            else {
+                for (hss_key_t *prev = hss_keys; prev != NULL; prev = prev->next) {
+                    if (prev->next == hss_key) {
+                        prev->next = hss_key->next;
+                        break;
+                    }
+                }
+            }
+            (void)hal_free_static_memory(hss_key);
+            continue;
+        }
+
+        /* generate the rest of T[] */
+        for (size_t r = (1U << hss_key->lms->h) - 1; r > 0; --r) {
+            uint8_t statebuf[512];
+            hal_hash_state_t *state = NULL;
+            hal_hash_initialize(NULL, hal_hash_sha256, &state, statebuf, sizeof(statebuf));
+            hal_hash_update(state, (const uint8_t *)&hss_key->I, sizeof(hss_key->I));
+            uint32_t l = u32str(r); hal_hash_update(state, (const uint8_t *)&l, sizeof(l));
+            uint16_t s = u16str(D_INTR); check(hal_hash_update(state, (const uint8_t *)&s, sizeof(s)));
+            hal_hash_update(state, (const uint8_t *)&hss_key->lms_keys[0].T[2*r], sizeof(hss_key->lms_keys[0].T[r]));
+            hal_hash_update(state, (const uint8_t *)&hss_key->lms_keys[0].T[2*r+1], sizeof(hss_key->lms_keys[0].T[r]));
+            hal_hash_finalize(state, (uint8_t *)&hss_key->lms_keys[0].T[r], sizeof(hss_key->lms_keys[0].T[r]));
+        }
+        if (memcmp(&hss_key->lms_keys[0].T[1], &hss_key->T1, sizeof(hss_key->lms_keys[0].T[1])) != 0)
+            goto fail;
+
+        /* generate the lower-level lms keys */
+        for (size_t i = 1; i < hss_key->L; ++i) {
+            lms_key_t * lms_key = &hss_key->lms_keys[i];
+            if (lms_generate(lms_key) != HAL_OK)
+                goto fail;
+
+            /* store the lms key */
+            slot.type  = HAL_KEY_TYPE_HASHSIG_LMS;
+            slot.flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE;
+            memcpy(&slot.name, &lms_key->I, sizeof(slot.name));
+            if (lms_private_key_to_der(lms_key, der, &der_len, sizeof(der)) != HAL_OK ||
+                hal_ks_store(hal_ks_volatile, &slot, der, der_len) != HAL_OK ||
+                /* sign this lms key with the previous */
+                lms_sign(&hss_key->lms_keys[i-1],
+                         (const uint8_t * const)lms_key->pubkey, lms_key->pubkey_len,
+                         lms_key->signature, NULL, lms_key->signature_len) != HAL_OK)
+                goto fail;
+        }
+    }
+
+    return HAL_OK;
+}
+#endif
diff --git a/hashsig.h b/hashsig.h
index 7bae86e..3753496 100644
--- a/hashsig.h
+++ b/hashsig.h
@@ -1,5 +1,5 @@
 /*
- * hashsig.c
+ * hashsig.h
  * ---------
  * Implementation of draft-mcgrew-hash-sigs-08.txt
  *
@@ -113,6 +113,6 @@ extern size_t hal_hashsig_lmots_private_key_len(const lmots_algorithm_t lmots_ty
 extern hal_error_t hal_hashsig_public_key_der_to_xdr(const uint8_t * const der, const size_t der_len,
                                                      uint8_t * const xdr, size_t * const xdr_len , const size_t xdr_max);
 
-//extern hal_error_t hal_hashsig_restart(...);
+extern hal_error_t hal_hashsig_ks_init(void);
 
 #endif /* _HAL_HASHSIG_H_ */
diff --git a/tests/test-rpc_hashsig.c b/tests/test-rpc_hashsig.c
index b93f11e..00728c3 100644
--- a/tests/test-rpc_hashsig.c
+++ b/tests/test-rpc_hashsig.c
@@ -264,7 +264,7 @@ static int test_hashsig_sign(const size_t L,
                              const lms_algorithm_t lms_type,
                              const lmots_algorithm_t lmots_type,
                              size_t iterations,
-                             int save)
+                             int save, int keep)
 {
     const hal_client_handle_t client = {HAL_HANDLE_NONE};
     const hal_session_handle_t session = {HAL_HANDLE_NONE};
@@ -287,7 +287,7 @@ static int test_hashsig_sign(const size_t L,
                 lose("Error closing %s: %s\n", save_name, strerror(errno));
         }
 
-        hal_key_flags_t flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE;
+        hal_key_flags_t flags = HAL_KEY_FLAG_USAGE_DIGITALSIGNATURE | HAL_KEY_FLAG_TOKEN;
 
         printf("Starting hashsig key test: L %lu, lms type %u (h=%lu), lmots type %u (w=%lu)\n",
                L, lms_type, lms_type_to_h(lms_type), lmots_type, lmots_type_to_w(lmots_type));
@@ -399,8 +399,10 @@ static int test_hashsig_sign(const size_t L,
             }
         }
 
-        if ((err = hal_rpc_pkey_delete(private_key)) != HAL_OK)
-            lose("Could not delete private key: %s\n", hal_error_string(err));
+        if (!keep) {
+            if ((err = hal_rpc_pkey_delete(private_key)) != HAL_OK)
+                lose("Could not delete private key: %s\n", hal_error_string(err));
+        }
 
         if ((err = hal_rpc_pkey_delete(public_key)) != HAL_OK)
             lose("Could not delete public key: %s\n", hal_error_string(err));
@@ -460,7 +462,7 @@ int main(int argc, char *argv[])
     size_t L_lo = 0, L_hi = 0;
     size_t lms_lo = 5, lms_hi = 0;
     size_t lmots_lo = 3, lmots_hi = 0;
-    int save = 0;
+    int save = 0, keep = 0;
     char *p;
     hal_error_t err;
     int ok = 1;
@@ -476,11 +478,12 @@ Usage: %s [-d] [-i] [-p pin] [-t] [-L n] [-l n] [-o n] [-n n] [-s] [-r file]\n\
        -o: LM-OTS type (1..4)\n\
        -n: number of signatures to generate (0..'max')\n\
        -s: save generated public key and signatures\n\
+       -k: keep (don't delete) the generated keys on the hsm\n\
        -r: read and pretty-print a saved signature file\n\
 Numeric arguments can be a single number or a range, e.g. '1..4'\n";
 
     int opt;
-    while ((opt = getopt(argc, argv, "ditp:L:l:o:n:sr:h?")) != -1) {
+    while ((opt = getopt(argc, argv, "ditp:L:l:o:n:skr:h?")) != -1) {
         switch (opt) {
         case 'd':
             debug = 1;
@@ -526,6 +529,9 @@ Numeric arguments can be a single number or a range, e.g. '1..4'\n";
         case's':
             save = 1;
             break;
+        case 'k':
+            keep = 1;
+            break;
         case 'r':
             ok &= read_sig(optarg);
             do_default = 0;
@@ -572,7 +578,7 @@ Numeric arguments can be a single number or a range, e.g. '1..4'\n";
         for (size_t L = L_lo; L <= L_hi; ++L) {
             for (lms_algorithm_t lms_type = lms_lo; lms_type <= lms_hi; ++lms_type) {
                 for (lmots_algorithm_t lmots_type = lmots_lo; lmots_type <= lmots_hi; ++lmots_type) {
-                    ok &= test_hashsig_sign(L, lms_type, lmots_type, iterations, save);
+                    ok &= test_hashsig_sign(L, lms_type, lmots_type, iterations, save, keep);
                 }
             }
         }



More information about the Commits mailing list