Skip to content

Commit

Permalink
Refactor KeyID handling. Allow most commands to take either ID or label.
Browse files Browse the repository at this point in the history
  • Loading branch information
elonen committed Jul 24, 2024
1 parent 58fe4fc commit 204b74b
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 85 deletions.
62 changes: 52 additions & 10 deletions hsm_secrets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from datetime import datetime
import os
import re
from pydantic import BaseModel, ConfigDict, HttpUrl, Field, StringConstraints
from typing_extensions import Annotated
from typing import List, Literal, NewType, Optional, Sequence, Union
Expand Down Expand Up @@ -49,6 +50,9 @@ def get_domain_bitfield(self, names: set['HSMDomainName']) -> int:
assert 0 <= res <= 0xFFFF, f"Domain bitfield out of range: {res}"
return res

def find_def(self, id_or_label: Union[int, str], enforce_type: Optional[type] = None) -> 'HSMDefBase':
return _find_def_by_id_or_label(self, id_or_label, enforce_type)

@staticmethod
def domain_bitfield_to_nums(bitfield: int) -> set['HSMDomainNum']:
return {i+1 for i in range(16) if bitfield & (1 << i)}
Expand Down Expand Up @@ -123,7 +127,7 @@ class General(NoExtraBaseModel):
x509_defaults: 'X509Info'


class HSMKeyBase(NoExtraBaseModel):
class HSMDefBase(NoExtraBaseModel):
model_config = ConfigDict(extra="forbid")
label: KeyLabel
id: KeyID
Expand All @@ -136,14 +140,14 @@ class HSMKeyBase(NoExtraBaseModel):
"none", "sign-pkcs", "sign-pss", "sign-ecdsa", "sign-eddsa", "decrypt-pkcs", "decrypt-oaep", "derive-ecdh",
"exportable-under-wrap", "sign-ssh-certificate", "sign-attestation-certificate"
]
class HSMAsymmetricKey(HSMKeyBase):
class HSMAsymmetricKey(HSMDefBase):
capabilities: set[AsymmetricCapabilityName]
algorithm: AsymmetricAlgorithm

# -- Symmetric key models --
SymmetricAlgorithm = Literal["aes128", "aes192", "aes256"]
SymmetricCapabilityName = Literal["none", "encrypt-ecb", "decrypt-ecb", "encrypt-cbc", "decrypt-cbc", "exportable-under-wrap"]
class HSMSymmetricKey(HSMKeyBase):
class HSMSymmetricKey(HSMDefBase):
capabilities: set[SymmetricCapabilityName]
algorithm: SymmetricAlgorithm

Expand All @@ -159,15 +163,15 @@ class HSMSymmetricKey(HSMKeyBase):
"put-opaque", "put-otp-aead-key", "put-template", "put-wrap-key", "randomize-otp-aead", "reset-device",
"rewrap-from-otp-aead-key", "rewrap-to-otp-aead-key", "set-option", "sign-attestation-certificate", "sign-ecdsa",
"sign-eddsa", "sign-hmac", "sign-pkcs", "sign-pss", "sign-ssh-certificate", "unwrap-data", "verify-hmac", "wrap-data"]
class HSMWrapKey(HSMKeyBase):
class HSMWrapKey(HSMDefBase):
capabilities: set[WrapCapabilityName]
delegated_capabilities: set[WrapDelegateCapabilityName]
algorithm: WrapAlgorithm

# -- HMAC key models --
HmacAlgorithm = Literal["hmac-sha1", "hmac-sha256", "hmac-sha384", "hmac-sha512"]
HmacCapabilityName = Literal["none", "sign-hmac", "verify-hmac", "exportable-under-wrap"]
class HSMHmacKey(HSMKeyBase):
class HSMHmacKey(HSMDefBase):
capabilities: set[HmacCapabilityName]
algorithm: HmacAlgorithm

Expand All @@ -194,13 +198,13 @@ class HSMHmacKey(HSMKeyBase):
"sign-eddsa", "sign-hmac", "sign-pkcs", "sign-pss", "sign-ssh-certificate", "unwrap-data", "verify-hmac", "wrap-data",
"decrypt-ecb", "encrypt-ecb", "decrypt-cbc", "encrypt-cbc",
]
class HSMAuthKey(HSMKeyBase):
class HSMAuthKey(HSMDefBase):
capabilities: set[AuthKeyCapabilityName]
delegated_capabilities: set[AuthKeyDelegatedCapabilityName]

# -- Opaque object models --
OpaqueObjectAlgorithm = Literal["opaque-data", "opaque-x509-certificate"]
class OpaqueObject(HSMKeyBase):
class HSMOpaqueObject(HSMDefBase):
algorithm: OpaqueObjectAlgorithm
sign_by: Optional[KeyID] # ID of the key to sign the object with (if applicable)

Expand Down Expand Up @@ -244,7 +248,7 @@ class X509Info(NoExtraBaseModel):
class X509Cert(NoExtraBaseModel):
key: HSMAsymmetricKey
x509_info: Optional[X509Info] = Field(default=None) # If None, use the default values from the global configuration (applies to sub-fields, too)
signed_certs: List[OpaqueObject] = Field(default_factory=list) # Storage for signed certificates
signed_certs: List[HSMOpaqueObject] = Field(default_factory=list) # Storage for signed certificates


# ----- Subsystem models -----
Expand Down Expand Up @@ -348,6 +352,44 @@ def find_instances(obj: Any, target_type: Type[T]) -> Generator[T, None, None]:
return list(find_instances(conf, cls))


def parse_keyid(key_id: str) -> int:
"""
Parse a key ID from a string in the format '0x1234'.
:raises ValueError: If the key ID is not a hexadecimal number with the '0x' prefix.
"""
if not key_id.startswith('0x'):
raise ValueError(f"Key ID '{key_id}' must be a hexadecimal number with the '0x' prefix.")
return int(key_id.replace('0x',''), 16)


def _find_def_by_id_or_label(conf: HSMConfig, id_or_label: Union[int, str], enforce_type: Optional[type] = None) -> HSMDefBase:
"""
Find the configuration object for a given key ID or label.
:raises KeyError: If the key is not found in the configuration file.
"""
# Check and parse the id/label
id = None
if isinstance(id_or_label, str):
if re.match(r'^0x[0-9a-fA-F]+$', id_or_label.strip()):
id = parse_keyid(id_or_label)
elif id_or_label.isdigit():
raise ValueError(f"Key ID ('{id_or_label}') must be a hexadecimal number with the '0x' prefix.")
elif isinstance(id_or_label, int):
id = id_or_label
if id <= 0 or id >= 0xFFFF:
raise ValueError(f"Key ID '{id}' is out of range (16 bit unsigned integer).")

# Search by ID or label
for t in [HSMAsymmetricKey, HSMSymmetricKey, HSMWrapKey, HSMHmacKey, HSMAuthKey, HSMOpaqueObject]:
for key in find_config_items_of_class(conf, t):
if (id and key.id == id) or key.label == id_or_label:
if enforce_type and not isinstance(key, enforce_type):
raise ValueError(f"Key '{id_or_label}' is not of the expected type '{enforce_type.__name__}'.")
return key

raise KeyError(f"Key with ID or label '{id_or_label}' not found in the configuration file.")



def find_all_config_items_per_type(conf: HSMConfig) -> tuple[dict, dict]:
"""
Expand All @@ -356,14 +398,14 @@ def find_all_config_items_per_type(conf: HSMConfig) -> tuple[dict, dict]:
"""
import yubihsm.objects # type: ignore [import]

from hsm_secrets.config import HSMAsymmetricKey, HSMSymmetricKey, HSMWrapKey, OpaqueObject, HSMHmacKey, HSMAuthKey
from hsm_secrets.config import HSMAsymmetricKey, HSMSymmetricKey, HSMWrapKey, HSMOpaqueObject, HSMHmacKey, HSMAuthKey
config_to_hsm_type = {
HSMAuthKey: yubihsm.objects.AuthenticationKey,
HSMWrapKey: yubihsm.objects.WrapKey,
HSMHmacKey: yubihsm.objects.HmacKey,
HSMSymmetricKey: yubihsm.objects.SymmetricKey,
HSMAsymmetricKey: yubihsm.objects.AsymmetricKey,
OpaqueObject: yubihsm.objects.Opaque,
HSMOpaqueObject: yubihsm.objects.Opaque,
}
config_items_per_type: dict = {t: find_config_items_of_class(conf, t) for t in config_to_hsm_type.keys()} # type: ignore
return config_items_per_type, config_to_hsm_type
38 changes: 21 additions & 17 deletions hsm_secrets/hsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tarfile
import click

from hsm_secrets.config import HSMConfig, find_all_config_items_per_type
from hsm_secrets.config import HSMAsymmetricKey, HSMConfig, find_all_config_items_per_type, parse_keyid
from hsm_secrets.hsm.secret_sharing_ceremony import cli_reconstruction_ceremony, cli_splitting_ceremony
from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_error, cli_info, cli_result, cli_ui_msg, cli_warn, hsm_generate_asymmetric_key, hsm_generate_hmac_key, hsm_generate_symmetric_key, hsm_obj_exists, hsm_put_derived_auth_key, hsm_put_wrap_key, open_hsm_session, open_hsm_session_with_password, pass_common_args, pretty_fmt_yubihsm_object, prompt_for_secret, pw_check_fromhex

Expand Down Expand Up @@ -249,15 +249,15 @@ def make_wrap_key(ctx: HsmSecretsCtx):
# ---------------

@cmd_hsm.command('delete-object')
@click.argument('cert_ids', nargs=-1, type=str, metavar='<id>...')
@click.argument('obj_ids', nargs=-1, type=str, metavar='<id|label> ...')
@click.option('--alldevs', is_flag=True, help="Delete on all devices")
@click.option('--force', is_flag=True, help="Force deletion without confirmation (use with caution)")
@pass_common_args
def delete_object(ctx: HsmSecretsCtx, cert_ids: tuple, alldevs: bool, force: bool):
"""Delete an object from the YubiHSM
def delete_object(ctx: HsmSecretsCtx, obj_ids: tuple, alldevs: bool, force: bool):
"""Delete object(s) from the YubiHSM
Deletes an object with the given ID from the YubiHSM.
YubiHSM2 identifies objects by type in addition to ID, so the command
Deletes an object(s) with the given ID or label from the YubiHSM.
YubiHSM2 can have the same id for different types of objects, so this command
asks you to confirm the type of the object before deleting it.
With `--force` ALL objects with the given ID will be deleted
Expand All @@ -266,13 +266,17 @@ def delete_object(ctx: HsmSecretsCtx, cert_ids: tuple, alldevs: bool, force: boo
hsm_serials = ctx.conf.general.all_devices.keys() if alldevs else [ctx.hsm_serial]
for serial in hsm_serials:
with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN, serial) as ses:
not_found = set(cert_ids)
for id in cert_ids:
id_int = int(id.replace('0x', ''), 16)
not_found = set(obj_ids)
for id_or_label in obj_ids:
try:
id_int = ctx.conf.find_def(id_or_label).id
except KeyError:
cli_warn(f"Object '{id_or_label}' not found in the configuration file. Assuming it's raw ID on the device.")
id_int = parse_keyid(id_or_label)
objects = ses.list_objects()
for o in objects:
if o.id == id_int:
not_found.remove(id)
not_found.remove(id_or_label)
if not force:
cli_ui_msg("Object found:")
cli_ui_msg(pretty_fmt_yubihsm_object(o))
Expand Down Expand Up @@ -331,8 +335,8 @@ def compare_config(ctx: HsmSecretsCtx, alldevs: bool, create: bool):
if create:
need_create = obj is None
if need_create:
from hsm_secrets.config import HSMAsymmetricKey, HSMSymmetricKey, HSMWrapKey, OpaqueObject, HSMHmacKey, HSMAuthKey
unsupported_types = (HSMWrapKey, HSMAuthKey, OpaqueObject)
from hsm_secrets.config import HSMAsymmetricKey, HSMSymmetricKey, HSMWrapKey, HSMOpaqueObject, HSMHmacKey, HSMAuthKey
unsupported_types = (HSMWrapKey, HSMAuthKey, HSMOpaqueObject)

gear_emoji = click.style("⚙️", fg='cyan')

Expand Down Expand Up @@ -374,22 +378,22 @@ def compare_config(ctx: HsmSecretsCtx, alldevs: bool, create: bool):

@cmd_hsm.command('attest-key')
@pass_common_args
@click.argument('cert_id', required=True, type=str, metavar='<id>')
@click.argument('obj_id', required=True, type=str, metavar='<id|label>')
@click.option('--out', '-o', type=click.File('w', encoding='utf8'), help='Output file (default: stdout)', default=click.get_text_stream('stdout'))
def attest_key(ctx: HsmSecretsCtx, cert_id: str, out: click.File):
def attest_key(ctx: HsmSecretsCtx, obj_id: str, out: click.File):
"""Attest an asymmetric key in the YubiHSM
Create an a key attestation certificate, signed by the
Yubico attestation key, for the given key ID (in hex).
"""
from cryptography.hazmat.primitives.serialization import Encoding
id = ctx.conf.find_def(obj_id, HSMAsymmetricKey).id

id = int(cert_id.replace('0x', ''), 16)
with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN, ctx.hsm_serial) as ses:
key = ses.get_object(id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY)
assert isinstance(key, yubihsm.objects.AsymmetricKey)
if not hsm_obj_exists(key):
raise click.ClickException(f"Key with ID 0x{id:04x} not found in the YubiHSM.")
raise click.ClickException(f"Asymmetric key 0x{id:04x} not found in the YubiHSM.")
cert = key.attest()
pem = cert.public_bytes(Encoding.PEM).decode('UTF-8')
out.write(pem) # type: ignore
Expand Down Expand Up @@ -493,7 +497,7 @@ def restore_hsm(ctx: HsmSecretsCtx, backup_file: str, force: bool):
name = tarinfo.name
assert name.endswith('.bin'), f"Unexpected file extension in tar archive: '{name}'"
assert name.count('--') == 2, f"Unexpected file name format in tar archive: '{name}'"
obj_id = int(name.split('--')[1].replace('0x', ''), 16)
obj_id = parse_keyid(name.split('--')[1])
obj_type = name.split('--')[0]

cli_info(f"- Importing object from '{tarinfo.name}'...")
Expand Down
8 changes: 4 additions & 4 deletions hsm_secrets/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hsm_secrets.ssh import cmd_ssh
from hsm_secrets.tls import cmd_tls
from hsm_secrets.passwd import cmd_pass
from hsm_secrets.config import HSMConfig, load_hsm_config
from hsm_secrets.config import HSMAuthKey, load_hsm_config
from hsm_secrets.user import cmd_user
from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_warn, list_yubikey_hsm_creds, pass_common_args, cli_info
from hsm_secrets.x509 import cmd_x509
Expand Down Expand Up @@ -40,8 +40,8 @@ def cli(ctx: click.Context, config: str|None, quiet: bool, yklabel: str|None, hs
--auth-default-admin: Use insecure default auth key (see config).
--auth-password-id <ID>: Use password from environment variable
HSM_PASSWORD with the specified auth key ID (hex).
--auth-password-id <id|label>: Use password from environment variable
HSM_PASSWORD with the specified auth key ID (hex) or label.
"""
ctx.obj = {'quiet': quiet} # early setup for cli_info and other utils to work

Expand Down Expand Up @@ -82,7 +82,7 @@ def cli(ctx: click.Context, config: str|None, quiet: bool, yklabel: str|None, hs
'quiet': quiet,
'hsmserial': hsmserial or conf.general.master_device,
'forced_auth_method': None,
'auth_password_id': int(auth_password_id.replace('0x', ''), 16) if auth_password_id else None,
'auth_password_id': conf.find_def(auth_password_id, HSMAuthKey).id if auth_password_id else None,
'auth_password': os.getenv("HSM_PASSWORD", None),
}

Expand Down
6 changes: 3 additions & 3 deletions hsm_secrets/ssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Sequence
import click

from hsm_secrets.config import HSMConfig
from hsm_secrets.config import HSMAsymmetricKey, HSMConfig
from hsm_secrets.utils import HsmSecretsCtx, cli_code_info, cli_result, cli_warn, open_hsm_session, pass_common_args
from cryptography.hazmat.primitives import _serialization

Expand Down Expand Up @@ -48,7 +48,7 @@ def get_ca(ctx: HsmSecretsCtx, get_all: bool, cert_ids: Sequence[str]):

@cmd_ssh.command('sign-key')
@click.option('--out', '-o', type=click.Path(exists=False, dir_okay=False, resolve_path=True, allow_dash=True), help="Output file (default: deduce from input)", default=None)
@click.option('--ca', '-c', required=False, help="CA key ID (hex) to sign with. Default: read from config", default=None)
@click.option('--ca', '-c', required=False, help="CA key ID (hex) or label to sign with. Default: read from config", default=None)
@click.option('--username', '-u', required=False, help="Key owner's name (for auditing)", default=None)
@click.option('--certid', '-n', required=False, help="Explicit certificate ID (default: auto-generated)", default=None)
@click.option('--validity', '-t', required=False, default=365*24*60*60, help="Validity period in seconds (default: 1 year)")
Expand All @@ -74,7 +74,7 @@ def sign_key(ctx: HsmSecretsCtx, out: str, ca: str|None, username: str|None, cer
from hsm_secrets.ssh.openssh.ssh_certificate import cert_for_ssh_pub_id, str_to_extension
from hsm_secrets.key_adapters import make_private_key_adapter

ca_key_id = int(ca.replace('0x',''), 16) if ca else ctx.conf.ssh.default_ca
ca_key_id = ctx.conf.find_def(ca, HSMAsymmetricKey).id if ca else ctx.conf.ssh.default_ca

ca_def = [c for c in ctx.conf.ssh.root_ca_keys if c.id == ca_key_id]
if not ca_def:
Expand Down
10 changes: 5 additions & 5 deletions hsm_secrets/tls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import yubihsm.defs # type: ignore [import]
import yubihsm.objects # type: ignore [import]

from hsm_secrets.config import X509CertAttribs, X509Info
from hsm_secrets.config import HSMOpaqueObject, X509CertAttribs, X509Info
from hsm_secrets.key_adapters import PrivateKey, make_private_key_adapter
from hsm_secrets.utils import HsmSecretsCtx, cli_code_info, cli_info, cli_ui_msg, cli_warn, hsm_obj_exists, open_hsm_session, open_hsm_session_with_yubikey, pass_common_args
from hsm_secrets.x509.cert_builder import X509CertBuilder
Expand All @@ -31,7 +31,7 @@ def cmd_tls(ctx: click.Context):
@click.option('--san-ip', '-i', multiple=True, help="IP SAN (Subject Alternative Name)")
@click.option('--validity', '-v', default=365, help="Validity period in days")
@click.option('--keyfmt', '-f', type=click.Choice(['rsa4096', 'ed25519', 'ecp256', 'ecp384']), default='ecp384', help="Key format")
@click.option('--sign-crt', '-s', type=str, required=False, help="CA ID (hex) to sign with, or 'self'. Default: use config", default=None)
@click.option('--sign-crt', '-s', type=str, required=False, help="CA ID (hex) or label to sign with, or 'self'. Default: use config", default=None)
def new_http_server_cert(ctx: HsmSecretsCtx, out: click.Path, common_name: str, san_dns: list[str], san_ip: list[str], validity: int, keyfmt: str, sign_crt: str):
"""Create a TLS server certificate + key
Expand All @@ -48,7 +48,7 @@ def new_http_server_cert(ctx: HsmSecretsCtx, out: click.Path, common_name: str,
issuer_x509_def = None
issuer_cert_id = -1
if (sign_crt or '').strip().lower() != 'self':
issuer_cert_id = int(sign_crt.replace('0x',''), 16) if sign_crt else ctx.conf.tls.default_ca_id
issuer_cert_id = ctx.conf.find_def(sign_crt, HSMOpaqueObject).id if sign_crt else ctx.conf.tls.default_ca_id
issuer_x509_def = find_cert_def(ctx.conf, issuer_cert_id)
assert issuer_x509_def, f"CA cert ID not found: 0x{issuer_cert_id:04x}"

Expand Down Expand Up @@ -141,7 +141,7 @@ def new_http_server_cert(ctx: HsmSecretsCtx, out: click.Path, common_name: str,
@pass_common_args
@click.argument('csr', type=click.Path(exists=False, dir_okay=False, resolve_path=True, allow_dash=True), default='-', required=True, metavar='<csr-file>')
@click.option('--out', '-o', required=False, type=click.Path(exists=False, dir_okay=False, resolve_path=True), help="Output filename (default: deduce from input)", default=None)
@click.option('--ca', '-c', type=str, required=False, help="CA ID (hex) to sign with. Default: use config", default=None)
@click.option('--ca', '-c', type=str, required=False, help="CA ID (hex) or label to sign with. Default: use config", default=None)
@click.option('--validity', '-v', default=365, help="Validity period in days")
def sign_csr(ctx: HsmSecretsCtx, csr: click.Path, out: click.Path|None, ca: str|None, validity: int):
"""Sign a CSR with a CA key
Expand All @@ -160,7 +160,7 @@ def sign_csr(ctx: HsmSecretsCtx, csr: click.Path, out: click.Path|None, ca: str|
csr_obj = cryptography.x509.load_pem_x509_csr(csr_data)

# Find the issuer CA definition
issuer_cert_id = int(ca.replace('0x',''), 16) if ca else ctx.conf.tls.default_ca_id
issuer_cert_id = ctx.conf.find_def(ca, HSMOpaqueObject).id if ca else ctx.conf.tls.default_ca_id
issuer_x509_def = find_cert_def(ctx.conf, issuer_cert_id)
assert issuer_x509_def, f"CA cert ID not found: 0x{issuer_cert_id:04x}"

Expand Down
Loading

0 comments on commit 204b74b

Please sign in to comment.