diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 0000000..638f254 --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,28 @@ +name: Python Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.12' + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y make openssh-client openssl libpcsclite-dev + + - name: Run tests + run: | + make test diff --git a/.gitignore b/.gitignore index d4c9d49..cc613bd 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ build/ .DS_Store +*.pickle + _venv .vscode __pycache__ diff --git a/Makefile b/Makefile index 41ef1d8..459da80 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: setup install clean distclean package +.PHONY: setup install clean distclean package test # Configurable paths and settings VENV := _venv @@ -37,6 +37,9 @@ $(VENV): requirements.txt $(PIP) install build @touch $(VENV) +test: $(TARGET_BINS) + ./run-tests.sh + clean: @echo "Cleaning up build and Python file artifacts..." @rm -rf $(VENV) diff --git a/hsm_secrets/config.py b/hsm_secrets/config.py index 5b40e4a..428aab2 100644 --- a/hsm_secrets/config.py +++ b/hsm_secrets/config.py @@ -50,7 +50,7 @@ def get_domain_bitfield(self, names: set['HSMDomainName']) -> int: assert 0 <= res <= 0xFFFF, f"Domain bitfield out of range: {res}" return res - def find_def(self, id_or_label: Union[int, str], enforce_type: Optional[type] = None) -> 'HSMDefBase': + def find_def(self, id_or_label: Union[int, str], enforce_type: Optional[type] = None) -> 'HSMObjBase': return _find_def_by_id_or_label(self, id_or_label, enforce_type) @staticmethod @@ -102,8 +102,8 @@ def algorithm_from_name(algo: Union['AsymmetricAlgorithm', 'SymmetricAlgorithm', # Some type definitions for the models -KeyID = Annotated[int, Field(strict=True, gt=0, lt=0xFFFF)] -KeyLabel = Annotated[str, Field(max_length=40)] +HSMKeyID = Annotated[int, Field(strict=True, gt=0, lt=0xFFFF)] +HSMKeyLabel = Annotated[str, Field(max_length=40)] HSMDomainNum = Annotated[int, Field(strict=True, gt=0, lt=17)] HSMDomainName = Literal["all", "x509", "tls", "nac", "gpg", "codesign", "ssh", "password_derivation", "encryption"] @@ -127,10 +127,10 @@ class General(NoExtraBaseModel): x509_defaults: 'X509Info' -class HSMDefBase(NoExtraBaseModel): +class HSMObjBase(NoExtraBaseModel): model_config = ConfigDict(extra="forbid") - label: KeyLabel - id: KeyID + label: HSMKeyLabel + id: HSMKeyID domains: set[HSMDomainName] @@ -140,14 +140,14 @@ class HSMDefBase(NoExtraBaseModel): "none", "sign-pkcs", "sign-pss", "sign-ecdsa", "sign-eddsa", "decrypt-pkcs", "decrypt-oaep", "derive-ecdh", "exportable-under-wrap", "sign-ssh-certificate", "sign-attestation-certificate" ] -class HSMAsymmetricKey(HSMDefBase): +class HSMAsymmetricKey(HSMObjBase): capabilities: set[AsymmetricCapabilityName] algorithm: AsymmetricAlgorithm # -- Symmetric key models -- SymmetricAlgorithm = Literal["aes128", "aes192", "aes256"] SymmetricCapabilityName = Literal["none", "encrypt-ecb", "decrypt-ecb", "encrypt-cbc", "decrypt-cbc", "exportable-under-wrap"] -class HSMSymmetricKey(HSMDefBase): +class HSMSymmetricKey(HSMObjBase): capabilities: set[SymmetricCapabilityName] algorithm: SymmetricAlgorithm @@ -163,7 +163,7 @@ class HSMSymmetricKey(HSMDefBase): "put-opaque", "put-otp-aead-key", "put-template", "put-wrap-key", "randomize-otp-aead", "reset-device", "rewrap-from-otp-aead-key", "rewrap-to-otp-aead-key", "set-option", "sign-attestation-certificate", "sign-ecdsa", "sign-eddsa", "sign-hmac", "sign-pkcs", "sign-pss", "sign-ssh-certificate", "unwrap-data", "verify-hmac", "wrap-data"] -class HSMWrapKey(HSMDefBase): +class HSMWrapKey(HSMObjBase): capabilities: set[WrapCapabilityName] delegated_capabilities: set[WrapDelegateCapabilityName] algorithm: WrapAlgorithm @@ -171,7 +171,7 @@ class HSMWrapKey(HSMDefBase): # -- HMAC key models -- HmacAlgorithm = Literal["hmac-sha1", "hmac-sha256", "hmac-sha384", "hmac-sha512"] HmacCapabilityName = Literal["none", "sign-hmac", "verify-hmac", "exportable-under-wrap"] -class HSMHmacKey(HSMDefBase): +class HSMHmacKey(HSMObjBase): capabilities: set[HmacCapabilityName] algorithm: HmacAlgorithm @@ -198,15 +198,15 @@ class HSMHmacKey(HSMDefBase): "sign-eddsa", "sign-hmac", "sign-pkcs", "sign-pss", "sign-ssh-certificate", "unwrap-data", "verify-hmac", "wrap-data", "decrypt-ecb", "encrypt-ecb", "decrypt-cbc", "encrypt-cbc", ] -class HSMAuthKey(HSMDefBase): +class HSMAuthKey(HSMObjBase): capabilities: set[AuthKeyCapabilityName] delegated_capabilities: set[AuthKeyDelegatedCapabilityName] # -- Opaque object models -- OpaqueObjectAlgorithm = Literal["opaque-data", "opaque-x509-certificate"] -class HSMOpaqueObject(HSMDefBase): +class HSMOpaqueObject(HSMObjBase): algorithm: OpaqueObjectAlgorithm - sign_by: Optional[KeyID] # ID of the key to sign the object with (if applicable) + sign_by: Optional[HSMKeyID] # ID of the key to sign the object with (if applicable) # -- Helper models -- X509KeyUsage = Literal[ @@ -263,7 +263,7 @@ class X509(NoExtraBaseModel): root_certs: List[X509Cert] class TLS(NoExtraBaseModel): - default_ca_id: KeyID + default_ca_id: HSMKeyID intermediate_certs: List[X509Cert] class NAC(NoExtraBaseModel): @@ -280,7 +280,7 @@ class SSHTemplateSlots(NoExtraBaseModel): max: int class SSH(NoExtraBaseModel): - default_ca: KeyID + default_ca: HSMKeyID root_ca_keys: List[HSMAsymmetricKey] @@ -290,8 +290,8 @@ class PwRotationToken(NoExtraBaseModel): ts: Annotated[int, Field(strict=True, ge=0)] class PasswordDerivationRule(NoExtraBaseModel): - id: KeyLabel - key: KeyID + id: HSMKeyLabel + key: HSMKeyID format: Literal["bip39", "hex"] = Field(default="bip39") separator: str = Field(default=".") bits: Literal[64, 128, 256] = Field(default=64) @@ -299,7 +299,7 @@ class PasswordDerivationRule(NoExtraBaseModel): class PasswordDerivation(NoExtraBaseModel): keys: List[HSMHmacKey] - default_rule: KeyLabel + default_rule: HSMKeyLabel rules: List[PasswordDerivationRule] @@ -362,7 +362,7 @@ def parse_keyid(key_id: str) -> int: return int(key_id.replace('0x',''), 16) -def _find_def_by_id_or_label(conf: HSMConfig, id_or_label: Union[int, str], enforce_type: Optional[type] = None) -> HSMDefBase: +def _find_def_by_id_or_label(conf: HSMConfig, id_or_label: int|str, enforce_type: type|None = None) -> HSMObjBase: """ Find the configuration object for a given key ID or label. :raises KeyError: If the key is not found in the configuration file. diff --git a/hsm_secrets/hsm/__init__.py b/hsm_secrets/hsm/__init__.py index 6c86d83..859635b 100644 --- a/hsm_secrets/hsm/__init__.py +++ b/hsm_secrets/hsm/__init__.py @@ -7,14 +7,19 @@ from hsm_secrets.config import HSMAsymmetricKey, HSMConfig, find_all_config_items_per_type, parse_keyid from hsm_secrets.hsm.secret_sharing_ceremony import cli_reconstruction_ceremony, cli_splitting_ceremony -from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_error, cli_info, cli_result, cli_ui_msg, cli_warn, hsm_generate_asymmetric_key, hsm_generate_hmac_key, hsm_generate_symmetric_key, hsm_obj_exists, hsm_put_derived_auth_key, hsm_put_wrap_key, open_hsm_session, open_hsm_session_with_password, pass_common_args, pretty_fmt_yubihsm_object, prompt_for_secret, pw_check_fromhex +from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_error, cli_info, cli_result, cli_ui_msg, cli_warn, confirm_and_delete_old_yubihsm_object_if_exists, open_hsm_session, open_hsm_session_with_password, pass_common_args, pretty_fmt_yubihsm_object, prompt_for_secret, pw_check_fromhex import yubihsm.defs, yubihsm.exceptions, yubihsm.objects # type: ignore [import] from yubihsm.core import AuthSession # type: ignore [import] +from yubihsm.defs import OBJECT # type: ignore [import] from click import style -def swear_you_are_on_airgapped_computer(): +from hsm_secrets.yubihsm import HSMSession, MockYhsmObject + +def swear_you_are_on_airgapped_computer(quiet: bool): + if quiet: + return cli_ui_msg(style(r""" +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | @@ -84,7 +89,7 @@ def list_objects(ctx: HsmSecretsCtx, alldevs: bool): cli_info(f"YubiHSM Objects on device {serial}:") cli_info("") for o in ses.list_objects(): - cli_result(pretty_fmt_yubihsm_object(o)) + cli_result(pretty_fmt_yubihsm_object(o.get_info())) cli_result("") # --------------- @@ -99,10 +104,10 @@ def default_admin_enable(ctx: HsmSecretsCtx, use_backup_secret: bool, alldevs: b Using either a shared secret or a backup secret, (re-)create the default admin key on the YubiHSM(s). This is a temporary key that should be removed after the management operations are complete. """ - swear_you_are_on_airgapped_computer() + swear_you_are_on_airgapped_computer(ctx.quiet) - def do_it(conf: HSMConfig, ses: AuthSession, serial: str): - obj = hsm_put_derived_auth_key(ses, serial, conf, conf.admin.default_admin_key, conf.admin.default_admin_password) + def do_it(conf: HSMConfig, ses: HSMSession): + obj = ses.auth_key_put_derived(conf.admin.default_admin_key, conf.admin.default_admin_password) cli_ui_msg(f"OK. Default insecure admin key (0x{obj.id:04x}: '{conf.admin.default_admin_password}') added successfully.") cli_ui_msg("!!! DON'T FORGET TO REMOVE IT after you're done with the management operations.") @@ -129,14 +134,14 @@ def do_it(conf: HSMConfig, ses: AuthSession, serial: str): hsm_serials = ctx.conf.general.all_devices.keys() if alldevs else [ctx.hsm_serial] for serial in hsm_serials: try: - if not ctx.forced_auth_method: + if not ctx.forced_auth_method and not ctx.mock_file: assert password is not None shared_key_id = ctx.conf.admin.shared_admin_key.id with open_hsm_session_with_password(ctx, shared_key_id, password, device_serial=serial ) as ses: - do_it(ctx.conf, ses, serial) + do_it(ctx.conf, ses) else: with open_hsm_session(ctx, device_serial=serial) as ses: - do_it(ctx.conf, ses, serial) + do_it(ctx.conf, ses) except yubihsm.exceptions.YubiHsmAuthenticationError as e: raise click.ClickException("Failed to authenticate with the provided password.") @@ -156,26 +161,22 @@ def default_admin_disable(ctx: HsmSecretsCtx, alldevs: bool, force: bool): hsm_serials = ctx.conf.general.all_devices.keys() if alldevs else [ctx.hsm_serial] for serial in hsm_serials: with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN, serial) as ses: - default_key = ses.get_object(ctx.conf.admin.default_admin_key.id, yubihsm.defs.OBJECT.AUTHENTICATION_KEY) - assert isinstance(default_key, yubihsm.objects.AuthenticationKey) + keydef = ctx.conf.admin.default_admin_key - if hsm_obj_exists(default_key): + if ses.object_exists(keydef): # Check that shared admin key exists before removing the default one if not force: - shared_key = ses.get_object(ctx.conf.admin.shared_admin_key.id, yubihsm.defs.OBJECT.AUTHENTICATION_KEY) - assert isinstance(shared_key, yubihsm.objects.AuthenticationKey) - if not hsm_obj_exists(shared_key): + if not ses.object_exists(ctx.conf.admin.shared_admin_key): raise click.ClickException(f"Shared admin key not found on device {serial}. You could lose access to the device, so refusing the operation (use --force to override).") - # Ok, it does, we can proceed - default_key.delete() + ses.delete_object(keydef) cli_info(f"Ok. Default admin key removed on device {serial}.") else: cli_warn(f"Default admin key not found on device {serial}. Skipping.") # Make sure it's really gone try: - if hsm_obj_exists(default_key): + if ses.object_exists(keydef): cli_error(f"ERROR!!! Default admin key still exists on device {serial}. Don't leave the airgapped session before removing it.") click.pause("Press ENTER to continue.", err=True) raise click.Abort() @@ -202,16 +203,19 @@ def make_shared_admin_key(ctx: HsmSecretsCtx, num_shares: int, threshold: int, s This is a very heavy process, and should be only done once, on the master YubiHSM. The resulting key can then be cloned to other devices via key wrapping operations. """ - swear_you_are_on_airgapped_computer() + swear_you_are_on_airgapped_computer(ctx.quiet) with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN) as ses: def apply_password_fn(new_password: str): - hsm_put_derived_auth_key(ses, ctx.hsm_serial, ctx.conf, ctx.conf.admin.shared_admin_key, new_password) + confirm_and_delete_old_yubihsm_object_if_exists(ses, ctx.conf.admin.shared_admin_key.id, yubihsm.defs.OBJECT.AUTHENTICATION_KEY) + info = ses.auth_key_put_derived(ctx.conf.admin.shared_admin_key, new_password) + cli_info(f"Auth key ID '{hex(info.id)}' ({info.label}) stored in YubiHSM device {ses.get_serial()}") + if skip_ceremony: apply_password_fn(prompt_for_secret("Enter the (new) shared admin password to store", confirm=True)) else: secret = ses.get_pseudo_random(256//8) - cli_splitting_ceremony(num_shares, threshold, apply_password_fn, pre_secret=secret) + cli_splitting_ceremony(threshold, num_shares, apply_password_fn, pre_secret=secret) cli_info("OK. Shared admin key added successfully.") @@ -232,7 +236,7 @@ def make_wrap_key(ctx: HsmSecretsCtx): hsm_serials = ctx.conf.general.all_devices.keys() assert len(hsm_serials) > 0, "No devices found in the configuration file." - swear_you_are_on_airgapped_computer() + swear_you_are_on_airgapped_computer(ctx.quiet) with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN) as ses: cli_info("Generating secret on master device...") @@ -242,8 +246,11 @@ def make_wrap_key(ctx: HsmSecretsCtx): cli_info("") for serial in hsm_serials: - with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN) as ses: - hsm_put_wrap_key(ses, serial, ctx.conf, ctx.conf.admin.wrap_key, secret) + with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN, device_serial=serial) as ses: + confirm_and_delete_old_yubihsm_object_if_exists(ses, ctx.conf.admin.wrap_key.id, yubihsm.defs.OBJECT.WRAP_KEY) + res = ses.put_wrap_key(ctx.conf.admin.wrap_key, secret) + cli_info(f"Wrap key ID '{hex(res.id)}' stored in YubiHSM device {ses.get_serial()}") + del secret cli_info(f"OK. Common wrap key added to all devices (serials: {', '.join(hsm_serials)}).") @@ -281,7 +288,7 @@ def delete_object(ctx: HsmSecretsCtx, obj_ids: tuple, alldevs: bool, force: bool not_found.remove(id_or_label) if not force: cli_ui_msg("Object found:") - cli_ui_msg(pretty_fmt_yubihsm_object(o)) + cli_ui_msg(pretty_fmt_yubihsm_object(o.get_info())) click.confirm("Delete this object?", default=False, abort=True, err=True) o.delete() cli_info("Object deleted.") @@ -324,11 +331,12 @@ def compare_config(ctx: HsmSecretsCtx, alldevs: bool, create: bool): n_created, n_skipped = 0, 0 for t, items in config_items_per_type.items(): + items = sorted(items, key=lambda x: x.id) cli_result(f"{t.__name__}") for it in items: - obj: yubihsm.objects.YhsmObject|None = None + obj: yubihsm.objects.YhsmObject|MockYhsmObject|None = None for o in device_objs: - if o.id == it.id and isinstance(o, config_to_hsm_type[t]): + if o.id == it.id and (o.object_type == config_to_hsm_type[t].object_type): obj = o objects_accounted_for[o.id] = True break @@ -346,17 +354,36 @@ def compare_config(ctx: HsmSecretsCtx, alldevs: bool, create: bool): warn_emoji = click.style("⚠️", fg='yellow') cli_result(f" └-> {warn_emoji} Cannot create '{it.__class__.__name__}' objects. Use other commands.") n_skipped += 1 + elif isinstance(it, HSMAsymmetricKey): cli_result(f" └-> {gear_emoji} Creating...") - hsm_generate_asymmetric_key(ses, serial, ctx.conf, it) + confirm_and_delete_old_yubihsm_object_if_exists(ses, it.id, OBJECT.ASYMMETRIC_KEY) + cli_info(f"Generating asymmetric key, type '{it.algorithm}'...") + if 'rsa' in it.algorithm.lower(): + cli_warn(" Note! RSA key generation is very slow. Please wait. The YubiHSM2 should blinking rapidly while it works.") + cli_warn(" If the process aborts / times out, you can rerun this command to resume.") + ses.asym_key_generate(it) + cli_info(f"Symmetric key ID '{hex(it.id)}' ({it.label}) stored in YubiHSM device {ses.get_serial()}") n_created += 1 + elif isinstance(it, HSMSymmetricKey): cli_result(f" └-> {gear_emoji} Creating...") - hsm_generate_symmetric_key(ses, serial, ctx.conf, it) + confirm_and_delete_old_yubihsm_object_if_exists(ses, it.id, OBJECT.SYMMETRIC_KEY) + cli_info(f"Generating symmetric key, type '{it.algorithm}'...") + ses.sym_key_generate(it) + cli_info(f"Symmetric key ID '{hex(it.id)}' ({it.label}) generated in YubiHSM device {ses.get_serial()}") + n_created += 1 elif isinstance(it, HSMHmacKey): cli_result(f" └-> {gear_emoji} Creating...") - hsm_generate_hmac_key(ses, serial, ctx.conf, it) + print("...") + confirm_and_delete_old_yubihsm_object_if_exists(ses, it.id, OBJECT.HMAC_KEY) + cli_info(f"Generating HMAC key, type '{it.algorithm}'...") + print("a") + ses.hmac_key_generate(it) + print("b") + cli_info(f"HMAC key ID '{hex(it.id)}' ({it.label}) stored in YubiHSM device {ses.get_serial()}") + print("c") n_created += 1 else: cli_result(click.style(f" └-> Unsupported object type: {it.__class__.__name__}. This is a bug. SKIPPING.", fg='red')) @@ -367,7 +394,7 @@ def compare_config(ctx: HsmSecretsCtx, alldevs: bool, create: bool): for o in device_objs: if o.id not in objects_accounted_for: info = o.get_info() - cli_result(f" ??? '{str(info.label)}' (0x{o.id:04x}) <{o.__class__.__name__}>") + cli_result(f" ??? '{str(info.label)}' (0x{o.id:04x}) <{o.object_type.name}>") if create: cli_info("") @@ -380,23 +407,18 @@ def compare_config(ctx: HsmSecretsCtx, alldevs: bool, create: bool): @cmd_hsm.command('attest') @pass_common_args -@click.argument('obj_id', required=True, type=str, metavar='') +@click.argument('key_id', required=True, type=str, metavar='') @click.option('--out', '-o', type=click.File('w', encoding='utf8'), help='Output file (default: stdout)', default=click.get_text_stream('stdout')) -def attest_key(ctx: HsmSecretsCtx, obj_id: str, out: click.File): +def attest_key(ctx: HsmSecretsCtx, key_id: str, out: click.File): """Attest an asymmetric key in the YubiHSM Create an a key attestation certificate, signed by the Yubico attestation key, for the given key ID (in hex). """ from cryptography.hazmat.primitives.serialization import Encoding - id = ctx.conf.find_def(obj_id, HSMAsymmetricKey).id - + id = ctx.conf.find_def(key_id, HSMAsymmetricKey).id with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN, ctx.hsm_serial) as ses: - key = ses.get_object(id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY) - assert isinstance(key, yubihsm.objects.AsymmetricKey) - if not hsm_obj_exists(key): - raise click.ClickException(f"Asymmetric key 0x{id:04x} not found in the YubiHSM.") - cert = key.attest() + cert = ses.attest_asym_key(id) pem = cert.public_bytes(Encoding.PEM).decode('UTF-8') out.write(pem) # type: ignore cli_info(f"Key 0x{id:04x} attestation certificate written to '{out.name}'") @@ -430,25 +452,18 @@ def backup_hsm(ctx: HsmSecretsCtx, out: click.File|None): skipped = 0 with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN, ctx.hsm_serial) as ses: - - wrap_key = ses.get_object(ctx.conf.admin.wrap_key.id, yubihsm.defs.OBJECT.WRAP_KEY) - assert isinstance(wrap_key, yubihsm.objects.WrapKey) - if not hsm_obj_exists(wrap_key): - raise click.ClickException("Configured wrap key not found in the YubiHSM.") - device_objs = list(ses.list_objects()) for obj in device_objs: - # Try to export the object try: - key_bytes = wrap_key.export_wrapped(obj) + key_bytes = ses.export_wrapped(ctx.conf.admin.wrap_key, obj.id, obj.object_type) except yubihsm.exceptions.YubiHsmDeviceError as e: skipped += 1 if e.code == yubihsm.defs.ERROR.INSUFFICIENT_PERMISSIONS: - cli_warn(f"- Warning: Skipping 0x{obj.id:04x}: Insufficient permissions to export object.") + cli_error(f"- Warning: Skipping 0x{obj.id:04x}: Insufficient permissions to export object.") continue else: - cli_warn(f"- Error: Failed to export object 0x{obj.id:04x}: {e}") + cli_error(f"- Error: Failed to export object 0x{obj.id:04x}: {e}") continue # Write to tar @@ -464,7 +479,7 @@ def backup_hsm(ctx: HsmSecretsCtx, out: click.File|None): cli_info("") cli_info("Backup complete.") if skipped: - cli_warn(f"Skipped {skipped} objects due to errors or insufficient permissions.") + cli_error(f"Skipped {skipped} objects due to errors or insufficient permissions.") @cmd_hsm.command('restore') @@ -487,12 +502,7 @@ def restore_hsm(ctx: HsmSecretsCtx, backup_file: str, force: bool): click.confirm("This is the configured master device. Are you ABSOLUTELY sure you want to continue?", abort=True, err=True) with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN) as ses: - - wrap_key = ses.get_object(ctx.conf.admin.wrap_key.id, yubihsm.defs.OBJECT.WRAP_KEY) - assert isinstance(wrap_key, yubihsm.objects.WrapKey) - if not hsm_obj_exists(wrap_key): - raise click.ClickException("Configured wrap key not found in the YubiHSM.") - + wrap_key_def = ctx.conf.admin.wrap_key with open(backup_file, 'rb') as fh: tar = tarfile.open(fileobj=fh, mode='r:gz') for tarinfo in tar: @@ -509,15 +519,14 @@ def restore_hsm(ctx: HsmSecretsCtx, backup_file: str, force: bool): cli_info(click.style(f" └-> Skipping unknown object type '{obj_type}' in backup. File: '{name}'", fg='yellow')) continue - if obj_enum == yubihsm.defs.OBJECT.WRAP_KEY and obj_id == wrap_key.id: + if obj_enum == yubihsm.defs.OBJECT.WRAP_KEY and obj_id == wrap_key_def.id: cli_info(click.style(f" └-> Skipping wrap key 0x{obj_id:04x} that we are currently using for restoring.", fg='yellow')) continue - obj = ses.get_object(obj_id, obj_enum) - if hsm_obj_exists(obj): + if ses.object_exists_raw(obj_id, obj_enum): if force or click.confirm(f" └-> Object 0x{obj_id:04x} ({obj_type}) already exists. Overwrite?", default=False, err=True): cli_info(f" └-> Deleting existing {obj_type} 0x{obj_id:04x}'") - obj.delete() + ses.delete_object_raw(obj_id, obj_enum) else: cli_info(click.style(f" └-> Skipping existing {obj_type} 0x{obj_id:04x}'", fg='yellow')) continue @@ -525,8 +534,8 @@ def restore_hsm(ctx: HsmSecretsCtx, backup_file: str, force: bool): tarfh = tar.extractfile(tarinfo) assert tarfh is not None, f"Failed to extract file '{tarinfo.name}' from tar archive." key_bytes = tarfh.read() - obj = wrap_key.import_wrapped(key_bytes) - cli_info(f" └-> Restored: 0x{obj.id:04x}: ({obj.object_type.name}): {str(obj.get_info().label)}") + info = ses.import_wrapped(wrap_key_def, key_bytes) + cli_info(f" └-> Restored: 0x{info.id:04x}: ({info.object_type.name}): {str(info.label)}") cli_info("") cli_info("") diff --git a/hsm_secrets/hsm/secret_sharing_ceremony.py b/hsm_secrets/hsm/secret_sharing_ceremony.py index d1ea9fa..e223080 100644 --- a/hsm_secrets/hsm/secret_sharing_ceremony.py +++ b/hsm_secrets/hsm/secret_sharing_ceremony.py @@ -24,6 +24,7 @@ def cli_splitting_ceremony( :param with_backup_key: Whether to include a backup key in the ceremony. :param pre_secret: A pre-generated secret to use, or None to generate a new one. """ + assert threshold <= num_shares, "Threshold must be less than or equal to the number of shares." click.clear() backup_desc = """ @@ -87,7 +88,7 @@ def cli_splitting_ceremony( secret = create_16char_ascii_password(pre_secret).encode('ASCII') cli_ui_msg(f"Secret created ({len(secret) * 8} bits).") - apply_secret_fn(secret) + apply_secret_fn(secret.decode('ASCII')) cli_ui_msg("Secret applied/loaded into the system.") # Divide the original key into num_shares parts for backup diff --git a/hsm_secrets/key_adapters.py b/hsm_secrets/key_adapters.py index 154a042..65a47c5 100644 --- a/hsm_secrets/key_adapters.py +++ b/hsm_secrets/key_adapters.py @@ -8,8 +8,8 @@ import cryptography.hazmat.primitives.serialization as serialization from cryptography.hazmat.primitives import hashes -import yubihsm.objects -import yubihsm.defs +import yubihsm.objects # type: ignore [import] +import yubihsm.defs # type: ignore [import] """ Classes that wrap YubiHSM-stored keys in the cryptography.hazmat.primitives.asymmetric interfaces. @@ -19,7 +19,7 @@ """ PrivateKeyHSMAdapter = Union['RSAPrivateKeyHSMAdapter', 'Ed25519PrivateKeyHSMAdapter', 'ECPrivateKeyHSMAdapter'] -PrivateKey = Union[rsa.RSAPrivateKey, ed25519.Ed25519PrivateKey, ec.EllipticCurvePrivateKey, 'RSAPrivateKeyHSMAdapter', 'Ed25519PrivateKeyHSMAdapter', 'ECPrivateKeyHSMAdapter'] +PrivateKeyOrAdapter = Union[rsa.RSAPrivateKey, ed25519.Ed25519PrivateKey, ec.EllipticCurvePrivateKey, PrivateKeyHSMAdapter] def make_private_key_adapter(hsm_key: yubihsm.objects.AsymmetricKey) -> PrivateKeyHSMAdapter: diff --git a/hsm_secrets/main.py b/hsm_secrets/main.py index 52b12ac..e81673a 100644 --- a/hsm_secrets/main.py +++ b/hsm_secrets/main.py @@ -22,10 +22,11 @@ @click.option("--auth-yubikey", required=False, is_flag=True, help="Use Yubikey HSM auth key for HSM login") @click.option("--auth-default-admin", required=False, is_flag=True, help="Use default auth key for HSM login") @click.option("--auth-password-id", required=False, type=str, help="Auth key ID (hex) to login with password from env HSM_PASSWORD") +@click.option("--mock", required=False, type=click.Path(dir_okay=False, file_okay=True, exists=False), help="Use mock HSM for testing, data in give file") @click.version_option() @click.pass_context def cli(ctx: click.Context, config: str|None, quiet: bool, yklabel: str|None, hsmserial: str|None, - auth_default_admin: str|None, auth_yubikey: str|None, auth_password_id: str|None): + auth_default_admin: str|None, auth_yubikey: str|None, auth_password_id: str|None, mock: str|None): """Config file driven secret management tool for YubiHSM2 devices. Unless --config is specified, configuration file will be searched first @@ -84,6 +85,7 @@ def cli(ctx: click.Context, config: str|None, quiet: bool, yklabel: str|None, hs 'forced_auth_method': None, 'auth_password_id': conf.find_def(auth_password_id, HSMAuthKey).id if auth_password_id else None, 'auth_password': os.getenv("HSM_PASSWORD", None), + 'mock_file': mock, } # Check for forced auth method diff --git a/hsm_secrets/passwd/__init__.py b/hsm_secrets/passwd/__init__.py index 5972250..be48554 100644 --- a/hsm_secrets/passwd/__init__.py +++ b/hsm_secrets/passwd/__init__.py @@ -6,8 +6,9 @@ import pyescrypt # type: ignore [import] from mnemonic import Mnemonic -from hsm_secrets.config import HSMConfig, PasswordDerivationRule, PwRotationToken, find_config_items_of_class -from hsm_secrets.utils import HsmSecretsCtx, cli_code_info, cli_info, cli_result, group_by_4, hsm_obj_exists, open_hsm_session, open_hsm_session_with_yubikey, pass_common_args, secure_display_secret +from hsm_secrets.config import HSMConfig, HSMHmacKey, PasswordDerivationRule, PwRotationToken, find_config_items_of_class +from hsm_secrets.utils import HsmSecretsCtx, cli_code_info, cli_info, cli_result, group_by_4, open_hsm_session, pass_common_args, secure_display_secret +from hsm_secrets.yubihsm import HSMSession @click.group() @@ -43,23 +44,22 @@ def get_password(ctx: HsmSecretsCtx, name: str, prev: int, rule: str|None): if prev < 0: raise click.ClickException(f"Invalid previous password index: {prev}") - rule_def, key_id = _find_rule_and_key(ctx.conf, rule_id) + rule_def, hmac_key = _find_deriv_rule_and_key(ctx.conf, rule_id) with open_hsm_session(ctx) as ses: - obj = ses.get_object(key_id, yubihsm.defs.OBJECT.HMAC_KEY) - assert isinstance(obj, HmacKey) - if not hsm_obj_exists(obj): - raise click.ClickException(f"HMAC key 0x'{key_id:04x}' not found in HSM.") - - name_hmac = int.from_bytes(obj.sign_hmac(name.encode('utf8')), 'big') + # Find all rotations for the given name, and sort by timestamp + name_hmac = int.from_bytes(ses.sign_hmac(hmac_key, name.encode('utf8')), 'big') rotations = [r for r in rule_def.rotation_tokens if r.name_hmac in (None, name_hmac)] rotations.sort(key=lambda r: r.ts, reverse=True) if prev > len(rotations): raise click.ClickException(f"Password has not been rotated {prev} times yet.") + # Derive the secret from name and latest rotation nonce = rotations[prev-1].nonce if rotations else 0 - derived_secret = _derive_secret(obj, name, nonce)[:rule_def.bits//8] - password = _(derived_secret, rule_def) + nonce_bytes = nonce.to_bytes((nonce.bit_length() + 7) // 8, 'big') if nonce else b'' + derived_secret = ses.sign_hmac(hmac_key, name.encode('utf8') + nonce_bytes) + + password = _secret_to_password(derived_secret, rule_def) if ctx.quiet: print(password) @@ -99,18 +99,13 @@ def rotate_password(ctx: HsmSecretsCtx, name: list[str]|None, rule: str|None, al if not all and not name: raise click.ClickException("Must specify either --all or at least one name.") - _, key_id = _find_rule_and_key(ctx.conf, rule_id) + _, key_def = _find_deriv_rule_and_key(ctx.conf, rule_id) with open_hsm_session(ctx) as ses: - obj = ses.get_object(key_id, yubihsm.defs.OBJECT.HMAC_KEY) - assert isinstance(obj, HmacKey) - if not hsm_obj_exists(obj): - raise click.ClickException(f"HMAC key 0x'{key_id:04x}' not found in HSM.") - nonce = int.from_bytes(ses.get_pseudo_random(8), 'big') def rotate(name: str|None): - name_hmac = int.from_bytes(obj.sign_hmac(name.encode('utf8')), 'big') if name else None + name_hmac = int.from_bytes(ses.sign_hmac(key_def, name.encode('utf8')), 'big') if name else None rotation = PwRotationToken(name_hmac=name_hmac, nonce=nonce, ts=int(datetime.now().timestamp())) name_hmac_str = f"name_hmac: 0x{name_hmac:x}, " if name_hmac else "" rotation_str = f" - {{{name_hmac_str}nonce: 0x{rotation.nonce:x}, ts: {rotation.ts}}}" @@ -128,25 +123,20 @@ def rotate(name: str|None): # --- Helpers --- -def _find_rule_and_key(conf: HSMConfig, rule_id: str) -> tuple[PasswordDerivationRule, int]: +def _find_deriv_rule_and_key(conf: HSMConfig, rule_id: str) -> tuple[PasswordDerivationRule, HSMHmacKey]: rules: list[PasswordDerivationRule] = find_config_items_of_class(conf, PasswordDerivationRule) matches = [r for r in rules if r.id == rule_id] if not matches: raise click.ClickException(f"Derivation rule '{rule_id}' not found in config file.") rule_def = matches[0] - key_id = next((k.id for k in conf.password_derivation.keys if k.id == rule_def.key), None) - if not key_id: + key_def = next((k for k in conf.password_derivation.keys if k.id == rule_def.key), None) + if not key_def: raise click.ClickException(f"Key '{rule_def.key}' not found in config file.") - return rule_def, key_id - + return rule_def, key_def -def _derive_secret(obj: HmacKey, name: str, nonce: int) -> bytes: - name_bytes = name.encode('utf8') - nonce_bytes = nonce.to_bytes((nonce.bit_length() + 7) // 8, 'big') if nonce else b'' - return obj.sign_hmac(name_bytes + nonce_bytes) -def _(derived_secret: bytes, rule_def: PasswordDerivationRule) -> str: +def _secret_to_password(derived_secret: bytes, rule_def: PasswordDerivationRule) -> str: if rule_def.format == "bip39": mnemo = Mnemonic("english") secret_padded = derived_secret + b'\x00' * max(128//8 - len(derived_secret), 0) diff --git a/hsm_secrets/ssh/__init__.py b/hsm_secrets/ssh/__init__.py index bbfa3a8..9431905 100644 --- a/hsm_secrets/ssh/__init__.py +++ b/hsm_secrets/ssh/__init__.py @@ -40,9 +40,7 @@ def get_ca(ctx: HsmSecretsCtx, get_all: bool, cert_ids: Sequence[str]): with open_hsm_session(ctx) as ses: for key in selected_keys: - obj = ses.get_object(key.id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY) - assert isinstance(obj, AsymmetricKey) - pubkey = obj.get_public_key().public_bytes(encoding=_serialization.Encoding.OpenSSH, format=_serialization.PublicFormat.OpenSSH).decode('ascii') + pubkey = ses.get_public_key(key).public_bytes(encoding=_serialization.Encoding.OpenSSH, format=_serialization.PublicFormat.OpenSSH).decode('ascii') cli_result(f"{pubkey} {key.label}") @@ -74,11 +72,12 @@ def sign_ssh_key(ctx: HsmSecretsCtx, out: str, ca: str|None, username: str|None, from hsm_secrets.ssh.openssh.ssh_certificate import cert_for_ssh_pub_id, str_to_extension from hsm_secrets.key_adapters import make_private_key_adapter - ca_key_id = ctx.conf.find_def(ca, HSMAsymmetricKey).id if ca else ctx.conf.ssh.default_ca + ca_key_def = ctx.conf.find_def(ca or ctx.conf.ssh.default_ca, HSMAsymmetricKey) + assert isinstance(ca_key_def, HSMAsymmetricKey) - ca_def = [c for c in ctx.conf.ssh.root_ca_keys if c.id == ca_key_id] + ca_def = [c for c in ctx.conf.ssh.root_ca_keys if c.id == ca_key_def.id] if not ca_def: - raise click.ClickException(f"CA key 0x{ca_key_id:04x} not found in config") + raise click.ClickException(f"CA key 0x{ca_key_def.id:04x} not found in config") if not username and not certid: raise click.ClickException("Either --username or --certid must be specified") @@ -129,13 +128,8 @@ def sign_ssh_key(ctx: HsmSecretsCtx, out: str, ca: str|None, username: str|None, # Sign & write out with open_hsm_session(ctx) as ses: - obj = ses.get_object(ca_key_id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY) - assert isinstance(obj, AsymmetricKey) - - ca_pubkey = obj.get_public_key().public_bytes(encoding=_serialization.Encoding.OpenSSH, format=_serialization.PublicFormat.OpenSSH) - ca_key = make_private_key_adapter(obj) - - sign_ssh_cert(cert, ca_key) + ca_priv_key = ses.get_private_key(ca_key_def) + sign_ssh_cert(cert, ca_priv_key) cert_str = cert.to_string_fmt().replace(certid, f"{certid}{key_comment}").strip() if not out_fp: @@ -143,10 +137,11 @@ def sign_ssh_key(ctx: HsmSecretsCtx, out: str, ca: str|None, username: str|None, out_fp.write(cert_str.strip() + "\n") # type: ignore out_fp.close() if str(path) != '-': + ca_pub_key = ca_priv_key.public_key().public_bytes(encoding=_serialization.Encoding.OpenSSH, format=_serialization.PublicFormat.OpenSSH) cli_code_info(dedent(f""" Certificate written to: {path} - Send it to the user and ask them to put it in `~/.ssh/` along with the private key - To view it, run: `ssh-keygen -L -f {path}` - To allow access (adapt principals as neede), add this to your server authorized_keys file(s): - `cert-authority,principals="{','.join(cert.valid_principals)}" {ca_pubkey.decode()} HSM_{ca_def[0].label}` + `cert-authority,principals="{','.join(cert.valid_principals)}" {ca_pub_key.decode()} HSM_{ca_def[0].label}` """).strip()) diff --git a/hsm_secrets/tls/__init__.py b/hsm_secrets/tls/__init__.py index 30ea6b5..ea32dae 100644 --- a/hsm_secrets/tls/__init__.py +++ b/hsm_secrets/tls/__init__.py @@ -12,8 +12,8 @@ import yubihsm.objects # type: ignore [import] from hsm_secrets.config import HSMOpaqueObject, X509CertAttribs, X509Info -from hsm_secrets.key_adapters import PrivateKey, make_private_key_adapter -from hsm_secrets.utils import HsmSecretsCtx, cli_code_info, cli_info, cli_ui_msg, cli_warn, hsm_obj_exists, open_hsm_session, open_hsm_session_with_yubikey, pass_common_args +from hsm_secrets.key_adapters import PrivateKeyOrAdapter, make_private_key_adapter +from hsm_secrets.utils import HsmSecretsCtx, cli_code_info, cli_info, cli_warn, open_hsm_session, pass_common_args from hsm_secrets.x509.cert_builder import X509CertBuilder from hsm_secrets.x509.def_utils import find_cert_def, merge_x509_info_with_defaults @@ -46,11 +46,11 @@ def server_cert(ctx: HsmSecretsCtx, out: click.Path, common_name: str, san_dns: """ # Find the issuer CA definition issuer_x509_def = None - issuer_cert_id = -1 + issuer_cert_def = None if (sign_crt or '').strip().lower() != 'self': - issuer_cert_id = ctx.conf.find_def(sign_crt, HSMOpaqueObject).id if sign_crt else ctx.conf.tls.default_ca_id - issuer_x509_def = find_cert_def(ctx.conf, issuer_cert_id) - assert issuer_x509_def, f"CA cert ID not found: 0x{issuer_cert_id:04x}" + issuer_cert_def = ctx.conf.find_def(sign_crt or ctx.conf.tls.default_ca_id, HSMOpaqueObject) + issuer_x509_def = find_cert_def(ctx.conf, issuer_cert_def.id) + assert issuer_x509_def, f"CA cert ID not found: 0x{issuer_cert_def.id:04x}" info = X509Info() info.attribs = X509CertAttribs(common_name = common_name) @@ -69,7 +69,7 @@ def server_cert(ctx: HsmSecretsCtx, out: click.Path, common_name: str, san_dns: merged_info.path_len = None merged_info.ca = False - priv_key: PrivateKey + priv_key: PrivateKeyOrAdapter if keyfmt == 'rsa4096': priv_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) elif keyfmt == 'ed25519': @@ -94,30 +94,21 @@ def server_cert(ctx: HsmSecretsCtx, out: click.Path, common_name: str, san_dns: builder = X509CertBuilder(ctx.conf, merged_info, priv_key) issuer_cert = None if issuer_x509_def: - assert issuer_cert_id >= 0 + assert issuer_cert_def with open_hsm_session(ctx) as ses: - - ca_cert_obj = ses.get_object(issuer_cert_id, yubihsm.defs.OBJECT.OPAQUE) - assert isinstance(ca_cert_obj, yubihsm.objects.Opaque) - assert hsm_obj_exists(ca_cert_obj), f"CA cert ID not found on HSM: 0x{issuer_cert_id:04x}" - - issuer_key_obj = ses.get_object(issuer_x509_def.key.id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY) - assert isinstance(issuer_key_obj, yubihsm.objects.AsymmetricKey) - assert hsm_obj_exists(issuer_key_obj), f"CA key ID not found on HSM: 0x{issuer_x509_def.key.id:04x}" - - issuer_cert = ca_cert_obj.get_certificate() - issuer_key = make_private_key_adapter(issuer_key_obj) - - signed_cer = builder.generate_cross_signed_intermediate_cert([issuer_cert], [issuer_key])[0] - cli_info(f"Signed with CA cert 0x{issuer_cert_id:04x}: {issuer_cert.subject}") + assert isinstance(issuer_cert_def, HSMOpaqueObject) + issuer_cert = ses.get_certificate(issuer_cert_def) + issuer_key = ses.get_private_key(issuer_x509_def.key) + signed_cert = builder.generate_cross_signed_intermediate_cert([issuer_cert], [issuer_key])[0] + cli_info(f"Signed with CA cert 0x{issuer_cert_def.id:04x}: {issuer_cert.subject}") else: - signed_cer = builder.generate_self_signed_cert() + signed_cert = builder.generate_self_signed_cert() cli_warn("WARNING: Self-signed certificate, please sign the CSR manually") cli_info("") key_pem = priv_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption()) csr_pem = builder.generate_csr().public_bytes(encoding=serialization.Encoding.PEM) - crt_pem = signed_cer.public_bytes(encoding=serialization.Encoding.PEM) + crt_pem = signed_cert.public_bytes(encoding=serialization.Encoding.PEM) chain_pem = (crt_pem.strip() + b'\n' + issuer_cert.public_bytes(encoding=serialization.Encoding.PEM)) if issuer_cert else None key_file.write_bytes(key_pem) @@ -160,9 +151,9 @@ def sign_csr(ctx: HsmSecretsCtx, csr: click.Path, out: click.Path|None, ca: str| csr_obj = cryptography.x509.load_pem_x509_csr(csr_data) # Find the issuer CA definition - issuer_cert_id = ctx.conf.find_def(ca, HSMOpaqueObject).id if ca else ctx.conf.tls.default_ca_id - issuer_x509_def = find_cert_def(ctx.conf, issuer_cert_id) - assert issuer_x509_def, f"CA cert ID not found: 0x{issuer_cert_id:04x}" + issuer_cert_def = ctx.conf.find_def(ca or ctx.conf.tls.default_ca_id, HSMOpaqueObject) + issuer_x509_def = find_cert_def(ctx.conf, issuer_cert_def.id) + assert issuer_x509_def, f"CA cert ID not found: 0x{issuer_cert_def.id:04x}" if out: out_path = Path(str(out)) @@ -172,16 +163,9 @@ def sign_csr(ctx: HsmSecretsCtx, csr: click.Path, out: click.Path|None, ca: str| click.confirm(f"Output file '{out_path}' already exists. Overwrite?", abort=True, err=True) with open_hsm_session(ctx) as ses: - ca_cert_obj = ses.get_object(issuer_cert_id, yubihsm.defs.OBJECT.OPAQUE) - assert isinstance(ca_cert_obj, yubihsm.objects.Opaque) - assert hsm_obj_exists(ca_cert_obj), f"CA cert ID not found on HSM: 0x{issuer_cert_id:04x}" - - ca_key_obj = ses.get_object(issuer_x509_def.key.id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY) - assert isinstance(ca_key_obj, yubihsm.objects.AsymmetricKey) - assert hsm_obj_exists(ca_key_obj), f"CA key ID not found on HSM: 0x{issuer_x509_def.key.id:04x}" - - issuer_cert = ca_cert_obj.get_certificate() - issuer_key = make_private_key_adapter(ca_key_obj) + assert isinstance(issuer_cert_def, HSMOpaqueObject) + issuer_cert = ses.get_certificate(issuer_cert_def) + issuer_key = ses.get_private_key(issuer_x509_def.key) builder = cryptography.x509.CertificateBuilder( issuer_name = issuer_cert.subject, @@ -200,7 +184,7 @@ def sign_csr(ctx: HsmSecretsCtx, csr: click.Path, out: click.Path|None, ca: str| hash_algo = hashes.SHA256() signed_cer = builder.sign(private_key=issuer_key, algorithm=hash_algo) - cli_info(f"Signed with CA cert 0x{issuer_cert_id:04x}: {issuer_cert.subject}") + cli_info(f"Signed with CA cert 0x{issuer_cert_def.id:04x}: {issuer_cert.subject}") crt_pem = signed_cer.public_bytes(encoding=serialization.Encoding.PEM) out_path.write_bytes(crt_pem) diff --git a/hsm_secrets/user/__init__.py b/hsm_secrets/user/__init__.py index 828930b..3870ef1 100644 --- a/hsm_secrets/user/__init__.py +++ b/hsm_secrets/user/__init__.py @@ -2,7 +2,7 @@ import secrets import click from hsm_secrets.config import HSMAuthKey -from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_info, cli_ui_msg, cli_warn, confirm_and_delete_old_yubihsm_object_if_exists, group_by_4, hsm_put_derived_auth_key, hsm_put_symmetric_auth_key, open_hsm_session, pass_common_args, prompt_for_secret, pw_check_fromhex, secure_display_secret +from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_info, cli_ui_msg, cli_warn, confirm_and_delete_old_yubihsm_object_if_exists, group_by_4, open_hsm_session, pass_common_args, prompt_for_secret, pw_check_fromhex, secure_display_secret import yubikit.hsmauth import ykman.scripting @@ -109,7 +109,10 @@ def add_user_yubikey(ctx: HsmSecretsCtx, label: str, alldevs: bool): hsm_serials = ctx.conf.general.all_devices.keys() if alldevs else [ctx.hsm_serial] for serial in hsm_serials: with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN, serial) as ses: - hsm_put_symmetric_auth_key(ses, serial, ctx.conf, user_key_conf, key_enc, key_mac) + confirm_and_delete_old_yubihsm_object_if_exists(ses, user_key_conf.id, yubihsm.defs.OBJECT.AUTHENTICATION_KEY) + info = ses.auth_key_put(user_key_conf, key_enc=key_enc, key_mac=key_mac) + cli_info(f"Auth key ID '{hex(info.id)}' ({info.label}) stored in YubiHSM device {ses.get_serial()}") + cli_info("OK. User key added" + (f" to all devices (serials: {', '.join(hsm_serials)})" if alldevs else "") + ".") cli_info("") @@ -142,7 +145,7 @@ def add_service(ctx: HsmSecretsCtx, obj_ids: tuple[str], all_accts: bool, askpw: if not all_accts and not obj_ids: raise click.ClickException("No service users specified for addition.") - id_strings = [str(x.id) for x in ctx.conf.service_keys] if all_accts else obj_ids + id_strings = [f'0x{x.id:04x}' for x in ctx.conf.service_keys] if all_accts else obj_ids ids = [ctx.conf.find_def(id, HSMAuthKey).id for id in id_strings] if not ids: raise click.ClickException("No service account ids specified.") @@ -155,9 +158,7 @@ def add_service(ctx: HsmSecretsCtx, obj_ids: tuple[str], all_accts: bool, askpw: for ad in acct_defs: with open_hsm_session(ctx, HSMAuthMethod.DEFAULT_ADMIN) as ses: - obj = ses.get_object(ad.id, yubihsm.defs.OBJECT.AUTHENTICATION_KEY) - assert isinstance(obj, yubihsm.objects.AuthenticationKey) - if not confirm_and_delete_old_yubihsm_object_if_exists(ctx.hsm_serial, obj, abort=False): + if not confirm_and_delete_old_yubihsm_object_if_exists(ses, ad.id, yubihsm.defs.OBJECT.AUTHENTICATION_KEY, abort=False): cli_warn(f"Skipping service user '{ad.label}' (ID: 0x{ad.id:04x})...") continue @@ -167,16 +168,24 @@ def add_service(ctx: HsmSecretsCtx, obj_ids: tuple[str], all_accts: bool, askpw: else: rnd = ses.get_pseudo_random(16) pw = group_by_4(rnd.hex()).replace(' ', '-') + retries = 0 while True: + retries += 1 + if retries > 5: + raise click.Abort("Too many retries. Aborting.") click.pause("Press ENTER to reveal the generated password.", err=True) secure_display_secret(pw) confirm = click.prompt("Enter the password again to confirm", hide_input=True, err=True) if confirm != pw: - cli_warn("Passwords do not match. Try again.") + cli_ui_msg("Passwords do not match. Try again.") continue else: break - hsm_put_derived_auth_key(ses, ctx.hsm_serial, ctx.conf, ad, pw) + + #hsm_put_derived_auth_key(ses, ctx.hsm_serial, ctx.conf, ad, pw) + confirm_and_delete_old_yubihsm_object_if_exists(ses, ad.id, yubihsm.defs.OBJECT.AUTHENTICATION_KEY) + info = ses.auth_key_put_derived(ad, pw) + cli_info(f"Auth key ID '{hex(info.id)}' ({info.label}) stored in YubiHSM device {ses.get_serial()}") # --------------- diff --git a/hsm_secrets/utils.py b/hsm_secrets/utils.py index c357259..52c2512 100644 --- a/hsm_secrets/utils.py +++ b/hsm_secrets/utils.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum import os +from pathlib import Path from textwrap import dedent from typing import Callable, Generator, Optional, Sequence from contextlib import contextmanager @@ -10,7 +11,7 @@ from yubihsm import YubiHsm # type: ignore [import] from yubihsm.core import AuthSession # type: ignore [import] from yubihsm.defs import CAPABILITY, ALGORITHM, ERROR, OBJECT # type: ignore [import] -from yubihsm.objects import AsymmetricKey, HmacKey, SymmetricKey, WrapKey, YhsmObject, AuthenticationKey # type: ignore [import] +from yubihsm.objects import ObjectInfo, AsymmetricKey, HmacKey, SymmetricKey, WrapKey, YhsmObject, AuthenticationKey # type: ignore [import] from yubikit.hsmauth import HsmAuthSession # type: ignore [import] from yubihsm.exceptions import YubiHsmDeviceError # type: ignore [import] @@ -24,6 +25,8 @@ from functools import wraps +from hsm_secrets.yubihsm import HSMSession, MockHSMSession, RealHSMSession, open_mock_hsms, save_mock_hsms + class HSMAuthMethod(Enum): YUBIKEY = 1 DEFAULT_ADMIN = 2 @@ -32,17 +35,28 @@ class HSMAuthMethod(Enum): @dataclass class HsmSecretsCtx: + """ + Context object to pass around common arguments and configuration. + """ click_ctx: click.Context + conf: hscfg.HSMConfig hsm_serial: str yubikey_label: str quiet: bool = False + + # Authentication method overrides forced_auth_method: Optional[HSMAuthMethod] = None auth_password: Optional[str] = None auth_password_id: Optional[int] = None + mock_file: Optional[str] = None # If set, load/save mock HSM objects in this file + def pass_common_args(f): + """ + Decorator to pass common arguments to a command function, and + """ @wraps(f) def wrapper(*args, **kwargs): click_ctx = click.get_current_context() @@ -54,7 +68,8 @@ def wrapper(*args, **kwargs): quiet=click_ctx.obj.get('quiet', False), forced_auth_method = click_ctx.obj.get('forced_auth_method'), auth_password = click_ctx.obj.get('auth_password'), - auth_password_id = click_ctx.obj.get('auth_password_id')) + auth_password_id = click_ctx.obj.get('auth_password_id'), + mock_file = click_ctx.obj.get('mock_file')) try: return f(ctx, *args, **kwargs) @@ -148,7 +163,9 @@ def prompt_for_secret( :return: The user-entered secret string """ check_fn = check_fn or (lambda pw: None) - while True: + retries = 0 + while retries < 5: + retries += 1 pw = click.prompt(prompt, hide_input=True, default=default, err=True) assert isinstance(pw, str) try: @@ -160,11 +177,13 @@ def prompt_for_secret( if confirm: if click.prompt("Type again to confirm", hide_input=True, default=default, err=True) == pw: return pw - cli_warn("Mismatch. Try again.") + cli_ui_msg("Mismatch. Try again.") else: + assert isinstance(pw, str) return pw except UnicodeEncodeError: cli_error(f"Failed to encode into {enc_test.upper()}. Try again.") + raise click.Abort("Too many retries. Aborting.") def group_by_4(s: str) -> str: @@ -251,7 +270,7 @@ def verify_hsm_device_info(device_serial, hsm): def open_hsm_session( ctx: HsmSecretsCtx, default_auth_method: HSMAuthMethod = HSMAuthMethod.YUBIKEY, - device_serial: str | None = None) -> Generator[AuthSession, None, None]: + device_serial: str | None = None) -> Generator[HSMSession, None, None]: """ Open a session to the HSM using forced or given default auth method. This is an auto-selecting wrapper for the specific session context managers. @@ -259,6 +278,17 @@ def open_hsm_session( auth_method = ctx.forced_auth_method or default_auth_method device_serial = device_serial or ctx.hsm_serial + # Mock HSM session for testing + if ctx.mock_file: + cli_warn("~🤡~ !! SIMULATED (mock) HSM session !! Authentication skipped. ~🤡~") + open_mock_hsms(ctx.mock_file, int(device_serial), ctx.conf) + try: + yield MockHSMSession(int(device_serial)) + finally: + save_mock_hsms(ctx.mock_file) + return + + # Real HSM session with the selected auth method if auth_method == HSMAuthMethod.YUBIKEY: ctxman = open_hsm_session_with_yubikey(ctx, device_serial) elif auth_method == HSMAuthMethod.DEFAULT_ADMIN: @@ -269,8 +299,11 @@ def open_hsm_session( ctxman = open_hsm_session_with_password(ctx, ctx.auth_password_id, ctx.auth_password, device_serial) else: raise ValueError(f"Unknown auth method: {auth_method}") - with ctxman as session: - yield session + with ctxman as ses: + if isinstance(ses, HSMSession): + yield ses + else: + yield RealHSMSession(ctx.conf, session=ses, dev_serial=int(device_serial)) @contextmanager @@ -326,7 +359,7 @@ def open_hsm_session_with_default_admin(ctx: HsmSecretsCtx, device_serial: str|N @contextmanager -def open_hsm_session_with_password(ctx: HsmSecretsCtx, auth_key_id: int, password: str, device_serial: str|None = None) -> Generator[AuthSession, None, None]: +def open_hsm_session_with_password(ctx: HsmSecretsCtx, auth_key_id: int, password: str, device_serial: str|None = None) -> Generator[HSMSession, None, None]: """ Open a session to the HSM using a password-derived auth key. """ @@ -343,145 +376,18 @@ def open_hsm_session_with_password(ctx: HsmSecretsCtx, auth_key_id: int, passwor cli_info(f"Using password login with key ID 0x{auth_key_id:04x}") session = hsm.create_session_derived(auth_key_id, password) try: - yield session + yield RealHSMSession(ctx.conf, session=session, dev_serial=int(device_serial)) finally: session.close() hsm.close() -def encode_capabilities(names: Sequence[hscfg.AsymmetricCapabilityName] | set[hscfg.AsymmetricCapabilityName]) -> CAPABILITY: - return hscfg.HSMConfig.capability_from_names(set(names)) - -def encode_algorithm(name_literal: str|hscfg.AsymmetricAlgorithm) -> ALGORITHM: - return hscfg.HSMConfig.algorithm_from_name(name_literal) # type: ignore - - -def hsm_put_wrap_key(ses: AuthSession, hsm_serial: str, conf: hscfg.HSMConfig, key_def: hscfg.HSMWrapKey, key: bytes) -> WrapKey: - """ - Put a (symmetric) wrap key into the HSM. - """ - wrap_key = ses.get_object(key_def.id, OBJECT.WRAP_KEY) - assert isinstance(wrap_key, WrapKey) - confirm_and_delete_old_yubihsm_object_if_exists(hsm_serial, wrap_key) - res = wrap_key.put( - session = ses, - object_id = key_def.id, - label = key_def.label, - algorithm = conf.algorithm_from_name(key_def.algorithm), - domains = conf.get_domain_bitfield(key_def.domains), - capabilities = conf.capability_from_names(set(key_def.capabilities)), - delegated_capabilities = conf.capability_from_names(set(key_def.delegated_capabilities)), - key = key) - cli_info(f"Wrap key ID '{hex(res.id)}' stored in YubiHSM device {hsm_serial}") - return res - - -def hsm_put_derived_auth_key(ses: AuthSession, hsm_serial: str, conf: hscfg.HSMConfig, key_def: hscfg.HSMAuthKey, password: str) -> AuthenticationKey: - """ - Put a password-derived authentication key into the HSM. - """ - auth_key = ses.get_object(key_def.id, OBJECT.AUTHENTICATION_KEY) - assert isinstance(auth_key, AuthenticationKey) - confirm_and_delete_old_yubihsm_object_if_exists(hsm_serial, auth_key) - res = auth_key.put_derived( - session = ses, - object_id = key_def.id, - label = key_def.label, - domains = conf.get_domain_bitfield(key_def.domains), - capabilities = conf.capability_from_names(key_def.capabilities), - delegated_capabilities = conf.capability_from_names(key_def.delegated_capabilities), - password = password) - cli_info(f"Auth key ID '{hex(res.id)}' ({key_def.label}) stored in YubiHSM device {hsm_serial}") - return res - - -def hsm_put_symmetric_auth_key(ses: AuthSession, hsm_serial: str, conf: hscfg.HSMConfig, key_def: hscfg.HSMAuthKey, key_enc: bytes, key_mac: bytes) -> AuthenticationKey: - """ - Put a symmetric authentication key into the HSM. - """ - auth_key = ses.get_object(key_def.id, OBJECT.AUTHENTICATION_KEY) - assert isinstance(auth_key, AuthenticationKey) - confirm_and_delete_old_yubihsm_object_if_exists(hsm_serial, auth_key) - res = auth_key.put( - session = ses, - object_id = key_def.id, - label = key_def.label, - domains = conf.get_domain_bitfield(key_def.domains), - capabilities = conf.capability_from_names(key_def.capabilities), - delegated_capabilities = conf.capability_from_names(key_def.delegated_capabilities), - key_enc = key_enc, - key_mac = key_mac) - cli_info(f"Auth key ID '{hex(res.id)}' ({key_def.label}) stored in YubiHSM device {hsm_serial}") - return res - - -def hsm_generate_symmetric_key(ses: AuthSession, hsm_serial: str, conf: hscfg.HSMConfig, key_def: hscfg.HSMSymmetricKey) -> SymmetricKey: - """ - Generate a symmetric key on the HSM. - """ - sym_key = ses.get_object(key_def.id, OBJECT.SYMMETRIC_KEY) - assert isinstance(sym_key, SymmetricKey) - confirm_and_delete_old_yubihsm_object_if_exists(hsm_serial, sym_key) - cli_info(f"Generating symmetric key, type '{key_def.algorithm}'...") - res = sym_key.generate( - session = ses, - object_id = key_def.id, - label = key_def.label, - domains = conf.get_domain_bitfield(key_def.domains), - capabilities = conf.capability_from_names(set(key_def.capabilities)), - algorithm = conf.algorithm_from_name(key_def.algorithm)) - cli_info(f"Symmetric key ID '{hex(res.id)}' ({key_def.label}) generated in YubiHSM device {hsm_serial}") - return res - - -def hsm_generate_asymmetric_key(ses: AuthSession, hsm_serial: str, conf: hscfg.HSMConfig, key_def: hscfg.HSMAsymmetricKey) -> AsymmetricKey: - """ - Generate an asymmetric key on the HSM. - """ - asym_key = ses.get_object(key_def.id, OBJECT.ASYMMETRIC_KEY) - assert isinstance(asym_key, AsymmetricKey) - confirm_and_delete_old_yubihsm_object_if_exists(hsm_serial, asym_key) - cli_info(f"Generating asymmetric key, type '{key_def.algorithm}'...") - if 'rsa' in key_def.algorithm.lower(): - cli_warn(" Note! RSA key generation is very slow. Please wait. The YubiHSM2 should blinking rapidly while it works.") - cli_warn(" If the process aborts / times out, you can rerun this command to resume.") - res = asym_key.generate( - session = ses, - object_id = key_def.id, - label = key_def.label, - domains = conf.get_domain_bitfield(key_def.domains), - capabilities = conf.capability_from_names(set(key_def.capabilities)), - algorithm = conf.algorithm_from_name(key_def.algorithm)) - cli_info(f"Symmetric key ID '{hex(res.id)}' ({key_def.label}) stored in YubiHSM device {hsm_serial}") - return res - - -def hsm_generate_hmac_key(ses: AuthSession, hsm_serial: str, conf: hscfg.HSMConfig, key_def: hscfg.HSMHmacKey) -> HmacKey: - """ - Generate an HMAC key on the HSM. - """ - hmac_key = ses.get_object(key_def.id, OBJECT.HMAC_KEY) - assert isinstance(hmac_key, HmacKey) - confirm_and_delete_old_yubihsm_object_if_exists(hsm_serial, hmac_key) - cli_info(f"Generating HMAC key, type '{key_def.algorithm}'...") - res = hmac_key.generate( - session = ses, - object_id = key_def.id, - label = key_def.label, - domains = conf.get_domain_bitfield(key_def.domains), - capabilities = conf.capability_from_names(set(key_def.capabilities)), - algorithm = conf.algorithm_from_name(key_def.algorithm)) - cli_info(f"HMAC key ID '{hex(res.id)}' ({key_def.label}) stored in YubiHSM device {hsm_serial}") - return res - - -def pretty_fmt_yubihsm_object(o: YhsmObject): - info = o.get_info() +def pretty_fmt_yubihsm_object(info: ObjectInfo) -> str: domains: set|str = hscfg.HSMConfig.domain_bitfield_to_nums(info.domains) domains = "all" if len(domains) == 16 else domains return dedent(f""" - 0x{o.id:04x} - type: {o.object_type.name} ({o.object_type}) + 0x{info.id:04x} + type: {info.object_type.name} ({info.object_type}) label: {repr(info.label)} algorithm: {info.algorithm.name} ({info.algorithm}) size: {info.size} @@ -492,22 +398,8 @@ def pretty_fmt_yubihsm_object(o: YhsmObject): """).strip() -def hsm_obj_exists(hsm_key_obj: YhsmObject) -> bool: - """ - Check if a YubiHSM object exists. - :param hsm_key_obj: The object to check for - :return: True if the object exists - """ - try: - _ = hsm_key_obj.get_info() # Raises an exception if the key does not exist - return True - except YubiHsmDeviceError as e: - if e.code == ERROR.OBJECT_NOT_FOUND: - return False - raise e - -def confirm_and_delete_old_yubihsm_object_if_exists(serial: str, obj: YhsmObject, abort=True) -> bool: +def confirm_and_delete_old_yubihsm_object_if_exists(ses: HSMSession, obj_id: hscfg.HSMKeyID, object_type: OBJECT, abort=True) -> bool: """ Check if a YubiHSM object exists, and if so, ask the user if they want to replace it. :param serial: The serial number of the YubiHSM device @@ -515,12 +407,12 @@ def confirm_and_delete_old_yubihsm_object_if_exists(serial: str, obj: YhsmObject :param abort: Whether to abort (raise) if the user does not want to delete the object :return: True if the object doesn't exist or was deleted, False if the user chose not to delete it """ - if hsm_obj_exists(obj): - cli_ui_msg(f"Object 0x{obj.id:04x} already exists on YubiHSM device {serial}:", err=True) - cli_ui_msg(pretty_fmt_yubihsm_object(obj)) + if info := ses.object_exists_raw(obj_id, object_type): + cli_ui_msg(f"Object 0x{obj_id:04x} already exists on YubiHSM device:") + cli_ui_msg(pretty_fmt_yubihsm_object(info)) cli_info("") if click.confirm("Replace the existing key?", default=False, abort=abort, err=True): - obj.delete() + ses.delete_object_raw(obj_id, object_type) else: return False return True diff --git a/hsm_secrets/x509/__init__.py b/hsm_secrets/x509/__init__.py index 9a64f5c..dd2448d 100644 --- a/hsm_secrets/x509/__init__.py +++ b/hsm_secrets/x509/__init__.py @@ -8,9 +8,9 @@ import yubihsm.defs # type: ignore [import] from cryptography.hazmat.primitives import serialization -from hsm_secrets.config import HSMConfig, KeyID, HSMOpaqueObject, X509Cert, find_config_items_of_class +from hsm_secrets.config import HSMConfig, HSMKeyID, HSMOpaqueObject, X509Cert, find_config_items_of_class -from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_result, cli_warn, confirm_and_delete_old_yubihsm_object_if_exists, hsm_obj_exists, open_hsm_session, cli_code_info, pass_common_args, cli_info +from hsm_secrets.utils import HSMAuthMethod, HsmSecretsCtx, cli_result, cli_warn, confirm_and_delete_old_yubihsm_object_if_exists, open_hsm_session, cli_code_info, pass_common_args, cli_info from hsm_secrets.x509.cert_builder import X509CertBuilder from hsm_secrets.x509.def_utils import pretty_x509_info, merge_x509_info_with_defaults, topological_sort_x509_cert_defs @@ -18,12 +18,13 @@ import click from hsm_secrets.key_adapters import make_private_key_adapter +from hsm_secrets.yubihsm import HSMSession @click.group() @click.pass_context def cmd_x509(ctx: click.Context): - """Genral X.509 Certificate Management""" + """General X.509 Certificate Management""" ctx.ensure_object(dict) # --------------- @@ -75,9 +76,7 @@ def get_cert_cmd(ctx: HsmSecretsCtx, all_certs: bool, outdir: str|None, bundle: with open_hsm_session(ctx) as ses: for cd in selected_certs: - obj = ses.get_object(cd.id, yubihsm.defs.OBJECT.OPAQUE) - assert isinstance(obj, yubihsm.objects.Opaque) - pem = obj.get_certificate().public_bytes(encoding=serialization.Encoding.PEM).decode() + pem = ses.get_certificate(cd).public_bytes(encoding=serialization.Encoding.PEM).decode() if outdir: pem_file = Path(outdir) / f"{cd.label}.pem" pem_file.write_text(pem.strip() + "\n") @@ -100,8 +99,8 @@ def create_certs_impl(ctx: HsmSecretsCtx, all_certs: bool, dry_run: bool, cert_i Performs a topological sort of the certificates to ensure that any dependencies are created first. """ # Enumerate all certificate definitions in the config - scid_to_opq_def: dict[KeyID, HSMOpaqueObject] = {} - scid_to_x509_def: dict[KeyID, X509Cert] = {} + scid_to_opq_def: dict[HSMKeyID, HSMOpaqueObject] = {} + scid_to_x509_def: dict[HSMKeyID, X509Cert] = {} for x in find_config_items_of_class(ctx.conf, X509Cert): assert isinstance(x, X509Cert) @@ -109,12 +108,12 @@ def create_certs_impl(ctx: HsmSecretsCtx, all_certs: bool, dry_run: bool, cert_i scid_to_opq_def[opq.id] = opq scid_to_x509_def[opq.id] = x - def _do_it(ses: AuthSession|None): + def _do_it(ses: HSMSession|None): selected_defs = list(scid_to_opq_def.values()) if all_certs \ else [cast(HSMOpaqueObject, ctx.conf.find_def(id, HSMOpaqueObject)) for id in cert_ids] creation_order = topological_sort_x509_cert_defs(selected_defs) - id_to_cert_obj: dict[KeyID, x509.Certificate] = {} + id_to_cert_obj: dict[HSMKeyID, x509.Certificate] = {} # Create the certificates in topological order for cd in creation_order: @@ -124,37 +123,28 @@ def _do_it(ses: AuthSession|None): cli_info(indent(pretty_x509_info(x509_info), " ")) if not dry_run: - assert isinstance(ses, AuthSession) - + assert ses x509_def = scid_to_x509_def[cd.id] - key = ses.get_object(x509_def.key.id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY) - assert isinstance(key, yubihsm.objects.AsymmetricKey) # If the certificate is signed by another certificate, get the issuer cert and key - issuer_cert = None - issuer_key = None - + issuer_cert, issuer_key = None, None if cd.sign_by and cd.sign_by != cd.id: issuer_cert = id_to_cert_obj.get(cd.sign_by) if not issuer_cert: # Issuer cert was not created on this run, try to load it from the HSM - issuer_hsm_obj = ses.get_object(cd.sign_by, yubihsm.defs.OBJECT.OPAQUE) - assert isinstance(issuer_hsm_obj, yubihsm.objects.Opaque) - if not hsm_obj_exists(issuer_hsm_obj): + if not ses.object_exists(cd): raise click.ClickException(f"ERROR: Certificate 0x{cd.sign_by:04x} not found in HSM. Create it first, to sign 0x{cd.id:04x}.") - issuer_cert = issuer_hsm_obj.get_certificate() + issuer_cert = ses.get_certificate(cd) - # Get a HSM-backed key (adapter) for the issuer cert - key_id = scid_to_x509_def[cd.sign_by].key.id - key_obj = ses.get_object(key_id, yubihsm.defs.OBJECT.ASYMMETRIC_KEY) - assert isinstance(key_obj, yubihsm.objects.AsymmetricKey) - if not hsm_obj_exists(key_obj): - raise click.ClickException(f"ERROR: Key 0x{key_id:04x} not found in HSM. Create it first, to sign 0x{cd.id:04x}.") - issuer_key = make_private_key_adapter(key_obj) + sign_key_def = scid_to_x509_def[cd.sign_by].key + if not ses.object_exists(sign_key_def): + raise click.ClickException(f"ERROR: Key 0x{sign_key_def.id:04x} not found in HSM. Create it first, to sign 0x{cd.id:04x}.") + issuer_key = ses.get_private_key(sign_key_def) # Create and sign the certificate assert x509_def.x509_info, "X.509 certificate definition is missing x509_info" - builder = X509CertBuilder(ctx.conf, x509_def.x509_info, key) + priv_key = ses.get_private_key(x509_def.key) + builder = X509CertBuilder(ctx.conf, x509_def.x509_info, priv_key) if issuer_cert: assert issuer_key id_to_cert_obj[cd.id] = builder.generate_cross_signed_intermediate_cert([issuer_cert], [issuer_key])[0] @@ -164,17 +154,9 @@ def _do_it(ses: AuthSession|None): # Put the certificates into the HSM for cd in creation_order: if not dry_run: - assert isinstance(ses, AuthSession) - hsm_obj = ses.get_object(cd.id, yubihsm.defs.OBJECT.OPAQUE) - assert isinstance(hsm_obj, yubihsm.objects.Opaque) - if confirm_and_delete_old_yubihsm_object_if_exists(ctx.hsm_serial, hsm_obj, abort=False): - hsm_obj.put_certificate( - session = ses, - object_id = cd.id, - label = cd.label, - domains = ctx.conf.get_domain_bitfield(cd.domains), - capabilities = ctx.conf.capability_from_names({'exportable-under-wrap'}), - certificate = id_to_cert_obj[cd.id]) + assert isinstance(ses, HSMSession) + if confirm_and_delete_old_yubihsm_object_if_exists(ses, cd.id, yubihsm.defs.OBJECT.OPAQUE, abort=False): + ses.put_certificate(cd, id_to_cert_obj[cd.id]) cli_info(f"Certificate 0x{cd.id:04x} created and stored in YubiHSM (serial {ctx.hsm_serial}).") if dry_run: diff --git a/hsm_secrets/x509/cert_builder.py b/hsm_secrets/x509/cert_builder.py index 790fc48..3f3fe70 100644 --- a/hsm_secrets/x509/cert_builder.py +++ b/hsm_secrets/x509/cert_builder.py @@ -1,3 +1,4 @@ +import click from cryptography import x509 from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID from cryptography.hazmat.primitives import hashes @@ -17,16 +18,16 @@ import yubihsm.defs # type: ignore [import] from hsm_secrets.x509.def_utils import merge_x509_info_with_defaults -from hsm_secrets.key_adapters import PrivateKeyHSMAdapter, RSAPrivateKeyHSMAdapter, Ed25519PrivateKeyHSMAdapter, ECPrivateKeyHSMAdapter, PrivateKey +from hsm_secrets.key_adapters import PrivateKeyHSMAdapter, RSAPrivateKeyHSMAdapter, Ed25519PrivateKeyHSMAdapter, ECPrivateKeyHSMAdapter, PrivateKeyOrAdapter class X509CertBuilder: """ Ephemeral class for building and signing X.509 certificates using the YubiHSM as a key store. """ - private_key: PrivateKey + private_key: PrivateKeyOrAdapter - def __init__(self, hsm_config: HSMConfig, cert_def_info: X509Info, priv_key: PrivateKey|yubihsm.objects.AsymmetricKey): + def __init__(self, hsm_config: HSMConfig, cert_def_info: X509Info, priv_key: PrivateKeyOrAdapter|yubihsm.objects.AsymmetricKey): """ Initialize a new X.509 certificate builder. @@ -56,11 +57,11 @@ def generate_self_signed_cert(self) -> x509.Certificate: Build and sign a self-signed X.509 certificate. """ builder = self._build_cert_base() - ed = isinstance(self.private_key, Ed25519PrivateKeyHSMAdapter) + ed = isinstance(self.private_key, (Ed25519PrivateKeyHSMAdapter, ed25519.Ed25519PrivateKey)) return builder.sign(self.private_key, None if ed else hashes.SHA256()) - def generate_cross_signed_intermediate_cert(self, issuer_certs: List[x509.Certificate], issuer_keys: List[PrivateKeyHSMAdapter]) -> List[x509.Certificate]: + def generate_cross_signed_intermediate_cert(self, issuer_certs: List[x509.Certificate], issuer_keys: List[PrivateKeyOrAdapter]) -> List[x509.Certificate]: """ Build and sign an intermediate X.509 certificate with one or more issuer certificates. This is used to cross-sign an intermediate CA certificate with root CAs. @@ -81,7 +82,7 @@ def generate_cross_signed_intermediate_cert(self, issuer_certs: List[x509.Certif subject_key_identifier = x509.SubjectKeyIdentifier.from_public_key(self.private_key.public_key()) builder = builder.add_extension(subject_key_identifier, critical=False) - ed = isinstance(issuer_key, Ed25519PrivateKeyHSMAdapter) + ed = isinstance(issuer_key, (Ed25519PrivateKeyHSMAdapter, ed25519.Ed25519PrivateKey)) cert = builder.sign(issuer_key, None if ed else hashes.SHA256()) cross_signed_certs.append(cert) @@ -105,7 +106,7 @@ def generate_csr(self) -> x509.CertificateSigningRequest: if self.cert_def_info.extended_key_usage: builder = builder.add_extension(self._mkext_extended_key_usage(), critical=False) - ed = isinstance(self.private_key, Ed25519PrivateKeyHSMAdapter) + ed = isinstance(self.private_key, (Ed25519PrivateKeyHSMAdapter, ed25519.Ed25519PrivateKey)) return builder.sign(self.private_key, None if ed else hashes.SHA256()) @@ -172,18 +173,18 @@ def _mk_name_attribs(self) -> x509.Name: def _mkext_alt_name(self) -> x509.SubjectAlternativeName: assert self.cert_def_info.attribs, "X509Info.attribs.subject_alt_names is missing" type_to_cls = { - "dns": x509.DNSName, - "ip": x509.IPAddress, - "rfc822": x509.RFC822Name, - "uri": x509.UniformResourceIdentifier, - "directory": x509.DirectoryName, - "registered_id": x509.RegisteredID, - "other": x509.OtherName + "dns": (x509.DNSName, lambda n: n), + "ip": (x509.IPAddress, lambda n: ipaddress.ip_address(n)), + "rfc822": (x509.RFC822Name, lambda n: n), + "uri": (x509.UniformResourceIdentifier, lambda n: n), + "directory": (x509.DirectoryName, lambda n: n), + "registered_id": (x509.RegisteredID, lambda n: n), + "other": (x509.OtherName, lambda n: n) } san: List[x509.GeneralName] = [] for san_type, names in (self.cert_def_info.attribs.subject_alt_names or {}).items(): - dst_cls = type_to_cls[san_type] - san.extend([dst_cls(name) for name in names]) + dst_cls, dst_conv = type_to_cls[san_type] + san.extend([dst_cls(dst_conv(name)) for name in names]) return x509.SubjectAlternativeName(san) def _mkext_key_usage(self) -> x509.KeyUsage: diff --git a/hsm_secrets/x509/def_utils.py b/hsm_secrets/x509/def_utils.py index b1dab6e..e63c5a3 100644 --- a/hsm_secrets/x509/def_utils.py +++ b/hsm_secrets/x509/def_utils.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Dict, List, Optional -from hsm_secrets.config import HSMConfig, KeyID, HSMOpaqueObject, X509Cert, X509Info, find_config_items_of_class +from hsm_secrets.config import HSMConfig, HSMKeyID, HSMOpaqueObject, X509Cert, X509Info, find_config_items_of_class """ Utility functions for working with certificate definitions from the HSMConfig object. @@ -94,15 +94,15 @@ def topological_sort_x509_cert_defs(cert_defs: List[HSMOpaqueObject]) -> list[HS """ # Step 1: Build a dependency graph id_to_def = {cd.id: cd for cd in cert_defs} - signer_to_signees: Dict[KeyID, List[KeyID]] = defaultdict(list) + signer_to_signees: Dict[HSMKeyID, List[HSMKeyID]] = defaultdict(list) for cd in cert_defs: if cd.sign_by and cd.sign_by != cd.id: # Skip self-signed certs signer_to_signees[cd.sign_by].append(cd.id) # Step 2: Perform a topological sort with loop detection sorted_certs: List[HSMOpaqueObject] = [] - visited: set[KeyID] = set() - in_path: set[KeyID] = set() + visited: set[HSMKeyID] = set() + in_path: set[HSMKeyID] = set() def dfs(c: HSMOpaqueObject): if c.id in in_path: @@ -124,7 +124,7 @@ def dfs(c: HSMOpaqueObject): return sorted_certs -def find_cert_def(conf: HSMConfig, opaque_id: KeyID|int) -> Optional[X509Cert]: +def find_cert_def(conf: HSMConfig, opaque_id: HSMKeyID|int) -> Optional[X509Cert]: """ Find a certificate definition by its opaque ID. """ diff --git a/hsm_secrets/yubihsm.py b/hsm_secrets/yubihsm.py new file mode 100644 index 0000000..a867280 --- /dev/null +++ b/hsm_secrets/yubihsm.py @@ -0,0 +1,763 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Sequence, cast +import pickle +import os + +import click +from yubihsm.defs import CAPABILITY, ALGORITHM, ERROR, OBJECT, ORIGIN # type: ignore [import] +from yubihsm.objects import AsymmetricKey, HmacKey, SymmetricKey, WrapKey, YhsmObject, AuthenticationKey, Opaque # type: ignore [import] +from yubihsm.core import AuthSession # type: ignore [import] +from yubihsm.exceptions import YubiHsmDeviceError # type: ignore [import] +from yubihsm.objects import ObjectInfo + +# Mock YubiHSM2 device with Cryptodome library +import cryptography.hazmat.primitives.ciphers.algorithms as haz_algs +import cryptography.hazmat.primitives.ciphers as haz_ciphers +import cryptography.hazmat.primitives.ciphers.modes as haz_cipher_modes +import cryptography.hazmat.primitives.asymmetric.rsa as haz_rsa +import cryptography.hazmat.primitives.asymmetric.ed25519 as haz_ed25519 +import cryptography.hazmat.primitives.asymmetric.ec as haz_ec +import cryptography.hazmat.primitives.hashes as haz_hashes +import cryptography.hazmat.primitives.hmac as haz_hmac +import cryptography.hazmat.primitives.serialization as haz_ser +from cryptography.hazmat.primitives import _serialization as haz_priv_ser +import cryptography.hazmat.primitives.asymmetric.padding as haz_asym_padding +import cryptography.x509 as haz_x509 + +from hsm_secrets.config import HSMAsymmetricKey, HSMAuthKey, HSMConfig, HSMHmacKey, HSMKeyID, HSMObjBase, HSMOpaqueObject, HSMSymmetricKey, HSMWrapKey, NoExtraBaseModel +from hsm_secrets.key_adapters import PrivateKeyOrAdapter, make_private_key_adapter + +""" +Abstracts the YubiHSM2 interface for the purpose of testing. + +This module provides both real (pass-through) and mock implementations of YubiHSM2 operations +that the rest of the application uses. The mock implementation is used for testing purposes, +and can simulate several YubiHSM2 devices with different objects stored in them. + +The mock devices are stored in pickle files, and can be loaded and saved using the +load_mock_hsms() and save_mock_hsms() functions. +""" + +HSM_KEY_TYPES = (HSMAuthKey, HSMWrapKey, HSMSymmetricKey, HSMAsymmetricKey, HSMHmacKey, HSMOpaqueObject) + +class HSMSession(ABC): + """ + Abstract base class for HSM sessions. + + This class defines the interface for interacting with a Hardware Security Module (HSM), + whether it's a real YubiHSM2 or a mock HSM for testing purposes. + """ + + @abstractmethod + def get_serial(self) -> int: + """ + Get the serial number of the HSM device. + """ + pass + + @abstractmethod + def get_info(self, objdef: HSMObjBase) -> ObjectInfo: + """ + Get information about an object in the HSM. + + :param objdef: Object definition + :return: Object information + """ + pass + + @abstractmethod + def get_info_raw(self, id: HSMKeyID, type: OBJECT) -> ObjectInfo: + """ + Get information about an object in the HSM, given its id and type. + + :param id: Object ID + :param type: Object type + :return: Object information + """ + pass + + @abstractmethod + def object_exists(self, objdef: HSMObjBase) -> ObjectInfo | None: + """ + Check if an object exists in the HSM and return its info. + + :param objdef: Object definition + :return: Object information if it exists, None otherwise + """ + pass + + @abstractmethod + def object_exists_raw(self, id: HSMKeyID, type: OBJECT) -> ObjectInfo | None: + """ + Check if an object exists in the HSM, given its id and type. + + :param id: Object ID + :param type: Object type + :return: Object information if it exists, None otherwise + """ + pass + + @abstractmethod + def put_wrap_key(self, keydef: HSMWrapKey, secret: bytes) -> ObjectInfo: + """ + Put a wrap key on device. + + :param keydef: Wrap key definition + :param secret: Key secret + :return: Object information of the created wrap key + """ + pass + + @abstractmethod + def attest_asym_key(self, key_id: HSMKeyID) -> haz_x509.Certificate: + """ + Attest an asymmetric key. + + :param key_id: Key ID + :return: Attestation certificate + """ + pass + + @abstractmethod + def export_wrapped(self, wrap_key: HSMWrapKey, obj_id: HSMKeyID, obj_type: OBJECT) -> bytes: + """ + Export an object from the HSM, encrypted with a wrap key. + + :param wrap_key: Wrap key to use for encryption + :param obj_id: ID of the object to export + :param obj_type: Type of the object to export + :return: Encrypted object data + """ + pass + + @abstractmethod + def import_wrapped(self, wrap_key: HSMWrapKey, data: bytes) -> ObjectInfo: + """ + Import an object into the HSM, decrypting it with a wrap key. + + :param wrap_key: Wrap key to use for decryption + :param data: Encrypted object data + :return: Object information of the imported object + """ + pass + + @abstractmethod + def auth_key_put_derived(self, keydef: HSMAuthKey, password: str) -> ObjectInfo: + """ + Create a new authentication key from a password. + + :param keydef: Authentication key definition + :param password: Password to derive the key from + :return: Object information of the created authentication key + """ + pass + + @abstractmethod + def auth_key_put(self, keydef: HSMAuthKey, key_enc: bytes, key_mac: bytes) -> ObjectInfo: + """ + Create a new authentication key from raw key material. + + :param keydef: Authentication key definition + :param key_enc: Encryption key + :param key_mac: MAC key + :return: Object information of the created authentication key + """ + pass + + @abstractmethod + def sym_key_generate(self, keydef: HSMSymmetricKey) -> ObjectInfo: + """ + Create a new symmetric key. + + :param keydef: Symmetric key definition + :return: Object information of the created symmetric key + """ + pass + + @abstractmethod + def asym_key_generate(self, keydef: HSMAsymmetricKey) -> ObjectInfo: + """ + Create a new asymmetric key. + + :param keydef: Asymmetric key definition + :return: Object information of the created asymmetric key + """ + pass + + @abstractmethod + def hmac_key_generate(self, keydef: HSMHmacKey) -> ObjectInfo: + """ + Create a new HMAC key. + + :param keydef: HMAC key definition + :return: Object information of the created HMAC key + """ + pass + + @abstractmethod + def get_pseudo_random(self, length: int) -> bytes: + """ + Generate pseudo-random bytes. + + :param length: Number of bytes to generate + :return: Pseudo-random bytes + """ + pass + + @abstractmethod + def list_objects(self) -> Sequence['YhsmObject | MockYhsmObject']: + """ + List all objects in the HSM. + + :return: Sequence of HSM objects + """ + pass + + @abstractmethod + def delete_object(self, objdef: HSMObjBase) -> None: + """ + Delete an object from the HSM, given its definition. + + :param objdef: Object definition + """ + pass + + @abstractmethod + def delete_object_raw(self, id: HSMKeyID, type: OBJECT) -> None: + """ + Delete an object from the HSM, given its key and type. + + :param id: Object ID + :param type: Object type + """ + pass + + @abstractmethod + def sign_hmac(self, keydef: HSMHmacKey, data: bytes) -> bytes: + """ + Sign data with an HMAC key. + + :param keydef: HMAC key definition + :param data: Data to sign + :return: HMAC signature + """ + pass + + @abstractmethod + def get_certificate(self, keydef: HSMOpaqueObject) -> haz_x509.Certificate: + """ + Get a certificate from the HSM. + + :param keydef: Opaque object definition containing the certificate + :return: X.509 Certificate + """ + pass + + @abstractmethod + def put_certificate(self, keydef: HSMOpaqueObject, certificate: haz_x509.Certificate) -> ObjectInfo: + """ + Store a certificate in the HSM. + + :param keydef: Opaque object definition to store the certificate + :param certificate: X.509 Certificate to store + :return: Object information of the stored certificate + """ + pass + + @abstractmethod + def get_private_key(self, keydef: HSMAsymmetricKey) -> PrivateKeyOrAdapter: + """ + Get a private key adapter for an asymmetric key. + + :param keydef: Asymmetric key definition + :return: Private key or adapter + """ + pass + + @abstractmethod + def get_public_key(self, keydef: HSMAsymmetricKey) -> haz_rsa.RSAPublicKey | haz_ec.EllipticCurvePublicKey | haz_ed25519.Ed25519PublicKey: + """ + Get the public key from an asymmetric key. + + :param keydef: Asymmetric key definition + :return: Public key (RSA, EC, or Ed25519) + """ + pass + +# --------- Real YubiHSM2 --------- + +class RealHSMSession(HSMSession): + """ + Implementation of the HSM session for a real YubiHSM2 device. + """ + + def __init__(self, conf: HSMConfig, session: AuthSession, dev_serial: int): + """ + Initialize the real HSM session. + + :param conf: HSM configuration + :param session: Authenticated session with the YubiHSM2 + :param dev_serial: Device serial number + """ + self.dev_serial = dev_serial + self.backend = session + self.conf = conf + + def get_serial(self) -> HSMKeyID: + return self.dev_serial + + def get_info(self, objdef: HSMObjBase) -> ObjectInfo: + res = self.object_exists(objdef) + if not res: + raise YubiHsmDeviceError(ERROR.OBJECT_NOT_FOUND) + return res + + def get_info_raw(self, id: HSMKeyID, type: OBJECT) -> ObjectInfo: + res = self.object_exists_raw(id, type) + if not res: + raise YubiHsmDeviceError(ERROR.OBJECT_NOT_FOUND) + return res + + def object_exists(self, objdef: HSMObjBase) -> ObjectInfo | None: + assert isinstance(objdef, HSM_KEY_TYPES) + obj_type = _conf_class_to_yhs_object_type[objdef.__class__] + return self.object_exists_raw(objdef.id, obj_type) + + def object_exists_raw(self, id: HSMKeyID, type: OBJECT) -> ObjectInfo | None: + try: + return self.backend.get_object(id, type).get_info() + except YubiHsmDeviceError as e: + if e.code == ERROR.OBJECT_NOT_FOUND: + return None + raise e + + def put_wrap_key(self, keydef: HSMWrapKey, secret: bytes) -> ObjectInfo: + wrap_key = self.backend.get_object(keydef.id, OBJECT.WRAP_KEY) + assert isinstance(wrap_key, WrapKey) + res = wrap_key.put( + session=self.backend, + object_id=keydef.id, + label=keydef.label, + algorithm=self.conf.algorithm_from_name(keydef.algorithm), + domains=self.conf.get_domain_bitfield(keydef.domains), + capabilities=self.conf.capability_from_names(set(keydef.capabilities)), + delegated_capabilities=self.conf.capability_from_names(set(keydef.delegated_capabilities)), + key=secret) + return res.get_info() + + def attest_asym_key(self, key_id: HSMKeyID) -> haz_x509.Certificate: + asym_key = self.backend.get_object(key_id, OBJECT.ASYMMETRIC_KEY) + assert isinstance(asym_key, AsymmetricKey) + return asym_key.attest() + + def export_wrapped(self, wrap_key: HSMWrapKey, obj_id: HSMKeyID, obj_type: OBJECT) -> bytes: + wrap_key_obj = self.backend.get_object(wrap_key.id, OBJECT.WRAP_KEY) + assert isinstance(wrap_key_obj, WrapKey) + export_obj = self.backend.get_object(obj_id, obj_type) + return wrap_key_obj.export_wrapped(export_obj) + + def import_wrapped(self, wrap_key: HSMWrapKey, data: bytes) -> ObjectInfo: + wrap_key_obj = self.backend.get_object(wrap_key.id, OBJECT.WRAP_KEY) + assert isinstance(wrap_key_obj, WrapKey) + return wrap_key_obj.import_wrapped(data).get_info() + + def auth_key_put_derived(self, keydef: HSMAuthKey, password: str) -> ObjectInfo: + auth_key = self.backend.get_object(keydef.id, OBJECT.AUTHENTICATION_KEY) + assert isinstance(auth_key, AuthenticationKey) + return auth_key.put_derived( + session=self.backend, + object_id=keydef.id, + label=keydef.label, + domains=self.conf.get_domain_bitfield(keydef.domains), + capabilities=self.conf.capability_from_names(keydef.capabilities), + delegated_capabilities=self.conf.capability_from_names(keydef.delegated_capabilities), + password=password).get_info() + + def auth_key_put(self, keydef: HSMAuthKey, key_enc: bytes, key_mac: bytes) -> ObjectInfo: + auth_key = self.backend.get_object(keydef.id, OBJECT.AUTHENTICATION_KEY) + assert isinstance(auth_key, AuthenticationKey) + return auth_key.put( + session=self.backend, + object_id=keydef.id, + label=keydef.label, + domains=self.conf.get_domain_bitfield(keydef.domains), + capabilities=self.conf.capability_from_names(keydef.capabilities), + delegated_capabilities=self.conf.capability_from_names(keydef.delegated_capabilities), + key_enc=key_enc, + key_mac=key_mac).get_info() + + def sym_key_generate(self, keydef: HSMSymmetricKey) -> ObjectInfo: + sym_key = self.backend.get_object(keydef.id, OBJECT.SYMMETRIC_KEY) + assert isinstance(sym_key, SymmetricKey) + return sym_key.generate( + session=self.backend, + object_id=keydef.id, + label=keydef.label, + domains=self.conf.get_domain_bitfield(keydef.domains), + capabilities=self.conf.capability_from_names(set(keydef.capabilities)), + algorithm=self.conf.algorithm_from_name(keydef.algorithm)).get_info() + + def asym_key_generate(self, keydef: HSMAsymmetricKey) -> ObjectInfo: + asym_key = self.backend.get_object(keydef.id, OBJECT.ASYMMETRIC_KEY) + assert isinstance(asym_key, AsymmetricKey) + return asym_key.generate( + session=self.backend, + object_id=keydef.id, + label=keydef.label, + domains=self.conf.get_domain_bitfield(keydef.domains), + capabilities=self.conf.capability_from_names(set(keydef.capabilities)), + algorithm=self.conf.algorithm_from_name(keydef.algorithm)).get_info() + + def hmac_key_generate(self, keydef: HSMHmacKey) -> ObjectInfo: + hmac_key = self.backend.get_object(keydef.id, OBJECT.HMAC_KEY) + assert isinstance(hmac_key, HmacKey) + return hmac_key.generate( + session=self.backend, + object_id=keydef.id, + label=keydef.label, + domains=self.conf.get_domain_bitfield(keydef.domains), + capabilities=self.conf.capability_from_names(set(keydef.capabilities)), + algorithm=self.conf.algorithm_from_name(keydef.algorithm)).get_info() + + def get_pseudo_random(self, length: int) -> bytes: + return self.backend.get_pseudo_random(length) + + def list_objects(self) -> Sequence[YhsmObject]: + return self.backend.list_objects() + + def delete_object(self, objdef: HSMObjBase) -> None: + assert isinstance(objdef, HSM_KEY_TYPES) + obj_type = _conf_class_to_yhs_object_type[objdef.__class__] + self.delete_object_raw(objdef.id, obj_type) + + def delete_object_raw(self, id: HSMKeyID, type: OBJECT) -> None: + self.backend.get_object(id, type).delete() + + def sign_hmac(self, keydef: HSMHmacKey, data: bytes) -> bytes: + hmac_key = self.backend.get_object(keydef.id, OBJECT.HMAC_KEY) + assert isinstance(hmac_key, HmacKey) + return hmac_key.sign_hmac(data) + + def get_certificate(self, keydef: HSMOpaqueObject) -> haz_x509.Certificate: + obj = self.backend.get_object(keydef.id, OBJECT.OPAQUE) + assert isinstance(obj, Opaque) + return obj.get_certificate() + + def put_certificate(self, keydef: HSMOpaqueObject, certificate: haz_x509.Certificate) -> ObjectInfo: + obj = self.backend.get_object(keydef.id, OBJECT.OPAQUE) + assert isinstance(obj, Opaque) + return obj.put_certificate( + session=self.backend, + object_id=keydef.id, + label=keydef.label, + domains=self.conf.get_domain_bitfield(keydef.domains), + capabilities=self.conf.capability_from_names({'exportable-under-wrap'}), + certificate=certificate).get_info() + + def get_private_key(self, keydef: HSMAsymmetricKey) -> PrivateKeyOrAdapter: + asym_key = self.backend.get_object(keydef.id, OBJECT.ASYMMETRIC_KEY) + assert isinstance(asym_key, AsymmetricKey) + return make_private_key_adapter(asym_key) + + def get_public_key(self, keydef: HSMAsymmetricKey) -> haz_rsa.RSAPublicKey | haz_ec.EllipticCurvePublicKey | haz_ed25519.Ed25519PublicKey: + asym_key = self.backend.get_object(keydef.id, OBJECT.ASYMMETRIC_KEY) + assert isinstance(asym_key, AsymmetricKey) + return asym_key.get_public_key() + +# --------- Mock YubiHSM2 --------- + + +_g_mock_hsms: dict[int, 'MockHSMDevice'] = {} +_g_conf: HSMConfig|None + +def open_mock_hsms(path: str, serial: int, conf: HSMConfig): + """ + Open mock HSM devices from a pickle file and/or + create a new mock HSM device with the given serial. + """ + global _g_mock_hsms, _g_conf + _g_conf = conf + + if os.path.exists(path): + with open(path, 'rb') as f: + _g_mock_hsms = pickle.loads(f.read()) + click.echo(click.style(f"~🤡~ Loaded {len(_g_mock_hsms)} mock YubiHSMs from '{path}' ~🤡~", fg='yellow'), err=True) + + if serial not in _g_mock_hsms: + dev = MockHSMDevice(serial=serial, objects={}) + _g_mock_hsms[serial] = dev + click.echo(click.style(f"~🤡~ Created new mock YubiHSM with serial {serial} ~🤡~", fg='yellow'), err=True) + + # Store the default admin key in the device, like on a fresh YubiHSM2 + ses = MockHSMSession(serial) + ses.auth_key_put_derived( + keydef = _g_conf.admin.default_admin_key, + password = _g_conf.admin.default_admin_password) + + +def save_mock_hsms(path: str): + """ + Save mock HSM devices to a pickle file. + """ + global _g_mock_hsms + with open(path, 'wb') as f: + blob = pickle.dumps(_g_mock_hsms) + f.write(blob) + click.echo(click.style(f"~🤡~ Saved {len(_g_mock_hsms)} mock YubiHSMs to '{path}' ~🤡~", fg='yellow'), err=True) + + +# ---------------------------- + +class MockHSMDevice: + serial: int + objects: dict[tuple[HSMKeyID, OBJECT], 'MockYhsmObject'] = {} + + def __init__(self, serial: int, objects: dict): + self.serial = serial + self.objects = objects + + def get_mock_object(self, key: HSMKeyID, type: OBJECT) -> 'MockYhsmObject': + if (key, type) not in self.objects: + raise YubiHsmDeviceError(ERROR.OBJECT_NOT_FOUND) + return self.objects[(key, type)] + + def put_mock_object(self, obj: 'MockYhsmObject') -> None: + assert isinstance(obj.mock_obj, HSM_KEY_TYPES) + key, type = obj.mock_obj.id, obj.object_type + if (key, type) in self.objects: + raise YubiHsmDeviceError(ERROR.OBJECT_EXISTS) + self.objects[(key, type)] = obj + + def del_mock_object(self, key: HSMKeyID, type: OBJECT) -> None: + if (key, type) not in self.objects: + raise YubiHsmDeviceError(ERROR.OBJECT_NOT_FOUND) + del self.objects[(key, type)] + + +_conf_class_to_yhs_object_type = { + HSMAuthKey: OBJECT.AUTHENTICATION_KEY, + HSMWrapKey: OBJECT.WRAP_KEY, + HSMSymmetricKey: OBJECT.SYMMETRIC_KEY, + HSMAsymmetricKey: OBJECT.ASYMMETRIC_KEY, + HSMHmacKey: OBJECT.HMAC_KEY, + HSMOpaqueObject: OBJECT.OPAQUE +} + +class MockYhsmObject: + """ + Mock version of the YhsmObject class (returned by list_objects() among others). + Implements get_info() and delete() methods only. + """ + def __init__(self, serial: int, mock_obj: HSMObjBase, data: bytes): + self.mock_device_serial = serial + self.mock_obj = mock_obj + self.data = data + + @property + def id(self) -> HSMKeyID: + assert isinstance(self.mock_obj, HSM_KEY_TYPES) + return self.mock_obj.id + + @property + def object_type(self) -> OBJECT: + assert isinstance(self.mock_obj, HSM_KEY_TYPES) + return _conf_class_to_yhs_object_type[self.mock_obj.__class__] + + def get_info(self) -> ObjectInfo: + global _g_conf + assert _g_conf + assert isinstance(self.mock_obj, HSM_KEY_TYPES) + + if algo_name := getattr(self.mock_obj, "algorithm", None): + yhsm_algo = _g_conf.algorithm_from_name(algo_name) + elif isinstance(self.mock_obj, HSMAuthKey): + yhsm_algo = ALGORITHM.AES128_YUBICO_AUTHENTICATION + else: + raise ValueError(f"Don't know how to get algorithm for object: {self.mock_obj}") + + yhsm_caps = CAPABILITY.NONE + if caps := getattr(self.mock_obj, "capabilities", None): + yhsm_caps = _g_conf.capability_from_names(set(caps)) + + yhsm_deleg_caps = CAPABILITY.NONE + if deleg_caps := getattr(self.mock_obj, "delegated_capabilities", None): + yhsm_deleg_caps = _g_conf.capability_from_names(set(deleg_caps)) + + return ObjectInfo( + id = self.mock_obj.id, + object_type = self.object_type, + algorithm = yhsm_algo, + label = self.mock_obj.label, + size = len(self.data), + domains = _g_conf.get_domain_bitfield(self.mock_obj.domains), + sequence = 123456, + origin = ORIGIN.IMPORTED, + capabilities = yhsm_caps, + delegated_capabilities = yhsm_deleg_caps) + + def delete(self) -> None: + global _g_mock_hsms + assert isinstance(self.mock_obj, HSM_KEY_TYPES) + key = (self.mock_obj.id, self.object_type) + assert self.mock_device_serial in _g_mock_hsms, f"Mock device not found: {self.mock_device_serial}" + device = _g_mock_hsms[self.mock_device_serial] + device.del_mock_object(key[0], key[1]) + + def __repr__(self): + return "{0.__class__.__name__}(id={0.id})".format(self) + + +class MockHSMSession(HSMSession): + """ + Implementation of the HSM session for a mock HSM device. + """ + + def __init__(self, dev_serial: int): + global _g_mock_hsms + self.backend = _g_mock_hsms[dev_serial] + self.dev_serial = dev_serial + + def get_serial(self) -> int: + return self.dev_serial + + def get_info(self, objdef: HSMObjBase) -> ObjectInfo: + res = self.object_exists(objdef) + if not res: + raise YubiHsmDeviceError(ERROR.OBJECT_NOT_FOUND) + return res + + def get_info_raw(self, id: HSMKeyID, type: OBJECT) -> ObjectInfo: + res = self.object_exists_raw(id, type) + if not res: + raise YubiHsmDeviceError(ERROR.OBJECT_NOT_FOUND) + return res + + def object_exists(self, objdef: HSMObjBase) -> ObjectInfo | None: + assert isinstance(objdef, HSM_KEY_TYPES) + obj_type = _conf_class_to_yhs_object_type[objdef.__class__] + return self.object_exists_raw(objdef.id, obj_type) + + def object_exists_raw(self, id: HSMKeyID, type: OBJECT) -> ObjectInfo | None: + try: + return self.backend.get_mock_object(id, type).get_info() + except YubiHsmDeviceError as e: + if e.code == ERROR.OBJECT_NOT_FOUND: + return None + raise e + + def put_wrap_key(self, keydef: HSMWrapKey, secret: bytes) -> ObjectInfo: + obj = MockYhsmObject(self.backend.serial, keydef, secret) + self.backend.objects[(keydef.id, OBJECT.WRAP_KEY)] = obj + return obj.get_info() + + def attest_asym_key(self, key_id: HSMKeyID) -> haz_x509.Certificate: + asym_pem = self.backend.get_mock_object(key_id, OBJECT.ASYMMETRIC_KEY).data + asym_key = haz_ser.load_pem_private_key(asym_pem, password=None) + assert isinstance(asym_key, (haz_rsa.RSAPrivateKey, haz_ec.EllipticCurvePrivateKey, haz_ed25519.Ed25519PrivateKey)) + return haz_x509.CertificateBuilder().subject_name(haz_x509.Name([ + haz_x509.NameAttribute(haz_x509.NameOID.COMMON_NAME, "self-signed") + ])).sign(asym_key, haz_hashes.SHA256()) + + def export_wrapped(self, wrap_key: HSMWrapKey, obj_id: HSMKeyID, obj_type: OBJECT) -> bytes: + if not self.object_exists(wrap_key): + raise click.ClickException(f"Wrap key missing. Create it first.") + aes_key = self.backend.objects.get((wrap_key.id, OBJECT.WRAP_KEY)) + if not aes_key: + raise YubiHsmDeviceError(ERROR.OBJECT_NOT_FOUND) + export_blob = pickle.dumps(self.backend.get_mock_object(obj_id, obj_type)) + cipher = haz_ciphers.Cipher(haz_algs.AES(aes_key.data), haz_cipher_modes.GCM(b"\0" * 16)) + encryptor = cipher.encryptor() + enc_blob = encryptor.update(export_blob) + encryptor.finalize() + tag = encryptor.tag + return pickle.dumps((enc_blob, tag)) + + def import_wrapped(self, wrap_key: HSMWrapKey, data: bytes) -> ObjectInfo: + aes_key = self.backend.get_mock_object(wrap_key.id, OBJECT.WRAP_KEY).data + decryptor = haz_ciphers.Cipher(haz_algs.AES(aes_key), haz_cipher_modes.GCM(b"\0" * 16)).decryptor() + enc_blob, tag = pickle.loads(data) + export_blob = decryptor.update(enc_blob) + decryptor.finalize_with_tag(tag) + obj: MockYhsmObject = pickle.loads(export_blob) + assert isinstance(obj.mock_obj, HSM_KEY_TYPES) + self.backend.put_mock_object(obj) + return obj.get_info() + + def auth_key_put_derived(self, keydef: HSMAuthKey, password: str) -> ObjectInfo: + data = f"derived:{password}".encode() + obj = MockYhsmObject(self.backend.serial, keydef, data) + self.backend.objects[(keydef.id, OBJECT.AUTHENTICATION_KEY)] = obj + return obj.get_info() + + def auth_key_put(self, keydef: HSMAuthKey, key_enc: bytes, key_mac: bytes) -> ObjectInfo: + data = f"key_enc:{key_enc.hex()},key_mac:{key_mac.hex()}".encode() + obj = MockYhsmObject(self.backend.serial, keydef, data) + self.backend.objects[(keydef.id, OBJECT.AUTHENTICATION_KEY)] = obj + return obj.get_info() + + def sym_key_generate(self, keydef: HSMSymmetricKey) -> ObjectInfo: + data = {"key_enc": self.get_pseudo_random(256//8), "key_mac": self.get_pseudo_random(256//8)} + obj = MockYhsmObject(self.backend.serial, keydef, pickle.dumps(data)) + self.backend.objects[(keydef.id, OBJECT.SYMMETRIC_KEY)] = obj + return obj.get_info() + + def asym_key_generate(self, keydef: HSMAsymmetricKey) -> ObjectInfo: + priv_key: PrivateKeyOrAdapter + if "rsa" in keydef.algorithm.lower(): + priv_key = haz_rsa.generate_private_key(public_exponent=65537, key_size=2048) + elif "ec" in keydef.algorithm.lower(): + priv_key = haz_ec.generate_private_key(haz_ec.SECP256R1()) + elif "ed25519" in keydef.algorithm.lower(): + priv_key = haz_ed25519.Ed25519PrivateKey.generate() + else: + raise ValueError(f"Unsupported algorithm: {keydef.algorithm}") + + priv_pem = priv_key.private_bytes(haz_ser.Encoding.PEM, haz_ser.PrivateFormat.PKCS8, haz_ser.NoEncryption()) + obj = MockYhsmObject(self.backend.serial, keydef, priv_pem) + self.backend.objects[(keydef.id, OBJECT.ASYMMETRIC_KEY)] = obj + return obj.get_info() + + def hmac_key_generate(self, keydef: HSMHmacKey) -> ObjectInfo: + data = self.get_pseudo_random(256//8) + obj = MockYhsmObject(self.backend.serial, keydef, data) + self.backend.put_mock_object(obj) + return obj.get_info() + + def get_pseudo_random(self, length: int) -> bytes: + return (b'0123' * length)[:length] # Mock: use deterministic data for tests + + def list_objects(self) -> Sequence[MockYhsmObject]: + return list(self.backend.objects.values()) + + def delete_object(self, objdef: HSMObjBase) -> None: + assert isinstance(objdef, HSM_KEY_TYPES) + obj_type = _conf_class_to_yhs_object_type[objdef.__class__] + self.delete_object_raw(objdef.id, obj_type) + + def delete_object_raw(self, id: HSMKeyID, type: OBJECT) -> None: + self.backend.del_mock_object(id, type) + + def sign_hmac(self, keydef: HSMHmacKey, data: bytes) -> bytes: + hmac_key = self.backend.objects[(keydef.id, OBJECT.HMAC_KEY)].data + hmac = haz_hmac.HMAC(hmac_key, haz_hashes.SHA256()) + hmac.update(data) + return hmac.finalize() + + def get_certificate(self, keydef: HSMOpaqueObject) -> haz_x509.Certificate: + return haz_x509.load_pem_x509_certificate(self.backend.objects[(keydef.id, OBJECT.OPAQUE)].data) + + def put_certificate(self, keydef: HSMOpaqueObject, certificate: haz_x509.Certificate) -> ObjectInfo: + obj = MockYhsmObject(self.backend.serial, keydef, certificate.public_bytes(encoding=haz_ser.Encoding.PEM)) + self.backend.objects[(keydef.id, OBJECT.OPAQUE)] = obj + return obj.get_info() + + def get_private_key(self, keydef: HSMAsymmetricKey) -> PrivateKeyOrAdapter: + asym_pem = self.backend.objects[(keydef.id, OBJECT.ASYMMETRIC_KEY)].data + asym_key = haz_ser.load_pem_private_key(asym_pem, password=None) + assert isinstance(asym_key, (haz_rsa.RSAPrivateKey, haz_ec.EllipticCurvePrivateKey, haz_ed25519.Ed25519PrivateKey)) + return asym_key + + def get_public_key(self, keydef: HSMAsymmetricKey) -> haz_rsa.RSAPublicKey | haz_ec.EllipticCurvePublicKey | haz_ed25519.Ed25519PublicKey: + asym_pem = self.backend.objects[(keydef.id, OBJECT.ASYMMETRIC_KEY)].data + asym_key = haz_ser.load_pem_private_key(asym_pem, password=None) + assert isinstance(asym_key, (haz_rsa.RSAPrivateKey, haz_ec.EllipticCurvePrivateKey, haz_ed25519.Ed25519PrivateKey)) + return asym_key.public_key() diff --git a/run-tests.sh b/run-tests.sh new file mode 100755 index 0000000..15fff4d --- /dev/null +++ b/run-tests.sh @@ -0,0 +1,181 @@ +#!/bin/bash +set -e + +TEMPDIR=$(mktemp -d /tmp/hsm-secret-test.XXXXXX) +[[ $TEMPDIR =~ ^/tmp/hsm-secret-test ]] || { echo "Error: Invalid temp directory"; exit 1; } +trap "rm -rf $TEMPDIR" EXIT + +cp hsm-conf.yml $TEMPDIR/ +MOCKDB="$TEMPDIR/mock.pickle" +CMD="./_venv/bin/hsm-secrets -c $TEMPDIR/hsm-conf.yml --mock $MOCKDB" + + +# Helpers for `expect` calls: +# - Preamble sets up an infallible timeout handler. +# - Postamble reads the exit status of the last spawned process and exits with it. +EXPECT_PREAMBLE=' + set timeout 5 + proc handle_timeout {} { puts "Timeout. Aborting."; catch {exec kill -9 [exp_pid]}; exit 1 } ' +EXPECT_POSTAMBLE=' + set wait_result [wait] + if {[llength $wait_result] == 4} { + lassign $wait_result pid spawnid os_error_flag value + if {$os_error_flag == 0} { puts "exit status: $value"; exit $value } + else { puts "errno: $value"; exit 1 } + } else { puts "Unexpected wait result"; exit 1 } ' + + +run_cmd() { + echo "$ $CMD $@" + $CMD "$@" +} + +assert_success() { + if [ $? -ne 0 ]; then + echo "ERROR: Expected success, but command failed" + return 1 + fi +} + +assert_grep() { + if ! grep -q "$1" <<< "$2"; then + echo "ERROR: Expected output to contain '$1'" + return 1 + fi +} + +setup() { + run_cmd -q hsm compare --create + run_cmd x509 create -a + # `add-service` command is interactive => use `expect` to provide input + expect << EOF + $EXPECT_PREAMBLE + spawn sh -c "$CMD user add-service 0x0008 2>&1" + expect { + "Press ENTER" { sleep 0.1; send "\r"; exp_continue } + "3031-3233-3031" { sleep 0.1; send "\r"; exp_continue } + "again to confirm" { sleep 0.1; send "3031-3233-3031-3233-3031-3233-3031-3233"; sleep 0.1; send "\r"; exp_continue } + timeout { handle_timeout } + eof {} + } + $EXPECT_POSTAMBLE +EOF + run_cmd -q hsm make-wrap-key +} + +# ------------------ test cases ------------------------- + +test_fresh_device() { + local count=$(run_cmd -q hsm list-objects | grep -c '^0x') + [ "$count" -eq 1 ] || { echo "Expected 1 object, but found $count"; return 1; } +} + +test_create_all() { + setup + local count=$(run_cmd -q hsm compare | grep -c '\[x\]') + [ "$count" -eq 35 ] || { echo "Expected 35 objects, but found $count"; return 1; } +} + +test_ssh_certificates() { + setup + run_cmd ssh get-ca --all | ssh-keygen -l -f /dev/stdin + assert_success + + ssh-keygen -t ed25519 -f $TEMPDIR/testkey -N '' -C 'testkey' + run_cmd ssh sign -u test.user -p users,admins $TEMPDIR/testkey.pub + assert_success + + local output=$(ssh-keygen -L -f $TEMPDIR/testkey-cert.pub) + assert_success + assert_grep "Public key: ED25519" "$output" + assert_grep "^[[:space:]]*users$" "$output" + assert_grep "^[[:space:]]*admins$" "$output" + assert_grep 'Key ID: "test.user-[0-9]*-users+admins"' "$output" +} + +test_tls_certificates() { + setup + run_cmd -q x509 get --all | openssl x509 -text -noout + assert_success + + run_cmd tls server-cert --out $TEMPDIR/www-example-com.pem --common-name www.example.com --san-dns www.example.org --san-ip 192.168.0.1 --san-ip fd12:123::80 --keyfmt rsa4096 + assert_success + + local output=$(openssl crl2pkcs7 -nocrl -certfile $TEMPDIR/www-example-com.cer.pem | openssl pkcs7 -print_certs | openssl x509 -text -noout) + assert_success + assert_grep 'Subject:.*CN=www.example.com.*L=Duckburg.*ST=Calisota.*C=US' "$output" + assert_grep 'DNS:www.example.org' "$output" + assert_grep 'IP Address:192.168.0.1' "$output" + assert_grep 'IP Address:FD12:123' "$output" + assert_grep 'Public.*4096' "$output" + assert_grep 'Signature.*ecdsa' "$output" + + [ -f $TEMPDIR/www-example-com.key.pem ] || { echo "ERROR: Key not saved"; return 1; } + [ -f $TEMPDIR/www-example-com.csr.pem ] || { echo "ERROR: CSR not saved"; return 1; } + [ -f $TEMPDIR/www-example-com.chain.pem ] || { echo "ERROR: Chain bundle not saved"; return 1; } +} + +test_password_derivation() { + setup + local output=$(run_cmd -q pass get www.example.com) + assert_grep 'dignity.proud.material.upset.elegant.finish' "$output" + + local nonce=$(run_cmd -q pass rotate www.example.com | grep nonce) + sed -E "s|^( *)\-.*name_hmac.*nonce.*ts.*$|\1${nonce}|" < $TEMPDIR/hsm-conf.yml > $TEMPDIR/rotated-conf.yml + mv $TEMPDIR/rotated-conf.yml $TEMPDIR/hsm-conf.yml + + output=$(run_cmd -q pass get www.example.com) + ! grep -q 'dignity.proud.material.upset.elegant.finish' <<< "$output" || { echo "ERROR: password not rotated"; return 1; } +} + +test_wrapped_backup() { + setup + run_cmd -q hsm backup --out $TEMPDIR/backup.tgz + assert_success + + tar tvfz $TEMPDIR/backup.tgz | grep -q 'ASYMMETRIC_KEY' || { echo "ERROR: No asymmetric keys found in backup"; return 1; } + tar tvfz $TEMPDIR/backup.tgz | grep -q 'OPAQUE' || { echo "ERROR: No certificates found in backup"; return 1; } + + run_cmd -q hsm delete --force 0x0210 + run_cmd -q hsm compare | grep -q '[ ].*ca-root-key-rsa' || { echo "ERROR: Key not deleted"; return 1; } + + run_cmd -q hsm restore --force $TEMPDIR/backup.tgz + run_cmd -q hsm compare | grep -q '[x].*ca-root-key-rsa' || { echo "ERROR: Key not restored"; return 1; } +} + +# ------------------------------------------------------ + +function run_test_quiet() { + echo -n " $1 ... " + local output + if output=$($1 2>&1); then + echo "OK" + else + echo "FAILED" + echo "Error output:" + echo "$output" + return 1 + fi + rm -f $MOCKDB +} + +run_test() { + echo -n " $1 ... " + if $1; then + echo "OK" + else + echo "FAILED" + return 1 + fi + rm -f $MOCKDB +} + +echo "Running tests:" +run_test test_fresh_device +run_test test_create_all +run_test test_ssh_certificates +run_test test_tls_certificates +run_test test_password_derivation +run_test test_wrapped_backup + +echo "All tests passed successfully!"