diff --git a/src/notsotuf/repo/cli.py b/src/notsotuf/repo/cli.py index da2137f..d92ca83 100644 --- a/src/notsotuf/repo/cli.py +++ b/src/notsotuf/repo/cli.py @@ -4,7 +4,9 @@ import packaging.version from tuf.api.metadata import TOP_LEVEL_ROLE_NAMES -from notsotuf.utils import input_bool, input_numeric, input_text, input_list +from notsotuf.utils import ( + log_print, input_bool, input_numeric, input_text, input_list +) from notsotuf.repo import Repository logger = logging.getLogger(__name__) @@ -32,11 +34,17 @@ ) +def _print_info(message: str): + return log_print(message=message, level=logging.INFO, logger=logger) + + def _get_repo(): try: return Repository.from_config() except TypeError: - print('Failed to load config. Did you initialize the repository?') + _print_info( + 'Failed to load config. Did you initialize the repository?' + ) def _add_key_dirs_argument(parser: argparse.ArgumentParser): @@ -207,17 +215,18 @@ def _cmd_init(options: argparse.Namespace): message = 'Modifying existing configuration.' else: message = 'Using existing configuration.' - logger.info(message) + _print_info(message) if modify: config_dict = _get_config_from_user(**config_dict) # create repository instance repository = Repository(**config_dict) # save new or updated configuration + _print_info('Saving config...') repository.save_config() - logger.info('Config saved.') # create directories, keys, and root metadata file + _print_info('Initializing repository...') repository.initialize() - logger.info('Repository initialized.') + _print_info('Done.') def _cmd_keys(options: argparse.Namespace): @@ -230,15 +239,15 @@ def _cmd_keys(options: argparse.Namespace): key_name=options.new_key_name ) if options.create: - logger.info(f'Creating key pair for {options.new_key_name}...') + _print_info(f'Creating key pair for {options.new_key_name}...') repository.keys.create_key_pair( private_key_path=private_key_path, encrypted=options.encrypted ) - logger.info(f'Key pair created.') + _print_info(f'Key pair created.') replace = hasattr(options, 'old_key_name') add = hasattr(options, 'role_name') if replace: - logger.info( + _print_info( f'Replacing key {options.old_key_name} by {options.new_key_name}...' ) repository.replace_key( @@ -246,37 +255,33 @@ def _cmd_keys(options: argparse.Namespace): new_public_key_path=public_key_path, new_private_key_encrypted=options.encrypted, ) - logger.info('Key replaced.') elif add: - logger.info(f'Adding key {options.new_key_name}...') + _print_info(f'Adding key {options.new_key_name}...') repository.add_key( role_name=options.role_name, public_key_path=public_key_path, encrypted=options.encrypted, ) - logger.info('Key added.') if replace or add: - logger.info('Publishing changes...') + _print_info('Publishing changes...') repository.publish_changes(private_key_dirs=options.key_dirs) - logger.info('Changes published.') + _print_info('Done.') def _cmd_targets(options: argparse.Namespace): logger.debug(f'command targets: {vars(options)}') repository = _get_repo() if hasattr(options, 'app_version') and hasattr(options, 'bundle_dir'): - logger.info('Adding bundle...') + _print_info('Adding bundle...') repository.add_bundle( new_version=options.app_version, new_bundle_dir=options.bundle_dir ) - logger.info('Bundle added.') else: - logger.debug('Removing latest bundle...') + _print_info('Removing latest bundle...') repository.remove_latest_bundle() - logger.info('Latest bundle removed.') - logger.info('Publishing changes...') + _print_info('Publishing changes...') repository.publish_changes(private_key_dirs=options.key_dirs) - logger.info('Changes published.') + _print_info('Done.') def _cmd_sign(options: argparse.Namespace): @@ -288,15 +293,19 @@ def _cmd_sign(options: argparse.Namespace): if options.expiration_days.isnumeric(): days = int(options.expiration_days) # change expiration date in signed metadata + _print_info(f'Setting expiration date {days} days from now...') repository.refresh_expiration_date( role_name=options.role_name, days=days ) # also update version and expiration date for dependent roles, and sign # modified roles + _print_info('Publishing changes...') repository.publish_changes(private_key_dirs=options.key_dirs) else: # sign without changing the signed metadata (for threshold signing) + _print_info('Adding signature...') repository.threshold_sign( role_name=options.role_name, private_key_dirs=options.key_dirs, ) + _print_info('Done.') \ No newline at end of file diff --git a/src/notsotuf/utils/__init__.py b/src/notsotuf/utils/__init__.py index 339e5ea..a14e80c 100644 --- a/src/notsotuf/utils/__init__.py +++ b/src/notsotuf/utils/__init__.py @@ -1,9 +1,10 @@ import logging import pathlib import shutil +import sys from typing import List, Optional, Union -logger = logging.getLogger(__name__) +utils_logger = logging.getLogger(__name__) _INPUT_SEPARATOR = ' ' @@ -22,30 +23,58 @@ def remove_path(path: Union[pathlib.Path, str]) -> bool: try: if path.is_dir(): shutil.rmtree(path=path) - logger.debug(f'Removed directory {path}') + utils_logger.debug(f'Removed directory {path}') elif path.is_file(): path.unlink() - logger.debug(f'Removed file {path}') + utils_logger.debug(f'Removed file {path}') except Exception as e: - logger.error(f'Failed to remove {path}: {e}') + utils_logger.error(f'Failed to remove {path}: {e}') return False return True +def log_print(message: str, logger: logging.Logger, level: int = logging.INFO): + """ + Log message with specified level. + + Print message too, if logger is not enabled for specified level, + or if logger does not have a handler that streams to stdout. + """ + message_logged_to_stdout = False + current_logger = logger + while current_logger and not message_logged_to_stdout: + is_enabled = current_logger.isEnabledFor(level) + logs_to_stdout = any( + getattr(handler, 'stream', None) == sys.stdout + for handler in current_logger.handlers + ) + message_logged_to_stdout = is_enabled and logs_to_stdout + if not current_logger.propagate: + current_logger = None + else: + current_logger = current_logger.parent + if not message_logged_to_stdout: + print(message) + logger.log(level=level, msg=message) + + def input_bool(prompt: str, default: bool) -> bool: true_inputs = ['y'] default_str = ' (y/[n])' if default: default_str = ' ([y]/n)' true_inputs.append('') - return input(prompt + default_str + _INPUT_SEPARATOR) in true_inputs + prompt += default_str + _INPUT_SEPARATOR + answer = input(prompt) + utils_logger.debug(f'{prompt}: {answer}') + return answer in true_inputs def input_list( prompt: str, default: List[str], item_default: Optional[str] = None ) -> List[str]: new_list = [] - print(prompt) + log_print(message=prompt, level=logging.DEBUG, logger=utils_logger) # handle existing items for existing_item in default or []: if input_bool(f'{existing_item}\nKeep this item?', default=True): @@ -60,8 +89,10 @@ def input_list( def input_numeric(prompt: str, default: int) -> int: answer = 'not empty, not numeric' default_str = f' (default: {default})' + prompt += default_str + _INPUT_SEPARATOR while answer and not answer.isnumeric(): - answer = input(prompt + default_str + _INPUT_SEPARATOR) + answer = input(prompt) + utils_logger.debug(f'{prompt}: {answer}') if answer: return int(answer) else: @@ -77,6 +108,7 @@ def input_text( prompt += _INPUT_SEPARATOR while not answer: answer = input(prompt) or default + utils_logger.debug(f'{prompt}: {answer}') if optional: break return answer