Skip to content

Commit

Permalink
issue9:
Browse files Browse the repository at this point in the history
add log_print util and add basic cli feedback
  • Loading branch information
dennisvang committed Jun 30, 2022
1 parent d540445 commit 0092f8d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 26 deletions.
47 changes: 28 additions & 19 deletions src/notsotuf/repo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import packaging.version
from tuf.api.metadata import TOP_LEVEL_ROLE_NAMES

from notsotuf.utils import input_bool, input_numeric, input_text, input_list
from notsotuf.utils import (
log_print, input_bool, input_numeric, input_text, input_list
)
from notsotuf.repo import Repository

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,11 +34,17 @@
)


def _print_info(message: str):
return log_print(message=message, level=logging.INFO, logger=logger)


def _get_repo():
try:
return Repository.from_config()
except TypeError:
print('Failed to load config. Did you initialize the repository?')
_print_info(
'Failed to load config. Did you initialize the repository?'
)


def _add_key_dirs_argument(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -207,17 +215,18 @@ def _cmd_init(options: argparse.Namespace):
message = 'Modifying existing configuration.'
else:
message = 'Using existing configuration.'
logger.info(message)
_print_info(message)
if modify:
config_dict = _get_config_from_user(**config_dict)
# create repository instance
repository = Repository(**config_dict)
# save new or updated configuration
_print_info('Saving config...')
repository.save_config()
logger.info('Config saved.')
# create directories, keys, and root metadata file
_print_info('Initializing repository...')
repository.initialize()
logger.info('Repository initialized.')
_print_info('Done.')


def _cmd_keys(options: argparse.Namespace):
Expand All @@ -230,53 +239,49 @@ def _cmd_keys(options: argparse.Namespace):
key_name=options.new_key_name
)
if options.create:
logger.info(f'Creating key pair for {options.new_key_name}...')
_print_info(f'Creating key pair for {options.new_key_name}...')
repository.keys.create_key_pair(
private_key_path=private_key_path, encrypted=options.encrypted
)
logger.info(f'Key pair created.')
_print_info(f'Key pair created.')
replace = hasattr(options, 'old_key_name')
add = hasattr(options, 'role_name')
if replace:
logger.info(
_print_info(
f'Replacing key {options.old_key_name} by {options.new_key_name}...'
)
repository.replace_key(
old_key_name=options.old_key_name,
new_public_key_path=public_key_path,
new_private_key_encrypted=options.encrypted,
)
logger.info('Key replaced.')
elif add:
logger.info(f'Adding key {options.new_key_name}...')
_print_info(f'Adding key {options.new_key_name}...')
repository.add_key(
role_name=options.role_name,
public_key_path=public_key_path,
encrypted=options.encrypted,
)
logger.info('Key added.')
if replace or add:
logger.info('Publishing changes...')
_print_info('Publishing changes...')
repository.publish_changes(private_key_dirs=options.key_dirs)
logger.info('Changes published.')
_print_info('Done.')


def _cmd_targets(options: argparse.Namespace):
logger.debug(f'command targets: {vars(options)}')
repository = _get_repo()
if hasattr(options, 'app_version') and hasattr(options, 'bundle_dir'):
logger.info('Adding bundle...')
_print_info('Adding bundle...')
repository.add_bundle(
new_version=options.app_version, new_bundle_dir=options.bundle_dir
)
logger.info('Bundle added.')
else:
logger.debug('Removing latest bundle...')
_print_info('Removing latest bundle...')
repository.remove_latest_bundle()
logger.info('Latest bundle removed.')
logger.info('Publishing changes...')
_print_info('Publishing changes...')
repository.publish_changes(private_key_dirs=options.key_dirs)
logger.info('Changes published.')
_print_info('Done.')


def _cmd_sign(options: argparse.Namespace):
Expand All @@ -288,15 +293,19 @@ def _cmd_sign(options: argparse.Namespace):
if options.expiration_days.isnumeric():
days = int(options.expiration_days)
# change expiration date in signed metadata
_print_info(f'Setting expiration date {days} days from now...')
repository.refresh_expiration_date(
role_name=options.role_name, days=days
)
# also update version and expiration date for dependent roles, and sign
# modified roles
_print_info('Publishing changes...')
repository.publish_changes(private_key_dirs=options.key_dirs)
else:
# sign without changing the signed metadata (for threshold signing)
_print_info('Adding signature...')
repository.threshold_sign(
role_name=options.role_name,
private_key_dirs=options.key_dirs,
)
_print_info('Done.')
46 changes: 39 additions & 7 deletions src/notsotuf/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import pathlib
import shutil
import sys
from typing import List, Optional, Union

logger = logging.getLogger(__name__)
utils_logger = logging.getLogger(__name__)

_INPUT_SEPARATOR = ' '

Expand All @@ -22,30 +23,58 @@ def remove_path(path: Union[pathlib.Path, str]) -> bool:
try:
if path.is_dir():
shutil.rmtree(path=path)
logger.debug(f'Removed directory {path}')
utils_logger.debug(f'Removed directory {path}')
elif path.is_file():
path.unlink()
logger.debug(f'Removed file {path}')
utils_logger.debug(f'Removed file {path}')
except Exception as e:
logger.error(f'Failed to remove {path}: {e}')
utils_logger.error(f'Failed to remove {path}: {e}')
return False
return True


def log_print(message: str, logger: logging.Logger, level: int = logging.INFO):
"""
Log message with specified level.
Print message too, if logger is not enabled for specified level,
or if logger does not have a handler that streams to stdout.
"""
message_logged_to_stdout = False
current_logger = logger
while current_logger and not message_logged_to_stdout:
is_enabled = current_logger.isEnabledFor(level)
logs_to_stdout = any(
getattr(handler, 'stream', None) == sys.stdout
for handler in current_logger.handlers
)
message_logged_to_stdout = is_enabled and logs_to_stdout
if not current_logger.propagate:
current_logger = None
else:
current_logger = current_logger.parent
if not message_logged_to_stdout:
print(message)
logger.log(level=level, msg=message)


def input_bool(prompt: str, default: bool) -> bool:
true_inputs = ['y']
default_str = ' (y/[n])'
if default:
default_str = ' ([y]/n)'
true_inputs.append('')
return input(prompt + default_str + _INPUT_SEPARATOR) in true_inputs
prompt += default_str + _INPUT_SEPARATOR
answer = input(prompt)
utils_logger.debug(f'{prompt}: {answer}')
return answer in true_inputs


def input_list(
prompt: str, default: List[str], item_default: Optional[str] = None
) -> List[str]:
new_list = []
print(prompt)
log_print(message=prompt, level=logging.DEBUG, logger=utils_logger)
# handle existing items
for existing_item in default or []:
if input_bool(f'{existing_item}\nKeep this item?', default=True):
Expand All @@ -60,8 +89,10 @@ def input_list(
def input_numeric(prompt: str, default: int) -> int:
answer = 'not empty, not numeric'
default_str = f' (default: {default})'
prompt += default_str + _INPUT_SEPARATOR
while answer and not answer.isnumeric():
answer = input(prompt + default_str + _INPUT_SEPARATOR)
answer = input(prompt)
utils_logger.debug(f'{prompt}: {answer}')
if answer:
return int(answer)
else:
Expand All @@ -77,6 +108,7 @@ def input_text(
prompt += _INPUT_SEPARATOR
while not answer:
answer = input(prompt) or default
utils_logger.debug(f'{prompt}: {answer}')
if optional:
break
return answer

0 comments on commit 0092f8d

Please sign in to comment.