diff --git a/.github/workflows/test_code.yaml b/.github/workflows/test_code.yaml index a1fa737f67..f25d44b1f4 100644 --- a/.github/workflows/test_code.yaml +++ b/.github/workflows/test_code.yaml @@ -4,24 +4,24 @@ name: Test code # Run workflow on pushes to matching branches on: # yamllint disable-line rule:truthy push: - branches: [develop] + branches: [develop, python-migration] pull_request: - branches: [develop] + branches: [develop, python-migration] jobs: - test_powershell: + test_python: runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v3 - - name: Install requirements - shell: pwsh - run: | - Set-PSRepository PSGallery -InstallationPolicy Trusted - deployment/CheckRequirements.ps1 -InstallMissing -IncludeDev - - name: Test PowerShell - shell: pwsh - run: ./tests/Run_Pester_Tests.ps1 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Install hatch + run: pip install hatch + - name: Test Python + run: hatch run test:test test_markdown_links: runs-on: ubuntu-latest diff --git a/data_safe_haven/README.md b/data_safe_haven/README.md index 83feed0443..03d55a48a8 100644 --- a/data_safe_haven/README.md +++ b/data_safe_haven/README.md @@ -9,16 +9,14 @@ Install the following requirements before starting - Run the following to initialise the deployment [approx 5 minutes]: -```bash -> dsh init +```console +> dsh context add ... +> dsh context create ``` -You will be prompted for various project settings. -If you prefer to enter these at the command line, run `dsh init -h` to see the necessary command line flags. - - Next deploy the Safe Haven Management (SHM) infrastructure [approx 30 minutes]: -```bash +```console > dsh deploy shm ``` @@ -29,13 +27,13 @@ Run `dsh deploy shm -h` to see the necessary command line flags and provide them Note that the phone number must be in full international format. Note that the country code is the two letter `ISO 3166-1 Alpha-2` code. -```bash +```console > dsh admin add-users ``` - Next deploy the infrastructure for one or more Secure Research Environments (SREs) [approx 30 minutes]: -```bash +```console > dsh deploy sre ``` @@ -44,7 +42,7 @@ Run `dsh deploy sre -h` to see the necessary command line flags and provide them - Next add one or more existing users to your SRE -```bash +```console > dsh admin register-users -s ``` @@ -54,7 +52,7 @@ where you must specify the usernames for each user you want to add to this SRE - Run the following to list the currently available users -```bash +```console > dsh admin list-users ``` @@ -62,20 +60,20 @@ where you must specify the usernames for each user you want to add to this SRE - Run the following if you want to teardown a deployed SRE: -```bash +```console > dsh teardown sre ``` - Run the following if you want to teardown the deployed SHM: -```bash +```console > dsh teardown shm ``` -- Run the following if you want to teardown the deployed Data Safe Haven backend: +- Run the following if you want to teardown the deployed Data Safe Haven context: -```bash -> dsh teardown backend +```console +> dsh context teardown ``` ## Code structure @@ -83,15 +81,16 @@ where you must specify the usernames for each user you want to add to this SRE - administration - this is where we keep utility commands for adminstrators of a deployed DSH - eg. "add a user"; "remove a user from an SRE" -- backend +- context - in order to use the Pulumi Azure backend we need a KeyVault, Identity and Storage Account - this code deploys those resources to bootstrap the rest of the Pulumi-based code + - the storage account is also used to store configuration, so that it can be shared by admins - commands - the main `dsh` command line entrypoint lives in `cli.py` - the subsidiary `typer` command line entrypoints (eg. `dsh deploy shm`) live here - config - serialises and deserialises a config file from Azure - - `backend_settings` manages basic settings related to the Azure backend: arguably this could/should live in `backend` + - `context_settings` manages basic settings related to the context: arguably this could/should live in `context` - exceptions - definitions of a Python exception hierarchy - external diff --git a/data_safe_haven/backend/__init__.py b/data_safe_haven/backend/__init__.py deleted file mode 100644 index 7401760657..0000000000 --- a/data_safe_haven/backend/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .backend import Backend - -__all__ = [ - "Backend", -] diff --git a/data_safe_haven/cli.py b/data_safe_haven/cli.py index ae2a5bdf20..a48a3e38af 100644 --- a/data_safe_haven/cli.py +++ b/data_safe_haven/cli.py @@ -7,8 +7,8 @@ from data_safe_haven import __version__ from data_safe_haven.commands import ( admin_command_group, + context_command_group, deploy_command_group, - initialise_command, teardown_command_group, ) from data_safe_haven.exceptions import DataSafeHavenError @@ -69,6 +69,9 @@ def main() -> None: name="admin", help="Perform administrative tasks for a Data Safe Haven deployment.", ) + application.add_typer( + context_command_group, name="context", help="Manage Data Safe Haven contexts." + ) application.add_typer( deploy_command_group, name="deploy", @@ -80,11 +83,6 @@ def main() -> None: help="Tear down a Data Safe Haven component.", ) - # Register direct subcommands - application.command(name="init", help="Initialise a Data Safe Haven deployment.")( - initialise_command - ) - # Start the application try: application() diff --git a/data_safe_haven/commands/__init__.py b/data_safe_haven/commands/__init__.py index 5ba59d66de..299c982302 100644 --- a/data_safe_haven/commands/__init__.py +++ b/data_safe_haven/commands/__init__.py @@ -1,11 +1,11 @@ from .admin import admin_command_group +from .context import context_command_group from .deploy import deploy_command_group -from .init import initialise_command from .teardown import teardown_command_group __all__ = [ "admin_command_group", + "context_command_group", "deploy_command_group", - "initialise_command", "teardown_command_group", ] diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py new file mode 100644 index 0000000000..6a34a74562 --- /dev/null +++ b/data_safe_haven/commands/context.py @@ -0,0 +1,174 @@ +"""Command group and entrypoints for managing a DSH context""" +from typing import Annotated, Optional + +import typer +from rich import print + +from data_safe_haven.config import Config, ContextSettings +from data_safe_haven.config.context_settings import default_config_file_path +from data_safe_haven.context import Context +from data_safe_haven.functions import validate_aad_guid + +context_command_group = typer.Typer() + + +@context_command_group.command() +def show() -> None: + """Show information about the selected context.""" + settings = ContextSettings.from_file() + + current_context_key = settings.selected + current_context = settings.context + + print(f"Current context: [green]{current_context_key}") + print(f"\tName: {current_context.name}") + print(f"\tAdmin Group ID: {current_context.admin_group_id}") + print(f"\tSubscription name: {current_context.subscription_name}") + print(f"\tLocation: {current_context.location}") + + +@context_command_group.command() +def available() -> None: + """Show the available contexts.""" + settings = ContextSettings.from_file() + + current_context_key = settings.selected + available = settings.available + + available.remove(current_context_key) + available = [f"[green]{current_context_key}*[/]", *available] + + print("\n".join(available)) + + +@context_command_group.command() +def switch( + key: Annotated[str, typer.Argument(help="Key of the context to switch to.")] +) -> None: + """Switch the selected context.""" + settings = ContextSettings.from_file() + settings.selected = key + settings.write() + + +@context_command_group.command() +def add( + key: Annotated[str, typer.Argument(help="Key of the context to add.")], + admin_group: Annotated[ + str, + typer.Option( + help="The ID of an Azure group containing all administrators.", + callback=validate_aad_guid, + ), + ], + location: Annotated[ + str, + typer.Option( + help="The Azure location to deploy resources into.", + ), + ], + name: Annotated[ + str, + typer.Option( + help="The human friendly name to give this Data Safe Haven deployment.", + ), + ], + subscription: Annotated[ + str, + typer.Option( + help="The name of an Azure subscription to deploy resources into.", + ), + ], +) -> None: + """Add a new context to the context list.""" + if default_config_file_path().exists(): + settings = ContextSettings.from_file() + settings.add( + key=key, + admin_group_id=admin_group, + location=location, + name=name, + subscription_name=subscription, + ) + else: + # Bootstrap context settings file + settings = ContextSettings( + { + "selected": key, + "contexts": { + key: { + "admin_group_id": admin_group, + "location": location, + "name": name, + "subscription_name": subscription, + } + }, + } + ) + settings.write() + + +@context_command_group.command() +def update( + admin_group: Annotated[ + Optional[str], # noqa: UP007 + typer.Option( + help="The ID of an Azure group containing all administrators.", + callback=validate_aad_guid, + ), + ] = None, + location: Annotated[ + Optional[str], # noqa: UP007 + typer.Option( + help="The Azure location to deploy resources into.", + ), + ] = None, + name: Annotated[ + Optional[str], # noqa: UP007 + typer.Option( + help="The human friendly name to give this Data Safe Haven deployment.", + ), + ] = None, + subscription: Annotated[ + Optional[str], # noqa: UP007 + typer.Option( + help="The name of an Azure subscription to deploy resources into.", + ), + ] = None, +) -> None: + """Update the selected context settings.""" + settings = ContextSettings.from_file() + settings.update( + admin_group_id=admin_group, + location=location, + name=name, + subscription_name=subscription, + ) + settings.write() + + +@context_command_group.command() +def remove( + key: Annotated[str, typer.Argument(help="Name of the context to remove.")], +) -> None: + """Remove the selected context.""" + settings = ContextSettings.from_file() + settings.remove(key) + settings.write() + + +@context_command_group.command() +def create() -> None: + """Create Data Safe Haven context infrastructure.""" + config = Config() + context = Context(config) + context.create() + context.config.upload() + + +@context_command_group.command() +def teardown() -> None: + """Tear down Data Safe Haven context infrastructure.""" + config = Config() + context = Context(config) + context.teardown() diff --git a/data_safe_haven/commands/init.py b/data_safe_haven/commands/init.py deleted file mode 100644 index 916687975f..0000000000 --- a/data_safe_haven/commands/init.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Command-line application for initialising a Data Safe Haven deployment""" -from typing import Annotated, Optional - -import typer - -from data_safe_haven.backend import Backend -from data_safe_haven.config import BackendSettings -from data_safe_haven.exceptions import DataSafeHavenError -from data_safe_haven.functions import validate_aad_guid - - -def initialise_command( - admin_group: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--admin-group", - "-a", - help="The ID of an Azure group containing all administrators.", - callback=validate_aad_guid, - ), - ] = None, - location: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--location", - "-l", - help="The Azure location to deploy resources into.", - ), - ] = None, - name: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--name", - "-n", - help="The name to give this Data Safe Haven deployment.", - ), - ] = None, - subscription: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--subscription", - "-s", - help="The name of an Azure subscription to deploy resources into.", - ), - ] = None, -) -> None: - """Typer command line entrypoint""" - try: - # Load backend settings and update with command line arguments - settings = BackendSettings() - settings.update( - admin_group_id=admin_group, - location=location, - name=name, - subscription_name=subscription, - ) - - # Ensure that the Pulumi backend exists - backend = Backend() - backend.create() - - # Load the generated configuration file and upload it to blob storage - backend.config.upload() - - except DataSafeHavenError as exc: - msg = f"Could not initialise Data Safe Haven.\n{exc}" - raise DataSafeHavenError(msg) from exc diff --git a/data_safe_haven/commands/teardown.py b/data_safe_haven/commands/teardown.py index 2895d20ee7..5e6e7d08ec 100644 --- a/data_safe_haven/commands/teardown.py +++ b/data_safe_haven/commands/teardown.py @@ -3,18 +3,12 @@ import typer -from .teardown_backend import teardown_backend from .teardown_shm import teardown_shm from .teardown_sre import teardown_sre teardown_command_group = typer.Typer() -@teardown_command_group.command(help="Tear down a deployed Data Safe Haven backend.") -def backend() -> None: - teardown_backend() - - @teardown_command_group.command( help="Tear down a deployed a Safe Haven Management component." ) diff --git a/data_safe_haven/commands/teardown_backend.py b/data_safe_haven/commands/teardown_backend.py deleted file mode 100644 index f90e4a6c16..0000000000 --- a/data_safe_haven/commands/teardown_backend.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Tear down a deployed Data Safe Haven backend""" -from data_safe_haven.backend import Backend -from data_safe_haven.exceptions import ( - DataSafeHavenError, - DataSafeHavenInputError, -) - - -def teardown_backend() -> None: - """Tear down a deployed Data Safe Haven backend""" - try: - # Remove the Pulumi backend - try: - backend = Backend() - backend.teardown() - except Exception as exc: - msg = f"Unable to teardown Pulumi backend.\n{exc}" - raise DataSafeHavenInputError(msg) from exc - except DataSafeHavenError as exc: - msg = f"Could not teardown Data Safe Haven backend.\n{exc}" - raise DataSafeHavenError(msg) from exc diff --git a/data_safe_haven/config/__init__.py b/data_safe_haven/config/__init__.py index d7132460da..4723c0bf08 100644 --- a/data_safe_haven/config/__init__.py +++ b/data_safe_haven/config/__init__.py @@ -1,7 +1,7 @@ -from .backend_settings import BackendSettings from .config import Config +from .context_settings import ContextSettings __all__ = [ - "BackendSettings", "Config", + "ContextSettings", ] diff --git a/data_safe_haven/config/backend_settings.py b/data_safe_haven/config/backend_settings.py deleted file mode 100644 index 61816839cf..0000000000 --- a/data_safe_haven/config/backend_settings.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Load global and local settings from dotfiles""" -import pathlib - -import appdirs -import yaml -from yaml.parser import ParserError - -from data_safe_haven.exceptions import ( - DataSafeHavenConfigError, - DataSafeHavenParameterError, -) -from data_safe_haven.utility import LoggingSingleton - - -class BackendSettings: - """Load global and local settings from dotfiles with structure like the following - - azure: - admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - location: uksouth - subscription_name: Data Safe Haven (Acme) - current: - name: Acme Deployment - """ - - def __init__( - self, - ) -> None: - # Define instance variables - self._admin_group_id: str | None = None - self._location: str | None = None - self._name: str | None = None - self._subscription_name: str | None = None - self.logger = LoggingSingleton() - - # Load previous backend settings (if any) - self.config_directory = pathlib.Path( - appdirs.user_config_dir(appname="data_safe_haven") - ).resolve() - self.config_file_path = self.config_directory / "config.yaml" - self.read() - - def update( - self, - *, - admin_group_id: str | None = None, - location: str | None = None, - name: str | None = None, - subscription_name: str | None = None, - ) -> None: - """Overwrite defaults with provided parameters""" - if admin_group_id: - self.logger.debug( - f"Updating '[green]{admin_group_id}[/]' to '{admin_group_id}'." - ) - self._admin_group_id = admin_group_id - if location: - self.logger.debug(f"Updating '[green]{location}[/]' to '{location}'.") - self._location = location - if name: - self.logger.debug(f"Updating '[green]{name}[/]' to '{name}'.") - self._name = name - if subscription_name: - self.logger.debug( - f"Updating '[green]{subscription_name}[/]' to '{subscription_name}'." - ) - self._subscription_name = subscription_name - - # Write backend settings to disk (this will trigger errors for uninitialised parameters) - self.write() - - @property - def admin_group_id(self) -> str: - if not self._admin_group_id: - msg = "Azure administrator group not provided: use '[bright_cyan]--admin-group[/]' / '[green]-a[/]' to do so." - raise DataSafeHavenParameterError(msg) - return self._admin_group_id - - @property - def location(self) -> str: - if not self._location: - msg = "Azure location not provided: use '[bright_cyan]--location[/]' / '[green]-l[/]' to do so." - raise DataSafeHavenParameterError(msg) - return self._location - - @property - def name(self) -> str: - if not self._name: - msg = ( - "Data Safe Haven deployment name not provided:" - " use '[bright_cyan]--name[/]' / '[green]-n[/]' to do so." - ) - raise DataSafeHavenParameterError(msg) - return self._name - - @property - def subscription_name(self) -> str: - if not self._subscription_name: - msg = "Azure subscription not provided: use '[bright_cyan]--subscription[/]' / '[green]-s[/]' to do so." - raise DataSafeHavenParameterError(msg) - return self._subscription_name - - def read(self) -> None: - """Read settings from YAML file""" - try: - if self.config_file_path.exists(): - with open(self.config_file_path, encoding="utf-8") as f_yaml: - settings = yaml.safe_load(f_yaml) - if isinstance(settings, dict): - self.logger.info( - f"Reading project settings from '[green]{self.config_file_path}[/]'." - ) - if admin_group_id := settings.get("azure", {}).get( - "admin_group_id", None - ): - self._admin_group_id = admin_group_id - if location := settings.get("azure", {}).get("location", None): - self._location = location - if name := settings.get("current", {}).get("name", None): - self._name = name - if subscription_name := settings.get("azure", {}).get( - "subscription_name", None - ): - self._subscription_name = subscription_name - except ParserError as exc: - msg = f"Could not load settings from {self.config_file_path}.\n{exc}" - raise DataSafeHavenConfigError(msg) from exc - - def write(self) -> None: - """Write settings to YAML file""" - settings = { - "azure": { - "admin_group_id": self.admin_group_id, - "location": self.location, - "subscription_name": self.subscription_name, - }, - "current": { - "name": self.name, - }, - } - # Create the parent directory if it does not exist then write YAML - self.config_file_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.config_file_path, "w", encoding="utf-8") as f_yaml: - yaml.dump(settings, f_yaml, indent=2) - self.logger.info( - f"Saved project settings to '[green]{self.config_file_path}[/]'." - ) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 1acc30ec90..c0f32247df 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -38,9 +38,10 @@ DatabaseSystem, LoggingSingleton, SoftwarePackageCategory, + config_dir, ) -from .backend_settings import BackendSettings +from .context_settings import ContextSettings class Validator: @@ -94,7 +95,7 @@ class ConfigSectionAzure(ConfigSection): @dataclass -class ConfigSectionBackend(ConfigSection): +class ConfigSectionContext(ConfigSection): key_vault_name: str = "" managed_identity_name: str = "" resource_group_name: str = "" @@ -349,31 +350,32 @@ class Config: def __init__(self) -> None: # Initialise config sections self.azure_: ConfigSectionAzure | None = None - self.backend_: ConfigSectionBackend | None = None + self.context_: ConfigSectionContext | None = None self.pulumi_: ConfigSectionPulumi | None = None self.shm_: ConfigSectionSHM | None = None self.tags_: ConfigSectionTags | None = None self.sres: dict[str, ConfigSectionSRE] = defaultdict(ConfigSectionSRE) - # Read backend settings - settings = BackendSettings() + # Read context settings + settings = ContextSettings.from_file() + context = settings.context # Check if backend exists and was loaded try: - self.name = settings.name + self.name = context.name except DataSafeHavenParameterError as exc: msg = "Data Safe Haven has not been initialised: run '[bright_cyan]dsh init[/]' before continuing." raise DataSafeHavenConfigError(msg) from exc - self.subscription_name = settings.subscription_name - self.azure.location = settings.location - self.azure.admin_group_id = settings.admin_group_id - self.backend_storage_container_name = "config" + self.subscription_name = context.subscription_name + self.azure.location = context.location + self.azure.admin_group_id = context.admin_group_id + self.context_storage_container_name = "config" # Set derived names self.shm_name_ = alphanumeric(self.name).lower() self.filename = f"config-{self.shm_name_}.yaml" - self.backend_resource_group_name = f"shm-{self.shm_name_}-rg-backend" - self.backend_storage_account_name = ( - f"shm{self.shm_name_[:14]}backend" # maximum of 24 characters allowed + self.context_resource_group_name = f"shm-{self.shm_name_}-rg-context" + self.context_storage_account_name = ( + f"shm{self.shm_name_[:14]}context" # maximum of 24 characters allowed ) - self.work_directory = settings.config_directory / self.shm_name_ + self.work_directory = config_dir() / self.shm_name_ self.azure_api = AzureApi(subscription_name=self.subscription_name) # Attempt to load YAML dictionary from blob storage yaml_input = {} @@ -381,18 +383,18 @@ def __init__(self) -> None: yaml_input = yaml.safe_load( self.azure_api.download_blob( self.filename, - self.backend_resource_group_name, - self.backend_storage_account_name, - self.backend_storage_container_name, + self.context_resource_group_name, + self.context_storage_account_name, + self.context_storage_container_name, ) ) # Attempt to decode each config section if yaml_input: if "azure" in yaml_input: self.azure_ = chili.decode(yaml_input["azure"], ConfigSectionAzure) - if "backend" in yaml_input: - self.backend_ = chili.decode( - yaml_input["backend"], ConfigSectionBackend + if "context" in yaml_input: + self.context_ = chili.decode( + yaml_input["context"], ConfigSectionContext ) if "pulumi" in yaml_input: self.pulumi_ = chili.decode(yaml_input["pulumi"], ConfigSectionPulumi) @@ -409,16 +411,16 @@ def azure(self) -> ConfigSectionAzure: return self.azure_ @property - def backend(self) -> ConfigSectionBackend: - if not self.backend_: - self.backend_ = ConfigSectionBackend( - key_vault_name=f"shm-{self.shm_name_[:9]}-kv-backend", - managed_identity_name=f"shm-{self.shm_name_}-identity-reader-backend", - resource_group_name=self.backend_resource_group_name, - storage_account_name=self.backend_storage_account_name, - storage_container_name=self.backend_storage_container_name, + def context(self) -> ConfigSectionContext: + if not self.context_: + self.context_ = ConfigSectionContext( + key_vault_name=f"shm-{self.shm_name_[:9]}-kv-context", + managed_identity_name=f"shm-{self.shm_name_}-identity-reader-context", + resource_group_name=self.context_resource_group_name, + storage_account_name=self.context_storage_account_name, + storage_container_name=self.context_storage_container_name, ) - return self.backend_ + return self.context_ @property def pulumi(self) -> ConfigSectionPulumi: @@ -443,8 +445,8 @@ def __str__(self) -> str: contents: dict[str, Any] = {} if self.azure_: contents["azure"] = self.azure.to_dict() - if self.backend_: - contents["backend"] = self.backend.to_dict() + if self.context_: + contents["context"] = self.context.to_dict() if self.pulumi_: contents["pulumi"] = self.pulumi.to_dict() if self.shm_: @@ -483,9 +485,9 @@ def upload(self) -> None: self.azure_api.upload_blob( str(self), self.filename, - self.backend_resource_group_name, - self.backend_storage_account_name, - self.backend_storage_container_name, + self.context_resource_group_name, + self.context_storage_account_name, + self.context_storage_container_name, ) def write_stack(self, name: str, path: pathlib.Path) -> None: diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py new file mode 100644 index 0000000000..32c0cb0b75 --- /dev/null +++ b/data_safe_haven/config/context_settings.py @@ -0,0 +1,187 @@ +"""Load global and local settings from dotfiles""" +# For postponed evaluation of annotations https://peps.python.org/pep-0563 +from __future__ import ( + annotations, +) + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import yaml +from schema import Schema, SchemaError +from yaml.parser import ParserError + +from data_safe_haven.exceptions import ( + DataSafeHavenConfigError, + DataSafeHavenParameterError, +) +from data_safe_haven.utility import LoggingSingleton, config_dir + + +def default_config_file_path() -> Path: + return config_dir() / "contexts.yaml" + + +@dataclass +class Context: + admin_group_id: str + location: str + name: str + subscription_name: str + + +class ContextSettings: + """Load global and local settings from dotfiles with structure like the following + + selected: acme_deployment + contexts: + acme_deployment: + name: Acme Deployment + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_name: Data Safe Haven (Acme) + ... + """ + + def __init__(self, settings_dict: dict[Any, Any]) -> None: + self.logger = LoggingSingleton() + + context_schema = Schema( + { + "name": str, + "admin_group_id": str, + "location": str, + "subscription_name": str, + } + ) + + schema = Schema( + { + "selected": str, + "contexts": Schema( + { + str: context_schema, + } + ), + } + ) + + try: + self._settings: dict[Any, Any] = schema.validate(settings_dict) + except SchemaError as exc: + msg = f"Invalid context configuration file.\n{exc}" + raise DataSafeHavenParameterError(msg) from exc + + @property + def settings(self) -> dict[Any, Any]: + return self._settings + + @property + def selected(self) -> str: + return str(self.settings["selected"]) + + @selected.setter + def selected(self, context_name: str) -> None: + if context_name in self.available: + self.settings["selected"] = context_name + self.logger.info(f"Switched context to '{context_name}'.") + else: + msg = f"Context '{context_name}' is not defined." + raise DataSafeHavenParameterError(msg) + + @property + def context(self) -> Context: + return Context(**self.settings["contexts"][self.selected]) + + @property + def available(self) -> list[str]: + return list(self.settings["contexts"].keys()) + + def update( + self, + *, + admin_group_id: str | None = None, + location: str | None = None, + name: str | None = None, + subscription_name: str | None = None, + ) -> None: + context_dict = self.settings["contexts"][self.selected] + + if admin_group_id: + self.logger.debug( + f"Updating '[green]{admin_group_id}[/]' to '{admin_group_id}'." + ) + context_dict["admin_group_id"] = admin_group_id + if location: + self.logger.debug(f"Updating '[green]{location}[/]' to '{location}'.") + context_dict["location"] = location + if name: + self.logger.debug(f"Updating '[green]{name}[/]' to '{name}'.") + context_dict["name"] = name + if subscription_name: + self.logger.debug( + f"Updating '[green]{subscription_name}[/]' to '{subscription_name}'." + ) + context_dict["subscription_name"] = subscription_name + + def add( + self, + *, + key: str, + name: str, + admin_group_id: str, + location: str, + subscription_name: str, + ) -> None: + # Ensure context is not already present + if key in self.available: + msg = f"A context with key '{key}' is already defined." + raise DataSafeHavenParameterError(msg) + + self.settings["contexts"][key] = { + "name": name, + "admin_group_id": admin_group_id, + "location": location, + "subscription_name": subscription_name, + } + + def remove(self, key: str) -> None: + if key not in self.available: + msg = f"No context with key '{key}'." + raise DataSafeHavenParameterError(msg) + del self.settings["contexts"][key] + + @classmethod + def from_file(cls, config_file_path: Path | None = None) -> ContextSettings: + if config_file_path is None: + config_file_path = default_config_file_path() + logger = LoggingSingleton() + try: + with open(config_file_path, encoding="utf-8") as f_yaml: + settings = yaml.safe_load(f_yaml) + if isinstance(settings, dict): + logger.info( + f"Reading project settings from '[green]{config_file_path}[/]'." + ) + return cls(settings) + else: + msg = f"Unable to parse {config_file_path} as a dict." + raise DataSafeHavenConfigError(msg) + except FileNotFoundError as exc: + msg = f"Could not find file {config_file_path}.\n{exc}" + raise DataSafeHavenConfigError(msg) from exc + except ParserError as exc: + msg = f"Could not load settings from {config_file_path}.\n{exc}" + raise DataSafeHavenConfigError(msg) from exc + + def write(self, config_file_path: Path | None = None) -> None: + """Write settings to YAML file""" + if config_file_path is None: + config_file_path = default_config_file_path() + # Create the parent directory if it does not exist then write YAML + config_file_path.parent.mkdir(parents=True, exist_ok=True) + + with open(config_file_path, "w", encoding="utf-8") as f_yaml: + yaml.dump(self.settings, f_yaml, indent=2) + self.logger.info(f"Saved context settings to '[green]{config_file_path}[/]'.") diff --git a/data_safe_haven/context/__init__.py b/data_safe_haven/context/__init__.py new file mode 100644 index 0000000000..94370d9ba1 --- /dev/null +++ b/data_safe_haven/context/__init__.py @@ -0,0 +1,5 @@ +from .context import Context + +__all__ = [ + "Context", +] diff --git a/data_safe_haven/backend/backend.py b/data_safe_haven/context/context.py similarity index 79% rename from data_safe_haven/backend/backend.py rename to data_safe_haven/context/context.py index a303955af1..234dc6c4c0 100644 --- a/data_safe_haven/backend/backend.py +++ b/data_safe_haven/context/context.py @@ -1,17 +1,15 @@ -"""Azure backend for a Data Safe Haven deployment""" - from data_safe_haven.config import Config from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external import AzureApi -class Backend: - """Azure backend for a Data Safe Haven deployment""" +class Context: + """Azure resources to support Data Safe Haven context""" - def __init__(self) -> None: + def __init__(self, config: Config) -> None: self.azure_api_: AzureApi | None = None - self.config = Config() - self.tags = {"component": "backend"} | self.config.tags.to_dict() + self.config = config + self.tags = {"component": "context"} | self.config.tags.to_dict() @property def azure_api(self) -> AzureApi: @@ -37,28 +35,28 @@ def create(self) -> None: self.config.azure.tenant_id = self.azure_api.tenant_id resource_group = self.azure_api.ensure_resource_group( location=self.config.azure.location, - resource_group_name=self.config.backend.resource_group_name, + resource_group_name=self.config.context.resource_group_name, tags=self.tags, ) if not resource_group.name: - msg = f"Resource group '{self.config.backend.resource_group_name}' was not created." + msg = f"Resource group '{self.config.context.resource_group_name}' was not created." raise DataSafeHavenAzureError(msg) identity = self.azure_api.ensure_managed_identity( - identity_name=self.config.backend.managed_identity_name, + identity_name=self.config.context.managed_identity_name, location=resource_group.location, resource_group_name=resource_group.name, ) storage_account = self.azure_api.ensure_storage_account( location=resource_group.location, resource_group_name=resource_group.name, - storage_account_name=self.config.backend.storage_account_name, + storage_account_name=self.config.context.storage_account_name, tags=self.tags, ) if not storage_account.name: - msg = f"Storage account '{self.config.backend.storage_account_name}' was not created." + msg = f"Storage account '{self.config.context.storage_account_name}' was not created." raise DataSafeHavenAzureError(msg) _ = self.azure_api.ensure_storage_blob_container( - container_name=self.config.backend.storage_container_name, + container_name=self.config.context.storage_container_name, resource_group_name=resource_group.name, storage_account_name=storage_account.name, ) @@ -69,7 +67,7 @@ def create(self) -> None: ) keyvault = self.azure_api.ensure_keyvault( admin_group_id=self.config.azure.admin_group_id, - key_vault_name=self.config.backend.key_vault_name, + key_vault_name=self.config.context.key_vault_name, location=resource_group.location, managed_identity=identity, resource_group_name=resource_group.name, @@ -77,7 +75,7 @@ def create(self) -> None: ) if not keyvault.name: msg = ( - f"Keyvault '{self.config.backend.key_vault_name}' was not created." + f"Keyvault '{self.config.context.key_vault_name}' was not created." ) raise DataSafeHavenAzureError(msg) pulumi_encryption_key = self.azure_api.ensure_keyvault_key( @@ -87,7 +85,7 @@ def create(self) -> None: key_version = pulumi_encryption_key.id.split("/")[-1] self.config.pulumi.encryption_key_version = key_version except Exception as exc: - msg = f"Failed to create backend resources.\n{exc}" + msg = f"Failed to create context resources.\n{exc}" raise DataSafeHavenAzureError(msg) from exc def teardown(self) -> None: @@ -98,8 +96,8 @@ def teardown(self) -> None: """ try: self.azure_api.remove_resource_group( - self.config.backend.resource_group_name + self.config.context.resource_group_name ) except Exception as exc: - msg = f"Failed to destroy backend resources.\n{exc}" + msg = f"Failed to destroy context resources.\n{exc}" raise DataSafeHavenAzureError(msg) from exc diff --git a/data_safe_haven/external/interface/azure_container_instance.py b/data_safe_haven/external/interface/azure_container_instance.py index 623c837555..1f696f4b04 100644 --- a/data_safe_haven/external/interface/azure_container_instance.py +++ b/data_safe_haven/external/interface/azure_container_instance.py @@ -1,4 +1,3 @@ -"""Backend for a Data Safe Haven environment""" import contextlib import time diff --git a/data_safe_haven/external/interface/azure_postgresql_database.py b/data_safe_haven/external/interface/azure_postgresql_database.py index 395791105a..ad0edc30da 100644 --- a/data_safe_haven/external/interface/azure_postgresql_database.py +++ b/data_safe_haven/external/interface/azure_postgresql_database.py @@ -1,4 +1,3 @@ -"""Backend for a Data Safe Haven environment""" import datetime import pathlib import time diff --git a/data_safe_haven/infrastructure/stack_manager.py b/data_safe_haven/infrastructure/stack_manager.py index 48d6c02d1b..f63026ff1b 100644 --- a/data_safe_haven/infrastructure/stack_manager.py +++ b/data_safe_haven/infrastructure/stack_manager.py @@ -39,11 +39,11 @@ def env(self) -> dict[str, Any]: if not self.env_: azure_api = AzureApi(self.cfg.subscription_name) backend_storage_account_keys = azure_api.get_storage_account_keys( - self.cfg.backend.resource_group_name, - self.cfg.backend.storage_account_name, + self.cfg.context.resource_group_name, + self.cfg.context.storage_account_name, ) self.env_ = { - "AZURE_STORAGE_ACCOUNT": self.cfg.backend.storage_account_name, + "AZURE_STORAGE_ACCOUNT": self.cfg.context.storage_account_name, "AZURE_STORAGE_KEY": str(backend_storage_account_keys[0].value), "AZURE_KEYVAULT_AUTH_VIA_CLI": "true", "PULUMI_BACKEND_URL": f"azblob://{self.cfg.pulumi.storage_container_name}", @@ -68,7 +68,9 @@ def __init__( self.program = program self.project_name = replace_separators(self.cfg.tags.project.lower(), "-") self.stack_name = self.program.stack_name - self.work_dir = config.work_directory / "pulumi" / self.program.short_name + self.work_dir: pathlib.Path = ( + config.work_directory / "pulumi" / self.program.short_name + ) self.work_dir.mkdir(parents=True, exist_ok=True) self.initialise_workdir() self.install_plugins() @@ -98,7 +100,7 @@ def stack(self) -> automation.Stack: stack_name=self.stack_name, program=self.program.run, opts=automation.LocalWorkspaceOptions( - secrets_provider=f"azurekeyvault://{self.cfg.backend.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.pulumi.encryption_key_version}", + secrets_provider=f"azurekeyvault://{self.cfg.context.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.pulumi.encryption_key_version}", work_dir=str(self.work_dir), env_vars=self.account.env, ), @@ -211,8 +213,8 @@ def destroy(self) -> None: azure_api = AzureApi(self.cfg.subscription_name) azure_api.remove_blob( blob_name=f".pulumi/stacks/{self.project_name}/{stack_backup_name}", - resource_group_name=self.cfg.backend.resource_group_name, - storage_account_name=self.cfg.backend.storage_account_name, + resource_group_name=self.cfg.context.resource_group_name, + storage_account_name=self.cfg.context.storage_account_name, storage_container_name=self.cfg.pulumi.storage_container_name, ) except DataSafeHavenAzureError as exc: diff --git a/data_safe_haven/utility/__init__.py b/data_safe_haven/utility/__init__.py index 73ab4620ec..923e6f31f5 100644 --- a/data_safe_haven/utility/__init__.py +++ b/data_safe_haven/utility/__init__.py @@ -1,3 +1,4 @@ +from .directories import config_dir from .enums import DatabaseSystem, SoftwarePackageCategory from .file_reader import FileReader from .logger import LoggingSingleton, NonLoggingSingleton @@ -5,6 +6,7 @@ from .types import PathType __all__ = [ + "config_dir", "DatabaseSystem", "FileReader", "LoggingSingleton", diff --git a/data_safe_haven/utility/directories.py b/data_safe_haven/utility/directories.py new file mode 100644 index 0000000000..593f64bb50 --- /dev/null +++ b/data_safe_haven/utility/directories.py @@ -0,0 +1,15 @@ +from os import getenv +from pathlib import Path + +import appdirs + + +def config_dir() -> Path: + if config_directory_env := getenv("DSH_CONFIG_DIRECTORY"): + config_directory = Path(config_directory_env).resolve() + else: + config_directory = Path( + appdirs.user_config_dir(appname="data_safe_haven") + ).resolve() + + return config_directory diff --git a/pyproject.toml b/pyproject.toml index 95afe80918..7bd45ee03e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "pytz~=2022.7.0", "PyYAML~=6.0", "rich~=13.4.2", + "schema~=0.7.0", "simple-acme-dns~=1.2.0", "typer~=0.9.0", "websocket-client~=1.5.0", @@ -85,6 +86,14 @@ all = [ "typing", ] +[tool.hatch.envs.test] +dependencies = [ + "pytest~=7.4.3" +] + +[tool.hatch.envs.test.scripts] +test = "pytest {args:-vvv tests_}" + [tool.black] target-version = ["py310", "py311"] @@ -166,3 +175,9 @@ module = [ "websocket.*", ] ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", + "--disable-warnings", +] diff --git a/tests_/commands/test_context.py b/tests_/commands/test_context.py new file mode 100644 index 0000000000..0cdf878c44 --- /dev/null +++ b/tests_/commands/test_context.py @@ -0,0 +1,228 @@ +from data_safe_haven.commands.context import context_command_group +from data_safe_haven.config import Config +from data_safe_haven.context import Context + +from pytest import fixture +from typer.testing import CliRunner + +context_settings = """\ + selected: acme_deployment + contexts: + acme_deployment: + name: Acme Deployment + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_name: Data Safe Haven (Acme) + gems: + name: Gems + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_name: Data Safe Haven (Gems)""" + + +@fixture +def tmp_contexts(tmp_path): + config_file_path = tmp_path / "contexts.yaml" + with open(config_file_path, "w") as f: + f.write(context_settings) + return tmp_path + + +@fixture +def runner(tmp_contexts): + runner = CliRunner( + env={ + "DSH_CONFIG_DIRECTORY": str(tmp_contexts), + "COLUMNS": "500", # Set large number of columns to avoid rich wrapping text + "TERM": "dumb", # Disable colours, style and interactive rich features + }, + mix_stderr=False, + ) + return runner + + +class TestShow: + def test_show(self, runner): + result = runner.invoke(context_command_group, ["show"]) + assert result.exit_code == 0 + assert "Current context: acme_deployment" in result.stdout + assert "Name: Acme Deployment" in result.stdout + + +class TestAvailable: + def test_available(self, runner): + result = runner.invoke(context_command_group, ["available"]) + assert result.exit_code == 0 + assert "acme_deployment*" in result.stdout + assert "gems" in result.stdout + + +class TestSwitch: + def test_switch(self, runner): + result = runner.invoke(context_command_group, ["switch", "gems"]) + assert result.exit_code == 0 + assert "Switched context to 'gems'." in result.stdout + result = runner.invoke(context_command_group, ["available"]) + assert result.exit_code == 0 + assert "gems*" in result.stdout + + def test_invalid_switch(self, runner): + result = runner.invoke(context_command_group, ["switch", "invalid"]) + assert result.exit_code == 1 + # Unable to check error as this is written outside of any Typer + # assert "Context 'invalid' is not defined " in result.stdout + + +class TestAdd: + def test_add(self, runner): + result = runner.invoke( + context_command_group, + [ + "add", + "example", + "--name", + "Example", + "--admin-group", + "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "--location", + "uksouth", + "--subscription", + "Data Safe Haven (Example)", + ] + ) + assert result.exit_code == 0 + result = runner.invoke(context_command_group, ["switch", "example"]) + assert result.exit_code == 0 + + def test_add_duplicate(self, runner): + result = runner.invoke( + context_command_group, + [ + "add", + "acme_deployment", + "--name", + "Acme Deployment", + "--admin-group", + "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "--location", + "uksouth", + "--subscription", + "Data Safe Haven (Acme)", + ] + ) + assert result.exit_code == 1 + # Unable to check error as this is written outside of any Typer + # assert "A context with key 'acme_deployment' is already defined." in result.stdout + + def test_add_invalid_uuid(self, runner): + result = runner.invoke( + context_command_group, + [ + "add", + "example", + "--name", + "Example", + "--admin-group", + "not a uuid", + "--location", + "uksouth", + "--subscription", + "Data Safe Haven (Example)", + ] + ) + assert result.exit_code == 2 + # This works because the context_command_group Typer writes this error + assert "Invalid value for '--admin-group': Expected GUID" in result.stderr + + def test_add_missing_ags(self, runner): + result = runner.invoke( + context_command_group, + [ + "add", + "example", + "--name", + "Example", + ] + ) + assert result.exit_code == 2 + assert "Missing option" in result.stderr + + def test_add_bootstrap(self, tmp_contexts, runner): + (tmp_contexts / "contexts.yaml").unlink() + result = runner.invoke( + context_command_group, + [ + "add", + "acme_deployment", + "--name", + "Acme Deployment", + "--admin-group", + "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "--location", + "uksouth", + "--subscription", + "Data Safe Haven (Acme)", + ] + ) + assert result.exit_code == 0 + assert (tmp_contexts / "contexts.yaml").exists() + result = runner.invoke(context_command_group, ["show"]) + assert result.exit_code == 0 + assert "Name: Acme Deployment" in result.stdout + result = runner.invoke(context_command_group, ["available"]) + assert result.exit_code == 0 + assert "acme_deployment*" in result.stdout + assert "gems" not in result.stdout + + +class TestUpdate: + def test_update(self, runner): + result = runner.invoke(context_command_group, ["update", "--name", "New Name"]) + assert result.exit_code == 0 + result = runner.invoke(context_command_group, ["show"]) + assert result.exit_code == 0 + assert "Name: New Name" in result.stdout + + +class TestRemove: + def test_remove(self, runner): + result = runner.invoke(context_command_group, ["remove", "gems"]) + assert result.exit_code == 0 + result = runner.invoke(context_command_group, ["available"]) + assert result.exit_code == 0 + assert "gems" not in result.stdout + + def test_remove_invalid(self, runner): + result = runner.invoke(context_command_group, ["remove", "invalid"]) + assert result.exit_code == 1 + # Unable to check error as this is written outside of any Typer + # assert "No context with key 'invalid'." in result.stdout + + +class TestCreate: + def test_create(self, runner, monkeypatch): + def mock_create(self): + print("mock create") + + def mock_upload(self): + print("mock upload") + + monkeypatch.setattr(Context, "create", mock_create) + monkeypatch.setattr(Config, "upload", mock_upload) + + result = runner.invoke(context_command_group, ["create"]) + assert "mock create" in result.stdout + assert "mock upload" in result.stdout + assert result.exit_code == 0 + + +class TestTeardown: + def test_teardown(self, runner, monkeypatch): + def mock_teardown(self): + print("mock teardown") + + monkeypatch.setattr(Context, "teardown", mock_teardown) + + result = runner.invoke(context_command_group, ["teardown"]) + assert "mock teardown" in result.stdout + assert result.exit_code == 0 diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py new file mode 100644 index 0000000000..66783d788e --- /dev/null +++ b/tests_/config/test_context_settings.py @@ -0,0 +1,160 @@ +from data_safe_haven.config.context_settings import Context, ContextSettings +from data_safe_haven.exceptions import DataSafeHavenParameterError + +import pytest +import yaml +from pytest import fixture + + +class TestContext: + def test_constructor(self): + context_dict = { + "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "location": "uksouth", + "name": "Acme Deployment", + "subscription_name": "Data Safe Haven (Acme)" + } + context = Context(**context_dict) + assert isinstance(context, Context) + assert all([ + getattr(context, item) == context_dict[item] for item in context_dict.keys() + ]) + + +@fixture +def context_yaml(): + context_yaml = """\ + selected: acme_deployment + contexts: + acme_deployment: + name: Acme Deployment + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_name: Data Safe Haven (Acme) + gems: + name: Gems + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_name: Data Safe Haven (Gems)""" + return context_yaml + + +@fixture +def context_settings(context_yaml): + return ContextSettings(yaml.safe_load(context_yaml)) + + +class TestContextSettings: + def test_constructor(self, context_yaml): + settings = ContextSettings(yaml.safe_load(context_yaml)) + assert isinstance(settings, ContextSettings) + + def test_missing_selected(self, context_yaml): + context_yaml = "\n".join(context_yaml.splitlines()[1:]) + + with pytest.raises(DataSafeHavenParameterError) as exc: + ContextSettings(yaml.safe_load(context_yaml)) + assert "Missing Key: 'selected'" in exc + + def test_settings(self, context_settings): + assert isinstance(context_settings.settings, dict) + + def test_selected(self, context_settings): + assert context_settings.selected == "acme_deployment" + + def test_set_selected(self, context_settings): + assert context_settings.selected == "acme_deployment" + context_settings.selected = "gems" + assert context_settings.selected == "gems" + + def test_invalid_selected(self, context_settings): + with pytest.raises(DataSafeHavenParameterError) as exc: + context_settings.selected = "invalid" + assert "Context invalid is not defined." in exc + + def test_context(self, context_yaml, context_settings): + yaml_dict = yaml.safe_load(context_yaml) + assert isinstance(context_settings.context, Context) + assert all([ + getattr(context_settings.context, item) == yaml_dict["contexts"]["acme_deployment"][item] + for item in yaml_dict["contexts"]["acme_deployment"].keys() + ]) + + def test_set_context(self, context_yaml, context_settings): + yaml_dict = yaml.safe_load(context_yaml) + context_settings.selected = "gems" + assert isinstance(context_settings.context, Context) + assert all([ + getattr(context_settings.context, item) == yaml_dict["contexts"]["gems"][item] + for item in yaml_dict["contexts"]["gems"].keys() + ]) + + def test_available(self, context_settings): + available = context_settings.available + assert isinstance(available, list) + assert all([isinstance(item, str) for item in available]) + assert available == ["acme_deployment", "gems"] + + def test_update(self, context_settings): + assert context_settings.context.name == "Acme Deployment" + context_settings.update(name="replaced") + assert context_settings.context.name == "replaced" + + def test_set_update(self, context_settings): + context_settings.selected = "gems" + assert context_settings.context.name == "Gems" + context_settings.update(name="replaced") + assert context_settings.context.name == "replaced" + + def test_add(self, context_settings): + context_settings.add( + key="example", + name="Example", + subscription_name="Data Safe Haven (Example)", + admin_group_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + location="uksouth", + ) + context_settings.selected = "example" + assert context_settings.selected == "example" + assert context_settings.context.name == "Example" + assert context_settings.context.subscription_name == "Data Safe Haven (Example)" + + def test_invalid_add(self, context_settings): + with pytest.raises(DataSafeHavenParameterError) as exc: + context_settings.add( + key="acme_deployment", + name="Acme Deployment", + subscription_name="Data Safe Haven (Acme)", + admin_group_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + location="uksouth", + ) + assert "A context with key 'acme' is already defined." in exc + + def test_remove(self, context_settings): + context_settings.remove("acme_deployment") + assert "acme_deployment" not in context_settings.available + + def test_invalid_remove(self, context_settings): + with pytest.raises(DataSafeHavenParameterError) as exc: + context_settings.remove("invalid") + assert "No context with key 'invalid'." in exc + + def test_from_file(self, tmp_path, context_yaml): + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as f: + f.write(context_yaml) + settings = ContextSettings.from_file(config_file_path=config_file_path) + assert settings.context.name == "Acme Deployment" + + def test_write(self, tmp_path, context_yaml): + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as f: + f.write(context_yaml) + settings = ContextSettings.from_file(config_file_path=config_file_path) + settings.selected = "gems" + settings.update(name="replaced") + settings.write(config_file_path) + with open(config_file_path, "r") as f: + context_dict = yaml.safe_load(f) + assert context_dict["selected"] == "gems" + assert context_dict["contexts"]["gems"]["name"] == "replaced" diff --git a/typings/schema/__init__.pyi b/typings/schema/__init__.pyi new file mode 100644 index 0000000000..2bd78ba644 --- /dev/null +++ b/typings/schema/__init__.pyi @@ -0,0 +1,10 @@ +from typing import Any + + +class SchemaError(Exception): + ... + + +class Schema: + def __init__(self, schema: dict[Any, Any]) -> None: ... + def validate(self, data: Any) -> Any: ...