[Cryptech-Commits] [sw/libhal] branch ksng updated: Better enum handling, more readable RPC methods.

git at cryptech.is git at cryptech.is
Fri Oct 21 04:52:08 UTC 2016


This is an automated email from the git hooks/post-receive script.

sra at hactrn.net pushed a commit to branch ksng
in repository sw/libhal.

The following commit(s) were added to refs/heads/ksng by this push:
       new  bbb84e2   Better enum handling, more readable RPC methods.
bbb84e2 is described below

commit bbb84e218f971d9dd134e85557951b36146c017a
Author: Rob Austein <sra at hactrn.net>
AuthorDate: Fri Oct 21 00:44:46 2016 -0400

    Better enum handling, more readable RPC methods.
    
    Using a context manager allows us to write the individual RPC methods
    fairly legibly, while still enforcing xdrlib.Unpacker.done() logic.
    
    Python doesn't really have enums in the sense that C does, and many
    people have put entirely too much skull sweat into trying to invent
    the Most Pythonic reimplementation of the enum concept, but an int
    subclass with a few extra methods is close enough for our purposes.
---
 libhal.py | 306 +++++++++++++++++++++++++++++++++-----------------------------
 1 file changed, 163 insertions(+), 143 deletions(-)

diff --git a/libhal.py b/libhal.py
index 0924863..5e5832b 100644
--- a/libhal.py
+++ b/libhal.py
@@ -44,6 +44,7 @@ import time
 import uuid
 import xdrlib
 import serial
+import contextlib
 
 SLIP_END     = chr(0300)        # indicates end of packet
 SLIP_ESC     = chr(0333)        # indicates byte stuffing
@@ -108,11 +109,32 @@ HALError.define(HAL_ERROR_ATTRIBUTE_NOT_FOUND       = "Attribute not found")
 HALError.define(HAL_ERROR_NO_KEY_INDEX_SLOTS        = "No key index slots available")
 
 
-def def_enum(text):
-    for i, name in enumerate(text.translate(None, ",").split()):
-        globals()[name] = i
+class Enum(int):
 
-def_enum('''
+    def __new__(cls, name, value):
+        self = int.__new__(cls, value)
+        self._name = name
+        setattr(self.__class__, name, self)
+        return self
+
+    def __str__(self):
+        return self._name
+
+    def __repr__(self):
+        return "<Enum:{0.__class__.__name__} {0._name}:{0:d}>".format(self)
+
+    @classmethod
+    def define(cls, names):
+        cls.index = tuple(cls(name, i) for i, name in enumerate(names.translate(None, ",").split()))
+        globals().update((symbol._name, symbol) for symbol in cls.index)
+
+    def xdr_packer(self, packer):
+        packer.pack_uint(self)
+
+
+class RPCFunc(Enum): pass
+
+RPCFunc.define('''
     RPC_FUNC_GET_VERSION,
     RPC_FUNC_GET_RANDOM,
     RPC_FUNC_SET_PIN,
@@ -146,7 +168,9 @@ def_enum('''
     RPC_FUNC_PKEY_DELETE_ATTRIBUTE,
 ''')
 
-def_enum('''
+class HALDigestAlgorithm(Enum): pass
+
+HALDigestAlgorithm.define('''
     hal_digest_algorithm_none,
     hal_digest_algorithm_sha1,
     hal_digest_algorithm_sha224,
@@ -157,7 +181,9 @@ def_enum('''
     hal_digest_algorithm_sha512
 ''')
 
-def_enum('''
+class HALKeyType(Enum): pass
+
+HALKeyType.define('''
     HAL_KEY_TYPE_NONE,
     HAL_KEY_TYPE_RSA_PRIVATE,
     HAL_KEY_TYPE_RSA_PUBLIC,
@@ -165,14 +191,18 @@ def_enum('''
     HAL_KEY_TYPE_EC_PUBLIC
 ''')
 
-def_enum('''
+class HALCurve(Enum): pass
+
+HALCurve.define('''
     HAL_CURVE_NONE,
     HAL_CURVE_P256,
     HAL_CURVE_P384,
     HAL_CURVE_P521
 ''')
 
-def_enum('''
+class HALUser(Enum): pass
+
+HALUser.define('''
     HAL_USER_NONE,
     HAL_USER_NORMAL,
     HAL_USER_SO,
@@ -196,6 +226,12 @@ class Attribute(object):
         packer.pack_bytes(self.value)
 
 
+class UUID(uuid.UUID):
+
+    def xdr_packer(self, packer):
+        packer.pack_bytes(self.bytes)
+
+
 def cached_property(func):
 
     attr_name = "_" + func.__name__
@@ -221,6 +257,9 @@ class Handle(object):
     def __cmp__(self, other):
         return cmp(self.handle, int(other))
 
+    def xdr_packer(self, packer):
+        packer.pack_uint(self.handle)
+
 
 class Digest(Handle):
 
@@ -279,20 +318,22 @@ class PKey(Handle):
     def verify(self, hash = 0, data = "", signature = None):
         self.hsm.pkey_verify(self, hash, data, signature)
 
-    def set_attribute(self, type, value):
-        self.hsm.pkey_set_attribute(self, type, value)
+    def set_attribute(self, attr_type, attr_value = None):
+        self.hsm.pkey_set_attribute(self, attr_type, attr_value)
 
-    def get_attribute(self, type):
-        return self.hsm.pkey_get_attribute(self, type)
+    def get_attribute(self, attr_type):
+        return self.hsm.pkey_get_attribute(self, attr_type)
 
-    def delete_attribute(self, type):
-        self.hsm.pkey_delete_attribute(self, type)
+    def delete_attribute(self, attr_type):
+        self.hsm.pkey_delete_attribute(self, attr_type)
 
 
 class HSM(object):
 
     debug = False
 
+    _send_delay = 0             # 0.1
+
     def _raise_if_error(self, status):
         if status != 0:
             raise HALError.table[status]()
@@ -300,7 +341,7 @@ class HSM(object):
     def __init__(self, device = os.getenv("CRYPTECH_RPC_CLIENT_SERIAL_DEVICE", "/dev/ttyUSB0")):
         while True:
             try:
-                self.tty = serial.Serial(device, 921600, timeout=0.1)
+                self.tty = serial.Serial(device, 921600, timeout = 0.1)
                 break
             except serial.SerialException:
                 time.sleep(0.2)
@@ -309,7 +350,8 @@ class HSM(object):
         if self.debug:
             sys.stdout.write("{:02x}".format(ord(c)))
         self.tty.write(c)
-        time.sleep(0.1)
+        if self._send_delay > 0:
+            time.sleep(self._send_delay)
 
     def _send(self, msg):       # Expects an xdrlib.Packer
         if self.debug:
@@ -365,18 +407,30 @@ class HSM(object):
         for arg in args:
             if hasattr(arg, "xdr_packer"):
                 arg.xdr_packer(packer)
-            elif isinstance(arg, (int, long, Handle)):
-                packer.pack_uint(arg)
-            elif isinstance(arg, str):
-                packer.pack_bytes(arg)
-            elif isinstance(arg, uuid.UUID):
-                packer.pack_bytes(arg.bytes)
-            elif isinstance(arg, (list, tuple)):
-                packer.pack_uint(len(arg))
-                self._pack(packer, arg)
             else:
-                raise RuntimeError("Don't know how to pack {!r} ({!r})".format(arg, type(arg)))
+                try:
+                    func = getattr(self, "_pack_" + type(arg).__name__)
+                except AttributeError:
+                    raise RuntimeError("Don't know how to pack {!r} ({!r})".format(arg, type(arg)))
+                else:
+                    func(packer, arg)
+
+    @staticmethod
+    def _pack_int(packer, arg):
+        packer.pack_uint(arg)
+
+    @staticmethod
+    def _pack_str(packer, arg):
+        packer.pack_bytes(arg)
+
+    def _pack_tuple(self, packer, arg):
+        packer.pack_uint(len(arg))
+        self._pack(packer, arg)
 
+    _pack_long = _pack_int
+    _pack_list = _pack_tuple
+
+    @contextlib.contextmanager
     def rpc(self, code, *args, **kwargs):
         client = kwargs.get("client", 0)
         packer = xdrlib.Packer()
@@ -387,179 +441,144 @@ class HSM(object):
         unpacker = self._recv(code)
         client = unpacker.unpack_uint()
         self._raise_if_error(unpacker.unpack_uint())
-        return unpacker
+        yield unpacker
+        unpacker.done()
 
     def get_version(self):
-        u = self.rpc(RPC_FUNC_GET_VERSION)
-        r = u.unpack_uint()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_GET_VERSION) as r:
+            return r.unpack_uint()
 
     def get_random(self, n):
-        u = self.rpc(RPC_FUNC_GET_RANDOM, n)
-        r = u.unpack_bytes()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_GET_RANDOM, n) as r:
+            return r.unpack_bytes()
 
     def set_pin(self, user, pin):
-        u = self.rpc(RPC_FUNC_SET_PIN, user, pin)
-        u.done()
+        with self.rpc(RPC_FUNC_SET_PIN, user, pin):
+            return
 
     def login(self, user, pin):
-        u = self.rpc(RPC_FUNC_LOGIN, user, pin)
-        u.done()
+        with self.rpc(RPC_FUNC_LOGIN, user, pin):
+            return
 
     def logout(self):
-        u = self.rpc(RPC_FUNC_LOGOUT)
-        u.done()
+        with self.rpc(RPC_FUNC_LOGOUT):
+            return
 
     def logout_all(self):
-        u = self.rpc(RPC_FUNC_LOGOUT_ALL)
-        u.done()
+        with self.rpc(RPC_FUNC_LOGOUT_ALL):
+            return
 
     def is_logged_in(self, user):
-        u = self.rpc(RPC_FUNC_IS_LOGGED_IN, user)
-        r = u.unpack_bool()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_IS_LOGGED_IN, user) as r:
+            return r.unpack_bool()
 
     def hash_get_digest_length(self, alg):
-        u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_LEN, alg)
-        r = u.unpack_uint()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_HASH_GET_DIGEST_LEN, alg) as r:
+            return r.unpack_uint()
 
     def hash_get_digest_algorithm_id(self, alg, max_len = 256):
-        u = self.rpc(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len)
-        r = u.unpack_bytes()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_HASH_GET_DIGEST_ALGORITHM_ID, alg, max_len) as r:
+            return r.unpack_bytes()
 
     def hash_get_algorithm(self, handle):
-        u = self.rpc(RPC_FUNC_HASH_GET_ALGORITHM, handle)
-        r = u.unpack_uint()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_HASH_GET_ALGORITHM, handle) as r:
+            return HALDigestAlgorithm.index[r.unpack_uint()]
 
     def hash_initialize(self, alg, key = "", client = 0, session = 0):
-        u = self.rpc(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client)
-        r = Digest(self, u.unpack_uint(), alg)
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_HASH_INITIALIZE, session, alg, key, client = client) as r:
+            return Digest(self, r.unpack_uint(), alg)
 
     def hash_update(self, handle, data):
-        u = self.rpc(RPC_FUNC_HASH_UPDATE, handle, data)
-        u.done()
+        with self.rpc(RPC_FUNC_HASH_UPDATE, handle, data):
+            return
 
     def hash_finalize(self, handle, length = None):
         if length is None:
             length = self.hash_get_digest_length(self.hash_get_algorithm(handle))
-        u = self.rpc(RPC_FUNC_HASH_FINALIZE, handle, length)
-        r = u.unpack_bytes()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_HASH_FINALIZE, handle, length) as r:
+            return r.unpack_bytes()
 
     def pkey_load(self, type, curve, der, flags = 0, client = 0, session = 0):
-        u = self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client)
-        r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()))
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_LOAD, session, type, curve, der, flags, client = client) as r:
+            return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes()))
 
     def pkey_find(self, uuid, flags = 0, client = 0, session = 0):
-        u = self.rpc(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client)
-        r = PKey(self, u.unpack_uint(), uuid)
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_FIND, session, uuid, flags, client = client) as r:
+            return PKey(self, r.unpack_uint(), uuid)
 
     def pkey_generate_rsa(self, keylen, exponent, flags = 0, client = 0, session = 0):
-        u = self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client)
-        r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()))
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_GENERATE_RSA, session, keylen, exponent, flags, client = client) as r:
+            return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes()))
 
     def pkey_generate_ec(self, curve, flags = 0, client = 0, session = 0):
-        u = self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client)
-        r = PKey(self, u.unpack_uint(), uuid.UUID(bytes = u.unpack_bytes()))
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_GENERATE_EC, session, curve, flags, client = client) as r:
+            return PKey(self, r.unpack_uint(), UUID(bytes = r.unpack_bytes()))
 
     def pkey_close(self, pkey):
-        u = self.rpc(RPC_FUNC_PKEY_CLOSE, pkey)
-        u.done()
+        with self.rpc(RPC_FUNC_PKEY_CLOSE, pkey):
+            return
 
     def pkey_delete(self, pkey):
-        u = self.rpc(RPC_FUNC_PKEY_DELETE, pkey)
-        u.done()
+        with self.rpc(RPC_FUNC_PKEY_DELETE, pkey):
+            return
 
     def pkey_get_key_type(self, pkey):
-        u = self.rpc(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey)
-        r = u.unpack_uint()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_GET_KEY_TYPE, pkey) as r:
+            return HALKeyType.index[r.unpack_uint()]
 
     def pkey_get_key_flags(self, pkey):
-        u = self.rpc(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey)
-        r = u.unpack_uint()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_GET_KEY_FLAGS, pkey) as r:
+            return r.unpack_uint()
 
     def pkey_get_public_key_len(self, pkey):
-        u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey)
-        r = u.unpack_uint()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY_LEN, pkey) as r:
+            return r.unpack_uint()
 
     def pkey_get_public_key(self, pkey, length = None):
         if length is None:
             length = self.pkey_get_public_key_len(pkey)
-        u = self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length)
-        r = u.unpack_bytes()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_GET_PUBLIC_KEY, pkey, length) as r:
+            return r.unpack_bytes()
 
     def pkey_sign(self, pkey, hash = 0, data = "", length = 1024):
-        u = self.rpc(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length)
-        r = u.unpack_bytes()
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_SIGN, pkey, hash, data, length) as r:
+            return r.unpack_bytes()
 
     def pkey_verify(self, pkey, hash = 0, data = "", signature = None):
-        u = self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature)
-        u.done()
+        with self.rpc(RPC_FUNC_PKEY_VERIFY, pkey, hash, data, signature):
+            return
 
     def pkey_list(self, flags = 0, client = 0, session = 0, length = 512):
-        u = self.rpc(RPC_FUNC_PKEY_LIST, session, length, flags, client = client)
-        r = tuple((u.unpack_uint(), u.unpack_uint(), u.unpack_uint(),
-                   uuid.UUID(bytes = u.unpack_bytes()))
-                  for i in xrange(u.unpack_uint()))
-        u.done()
-        return r
+        with self.rpc(RPC_FUNC_PKEY_LIST, session, length, flags, client = client) as r:
+            return tuple((HALKeyType.index[r.unpack_uint()],
+                          HALCurve.index[r.unpack_uint()],
+                          r.unpack_uint(),
+                          UUID(bytes = r.unpack_bytes()))
+                         for i in xrange(r.unpack_uint()))
 
     def pkey_match(self, type = 0, curve = 0, flags = 0, attributes = (),
-                   previous_uuid = uuid.UUID(int = 0), length = 512, client = 0, session = 0):
-        u = self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags,
-                     attributes, length, previous_uuid, client = client)
-        r = tuple(uuid.UUID(bytes = u.unpack_bytes())
-                  for i in xrange(u.unpack_uint()))
-        x = uuid.UUID(bytes = u.unpack_bytes())
-        u.done()
-        assert len(r) == 0 or x == r[-1]
-        return r
-
-    def pkey_set_attribute(self, pkey, type, value):
-        u = self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, type, value)
-        u.done()
-
-    def pkey_get_attribute(self, pkey, type):
-        u = self.rpc(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, type)
-        r = u.unpack_bytes()
-        u.done()
-        return r
-
-    def pkey_delete_attribute(self, pkey, type):
-        u = self.rpc(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, type)
-        u.done()
-
+                   previous_uuid = UUID(int = 0), length = 512, client = 0, session = 0):
+        with self.rpc(RPC_FUNC_PKEY_MATCH, session, type, curve, flags,
+                      attributes, length, previous_uuid, client = client) as r:
+            x = tuple(UUID(bytes = r.unpack_bytes())
+                      for i in xrange(r.unpack_uint()))
+            y = UUID(bytes = r.unpack_bytes())
+            assert len(x) == 0 or y == x[-1]
+            return x
+
+    def pkey_set_attribute(self, pkey, attr_type, attr_value = None):
+        if attr_value is None and isinstance(attr_type, Attribute):
+            attr_type, attr_value = attr_type.type, attr_type.attr_value
+        with self.rpc(RPC_FUNC_PKEY_SET_ATTRIBUTE, pkey, attr_type, attr_value):
+            return
+
+    def pkey_get_attribute(self, pkey, attr_type):
+        with self.rpc(RPC_FUNC_PKEY_GET_ATTRIBUTE, pkey, attr_type) as r:
+            return Attribute(attr_type, r.unpack_bytes())
+
+    def pkey_delete_attribute(self, pkey, attr_type):
+        with self.rpc(RPC_FUNC_PKEY_DELETE_ATTRIBUTE, pkey, attr_type):
+            return
 
 if __name__ == "__main__":
 
@@ -583,8 +602,6 @@ if __name__ == "__main__":
     k = hsm.pkey_generate_ec(HAL_CURVE_P256)
     print "{0.uuid} {0.key_type} {0.key_flags} {1}".format(k, hexstr(k.public_key))
     hsm.pkey_close(k)
-    k = hsm.pkey_find(k.uuid)
-    hsm.pkey_delete(k)
 
     for flags in (0, HAL_KEY_FLAG_TOKEN):
         for t, c, f, u in hsm.pkey_list(flags = flags):
@@ -593,3 +610,6 @@ if __name__ == "__main__":
     for f in (HAL_KEY_FLAG_TOKEN, 0):
         for u in hsm.pkey_match(flags = f):
             print u
+
+    k = hsm.pkey_find(k.uuid)
+    hsm.pkey_delete(k)

-- 
To stop receiving notification emails like this one, please contact
the administrator of this repository.


More information about the Commits mailing list