From f3acf7248e9563afda2af4f51c1db6d4d2890dc8 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 7 Nov 2023 09:53:53 +0000 Subject: [PATCH 01/65] WIP: Use Pydantic for context settings --- data_safe_haven/config/context_settings.py | 118 +++++++++------------ pyproject.toml | 2 +- 2 files changed, 51 insertions(+), 69 deletions(-) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 32c0cb0b75..3c6023c140 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -4,12 +4,11 @@ annotations, ) -from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, ClassVar import yaml -from schema import Schema, SchemaError +from pydantic import BaseModel, Field, ValidationError from yaml.parser import ParserError from data_safe_haven.exceptions import ( @@ -23,15 +22,14 @@ def default_config_file_path() -> Path: return config_dir() / "contexts.yaml" -@dataclass -class Context: +class Context(BaseModel): admin_group_id: str location: str name: str subscription_name: str -class ContextSettings: +class ContextSettings(BaseModel): """Load global and local settings from dotfiles with structure like the following selected: acme_deployment @@ -44,47 +42,18 @@ class ContextSettings: ... """ - def __init__(self, settings_dict: dict[Any, Any]) -> None: - self.logger = LoggingSingleton() - - context_schema = Schema( - { - "name": str, - "admin_group_id": str, - "location": str, - "subscription_name": str, - } - ) - - schema = Schema( - { - "selected": str, - "contexts": Schema( - { - str: context_schema, - } - ), - } - ) - - try: - self._settings: dict[Any, Any] = schema.validate(settings_dict) - except SchemaError as exc: - msg = f"Invalid context configuration file.\n{exc}" - raise DataSafeHavenParameterError(msg) from exc - - @property - def settings(self) -> dict[Any, Any]: - return self._settings + selected_: str = Field(..., alias="selected") + contexts: dict[str, Context] + logger: ClassVar[LoggingSingleton] = LoggingSingleton() @property def selected(self) -> str: - return str(self.settings["selected"]) + return str(self.selected_) @selected.setter def selected(self, context_name: str) -> None: if context_name in self.available: - self.settings["selected"] = context_name + self.selected_ = context_name self.logger.info(f"Switched context to '{context_name}'.") else: msg = f"Context '{context_name}' is not defined." @@ -92,11 +61,11 @@ def selected(self, context_name: str) -> None: @property def context(self) -> Context: - return Context(**self.settings["contexts"][self.selected]) + return self.contexts[self.selected] @property def available(self) -> list[str]: - return list(self.settings["contexts"].keys()) + return list(self.contexts.keys()) def update( self, @@ -106,24 +75,24 @@ def update( name: str | None = None, subscription_name: str | None = None, ) -> None: - context_dict = self.settings["contexts"][self.selected] + context = self.contexts[self.selected] if admin_group_id: self.logger.debug( f"Updating '[green]{admin_group_id}[/]' to '{admin_group_id}'." ) - context_dict["admin_group_id"] = admin_group_id + context.admin_group_id = admin_group_id if location: self.logger.debug(f"Updating '[green]{location}[/]' to '{location}'.") - context_dict["location"] = location + context.location = location if name: self.logger.debug(f"Updating '[green]{name}[/]' to '{name}'.") - context_dict["name"] = name + context.name = name if subscription_name: self.logger.debug( f"Updating '[green]{subscription_name}[/]' to '{subscription_name}'." ) - context_dict["subscription_name"] = subscription_name + context.subscription_name = subscription_name def add( self, @@ -139,41 +108,54 @@ def add( msg = f"A context with key '{key}' is already defined." raise DataSafeHavenParameterError(msg) - self.settings["contexts"][key] = { - "name": name, - "admin_group_id": admin_group_id, - "location": location, - "subscription_name": subscription_name, - } + self.contexts[key] = Context( + name=name, + admin_group_id=admin_group_id, + location=location, + subscription_name=subscription_name, + ) def remove(self, key: str) -> None: if key not in self.available: msg = f"No context with key '{key}'." raise DataSafeHavenParameterError(msg) - del self.settings["contexts"][key] + del self.contexts[key] + + @classmethod + def from_yaml(cls, settings_yaml: str) -> ContextSettings: + try: + settings_dict = yaml.safe_load(settings_yaml) + except ParserError as exc: + msg = f"Could not parse context settings as YAML.\n{exc}" + raise DataSafeHavenConfigError(msg) from exc + + if not isinstance(settings_dict, dict): + msg = "Unable to parse context settings as a dict." + raise DataSafeHavenConfigError(msg) + + try: + return ContextSettings.model_validate(settings_dict) + except ValidationError as exc: + cls.logger.error(f"{exc.error_count()} errors found in context settings:") + for error in exc.errors(): + cls.logger.error(f"{error['msg']} at '{'->'.join(error['loc'])}'") + msg = f"Could not load context settings.\n{exc}" + raise DataSafeHavenConfigError(msg) from exc @classmethod def from_file(cls, config_file_path: Path | None = None) -> ContextSettings: if config_file_path is None: config_file_path = default_config_file_path() - logger = LoggingSingleton() + cls.logger.info( + f"Reading project settings from '[green]{config_file_path}[/]'." + ) try: with open(config_file_path, encoding="utf-8") as f_yaml: - settings = yaml.safe_load(f_yaml) - if isinstance(settings, dict): - logger.info( - f"Reading project settings from '[green]{config_file_path}[/]'." - ) - return cls(settings) - else: - msg = f"Unable to parse {config_file_path} as a dict." - raise DataSafeHavenConfigError(msg) + settings_yaml = f_yaml.read() + return cls.from_yaml(settings_yaml) except FileNotFoundError as exc: msg = f"Could not find file {config_file_path}.\n{exc}" raise DataSafeHavenConfigError(msg) from exc - except ParserError as exc: - msg = f"Could not load settings from {config_file_path}.\n{exc}" - raise DataSafeHavenConfigError(msg) from exc def write(self, config_file_path: Path | None = None) -> None: """Write settings to YAML file""" @@ -183,5 +165,5 @@ def write(self, config_file_path: Path | None = None) -> None: config_file_path.parent.mkdir(parents=True, exist_ok=True) with open(config_file_path, "w", encoding="utf-8") as f_yaml: - yaml.dump(self.settings, f_yaml, indent=2) + yaml.dump(self.model_dump(), f_yaml, indent=2) self.logger.info(f"Saved context settings to '[green]{config_file_path}[/]'.") diff --git a/pyproject.toml b/pyproject.toml index 2a009054b4..cabdd2c7fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,10 +41,10 @@ dependencies = [ "pulumi~=3.80", "pulumi-azure-native~=2.14", "pulumi-random~=4.14", + "pydantic~=2.4", "pytz~=2023.3", "PyYAML~=6.0", "rich~=13.4", - "schema~=0.7", "simple-acme-dns~=3.0", "typer~=0.9", "websocket-client~=1.5", From 12688ea160973f6343a2f811fef4ba271aeb20af Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 7 Nov 2023 13:43:16 +0000 Subject: [PATCH 02/65] WIP: Fix ContextSettings tests --- data_safe_haven/config/context_settings.py | 9 +++----- tests_/config/test_context_settings.py | 24 +++++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 3c6023c140..a7e17f3ef2 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -5,7 +5,7 @@ ) from pathlib import Path -from typing import Any, ClassVar +from typing import ClassVar import yaml from pydantic import BaseModel, Field, ValidationError @@ -136,11 +136,8 @@ def from_yaml(cls, settings_yaml: str) -> ContextSettings: try: return ContextSettings.model_validate(settings_dict) except ValidationError as exc: - cls.logger.error(f"{exc.error_count()} errors found in context settings:") - for error in exc.errors(): - cls.logger.error(f"{error['msg']} at '{'->'.join(error['loc'])}'") msg = f"Could not load context settings.\n{exc}" - raise DataSafeHavenConfigError(msg) from exc + raise DataSafeHavenParameterError(msg) from exc @classmethod def from_file(cls, config_file_path: Path | None = None) -> ContextSettings: @@ -165,5 +162,5 @@ def write(self, config_file_path: Path | None = None) -> None: config_file_path.parent.mkdir(parents=True, exist_ok=True) with open(config_file_path, "w", encoding="utf-8") as f_yaml: - yaml.dump(self.model_dump(), f_yaml, indent=2) + yaml.dump(self.model_dump(by_alias=True), f_yaml, indent=2) self.logger.info(f"Saved context settings to '[green]{config_file_path}[/]'.") diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 66783d788e..a1742af1c8 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -41,23 +41,33 @@ def context_yaml(): @fixture def context_settings(context_yaml): - return ContextSettings(yaml.safe_load(context_yaml)) + return ContextSettings.from_yaml(context_yaml) class TestContextSettings: def test_constructor(self, context_yaml): - settings = ContextSettings(yaml.safe_load(context_yaml)) + settings = ContextSettings( + selected="acme_deployment", + contexts={ + "acme_deployment": Context( + name="Acme Deployment", + admin_group_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + location="uksouth", + subscription_name="Data Safe Haven (Acme)", + ) + }, + ) assert isinstance(settings, ContextSettings) def test_missing_selected(self, context_yaml): context_yaml = "\n".join(context_yaml.splitlines()[1:]) with pytest.raises(DataSafeHavenParameterError) as exc: - ContextSettings(yaml.safe_load(context_yaml)) - assert "Missing Key: 'selected'" in exc - - def test_settings(self, context_settings): - assert isinstance(context_settings.settings, dict) + ContextSettings.from_yaml(context_yaml) + assert "Could not load context settings" in exc + assert "1 validation error for ContextSettings" in exc + assert "selected" in exc + assert "Field required" in exc def test_selected(self, context_settings): assert context_settings.selected == "acme_deployment" From 562a6b471abf755190ee786894d4499661929c47 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 7 Nov 2023 14:34:02 +0000 Subject: [PATCH 03/65] Fix context command tests --- data_safe_haven/commands/context.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index 6a34a74562..6a9489d8fd 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -5,8 +5,9 @@ from rich import print from data_safe_haven.config import Config, ContextSettings +from data_safe_haven.config.context_settings import Context from data_safe_haven.config.context_settings import default_config_file_path -from data_safe_haven.context import Context +from data_safe_haven.context import Context as ContextInfra from data_safe_haven.functions import validate_aad_guid context_command_group = typer.Typer() @@ -93,17 +94,15 @@ def add( else: # Bootstrap context settings file settings = ContextSettings( - { - "selected": key, - "contexts": { - key: { - "admin_group_id": admin_group, - "location": location, - "name": name, - "subscription_name": subscription, - } - }, - } + selected=key, + contexts={ + key: Context( + admin_group_id=admin_group, + location=location, + name=name, + subscription_name=subscription, + ) + }, ) settings.write() @@ -161,7 +160,7 @@ def remove( def create() -> None: """Create Data Safe Haven context infrastructure.""" config = Config() - context = Context(config) + context = ContextInfra(config) context.create() context.config.upload() @@ -170,5 +169,5 @@ def create() -> None: def teardown() -> None: """Tear down Data Safe Haven context infrastructure.""" config = Config() - context = Context(config) + context = ContextInfra(config) context.teardown() From 026f14a842c07481ed51e619d7536ce0ccc06ec4 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 7 Nov 2023 14:42:01 +0000 Subject: [PATCH 04/65] Ensure selected context is defined --- data_safe_haven/config/context_settings.py | 8 +++++++- tests_/config/test_context_settings.py | 9 ++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index a7e17f3ef2..f180162682 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -8,7 +8,7 @@ from typing import ClassVar import yaml -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field, ValidationError, model_validator from yaml.parser import ParserError from data_safe_haven.exceptions import ( @@ -46,6 +46,12 @@ class ContextSettings(BaseModel): contexts: dict[str, Context] logger: ClassVar[LoggingSingleton] = LoggingSingleton() + @model_validator(mode="after") + def ensure_selected_is_valid(self) -> ContextSettings: + if self.selected not in self.available: + raise ValueError(f"Selected context '{self.selected}' is not defined.") + return self + @property def selected(self) -> str: return str(self.selected_) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index a1742af1c8..8707b04036 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -45,7 +45,7 @@ def context_settings(context_yaml): class TestContextSettings: - def test_constructor(self, context_yaml): + def test_constructor(self): settings = ContextSettings( selected="acme_deployment", contexts={ @@ -69,6 +69,13 @@ def test_missing_selected(self, context_yaml): assert "selected" in exc assert "Field required" in exc + def test_invalid_selected_input(self, context_yaml): + context_yaml = context_yaml.replace("selected: acme_deployment", "selected: invalid") + + with pytest.raises(DataSafeHavenParameterError) as exc: + ContextSettings.from_yaml(context_yaml) + assert "Selected context 'invalid' is not defined." in exc + def test_selected(self, context_settings): assert context_settings.selected == "acme_deployment" From aed1a8644294e5dc797031e82ee0f27e1233cc5b Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 7 Nov 2023 15:07:09 +0000 Subject: [PATCH 05/65] Remove schema typings --- typings/schema/__init__.pyi | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 typings/schema/__init__.pyi diff --git a/typings/schema/__init__.pyi b/typings/schema/__init__.pyi deleted file mode 100644 index 2bd78ba644..0000000000 --- a/typings/schema/__init__.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any - - -class SchemaError(Exception): - ... - - -class Schema: - def __init__(self, schema: dict[Any, Any]) -> None: ... - def validate(self, data: Any) -> Any: ... From 1abe112a6aba57740c2ea8f82d8be2b63ec4f4f0 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 13 Nov 2023 15:40:29 +0000 Subject: [PATCH 06/65] Add test for invalid YAML --- data_safe_haven/config/context_settings.py | 4 ++-- tests_/config/test_context_settings.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index f180162682..545a1f72a7 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -9,7 +9,7 @@ import yaml from pydantic import BaseModel, Field, ValidationError, model_validator -from yaml.parser import ParserError +from yaml import YAMLError from data_safe_haven.exceptions import ( DataSafeHavenConfigError, @@ -131,7 +131,7 @@ def remove(self, key: str) -> None: def from_yaml(cls, settings_yaml: str) -> ContextSettings: try: settings_dict = yaml.safe_load(settings_yaml) - except ParserError as exc: + except YAMLError as exc: msg = f"Could not parse context settings as YAML.\n{exc}" raise DataSafeHavenConfigError(msg) from exc diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 8707b04036..d98d6c2644 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -1,5 +1,5 @@ from data_safe_haven.config.context_settings import Context, ContextSettings -from data_safe_haven.exceptions import DataSafeHavenParameterError +from data_safe_haven.exceptions import DataSafeHavenConfigError, DataSafeHavenParameterError import pytest import yaml @@ -76,6 +76,12 @@ def test_invalid_selected_input(self, context_yaml): ContextSettings.from_yaml(context_yaml) assert "Selected context 'invalid' is not defined." in exc + def test_invalid_yaml(self): + invalid_yaml = "a: [1,2" + with pytest.raises(DataSafeHavenConfigError) as exc: + ContextSettings.from_yaml(invalid_yaml) + assert "Could not parse context settings as YAML." in exc + def test_selected(self, context_settings): assert context_settings.selected == "acme_deployment" From 7f314d53e17cd7981e88b2ffcdb2f4fd67e56ee1 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 13 Nov 2023 15:42:45 +0000 Subject: [PATCH 07/65] Add test for YAML not being a dict --- tests_/config/test_context_settings.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index d98d6c2644..6571ffb97f 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -82,6 +82,12 @@ def test_invalid_yaml(self): ContextSettings.from_yaml(invalid_yaml) assert "Could not parse context settings as YAML." in exc + def test_yaml_not_dict(self): + not_dict = "[1, 2, 3]" + with pytest.raises(DataSafeHavenConfigError) as exc: + ContextSettings.from_yaml(not_dict) + assert "Unable to parse context settings as a dict." in exc + def test_selected(self, context_settings): assert context_settings.selected == "acme_deployment" From d22bbd1c6fc137dfec847ad12e45c57c4463ed7b Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 13 Nov 2023 15:47:44 +0000 Subject: [PATCH 08/65] Add file not found test --- tests_/config/test_context_settings.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 6571ffb97f..e1cc3cfac5 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -175,6 +175,12 @@ def test_from_file(self, tmp_path, context_yaml): settings = ContextSettings.from_file(config_file_path=config_file_path) assert settings.context.name == "Acme Deployment" + def test_file_not_found(self, tmp_path): + config_file_path = tmp_path / "config.yaml" + with pytest.raises(DataSafeHavenConfigError) as exc: + ContextSettings.from_file(config_file_path=config_file_path) + assert "Could not find file" in exc + def test_write(self, tmp_path, context_yaml): config_file_path = tmp_path / "config.yaml" with open(config_file_path, "w") as f: From 502fe5b92ca3062e3cae270a997cd3410cebf041 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 13 Nov 2023 15:51:08 +0000 Subject: [PATCH 09/65] Fix linting errors --- data_safe_haven/commands/context.py | 3 +-- data_safe_haven/config/context_settings.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index 6a9489d8fd..1e17a0143d 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -5,8 +5,7 @@ from rich import print from data_safe_haven.config import Config, ContextSettings -from data_safe_haven.config.context_settings import Context -from data_safe_haven.config.context_settings import default_config_file_path +from data_safe_haven.config.context_settings import Context, default_config_file_path from data_safe_haven.context import Context as ContextInfra from data_safe_haven.functions import validate_aad_guid diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 545a1f72a7..795bfa1d72 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -49,7 +49,8 @@ class ContextSettings(BaseModel): @model_validator(mode="after") def ensure_selected_is_valid(self) -> ContextSettings: if self.selected not in self.available: - raise ValueError(f"Selected context '{self.selected}' is not defined.") + msg = f"Selected context '{self.selected}' is not defined." + raise ValueError(msg) return self @property From 9093b9738fe35f90f295d2c8ea71bc8120971167 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 14 Nov 2023 10:20:21 +0000 Subject: [PATCH 10/65] Add Pydantic to linting environment --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index cabdd2c7fc..f901ea955e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ detached = true dependencies = [ "black>=23.1.0", "mypy>=1.0.0", + "pydantic>=2.4", "ruff>=0.0.243", "types-appdirs>=1.4.3.5", "types-chevron>=0.14.2.5", From d374b1c048a2dbd346a3ec269bbd365f079b7e2c Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 15 Nov 2023 10:58:02 +0000 Subject: [PATCH 11/65] WIP: Introduce annotated types for validation --- data_safe_haven/config/context_settings.py | 11 +- data_safe_haven/functions/__init__.py | 10 -- data_safe_haven/functions/validators.py | 119 +++++---------------- data_safe_haven/utility/annotated_types.py | 22 ++++ tests_/config/test_context_settings.py | 37 +++++-- 5 files changed, 86 insertions(+), 113 deletions(-) create mode 100644 data_safe_haven/utility/annotated_types.py diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 795bfa1d72..10dab336c0 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -16,6 +16,11 @@ DataSafeHavenParameterError, ) from data_safe_haven.utility import LoggingSingleton, config_dir +from data_safe_haven.utility.annotated_types import ( + AzureLocation, + AzureLongName, + Guid, +) def default_config_file_path() -> Path: @@ -23,10 +28,10 @@ def default_config_file_path() -> Path: class Context(BaseModel): - admin_group_id: str - location: str + admin_group_id: Guid + location: AzureLocation name: str - subscription_name: str + subscription_name: AzureLongName class ContextSettings(BaseModel): diff --git a/data_safe_haven/functions/__init__.py b/data_safe_haven/functions/__init__.py index fcdb02ce72..6b444728ff 100644 --- a/data_safe_haven/functions/__init__.py +++ b/data_safe_haven/functions/__init__.py @@ -24,12 +24,7 @@ validate_azure_vm_sku, validate_email_address, validate_ip_address, - validate_list, - validate_non_empty_list, - validate_non_empty_string, - validate_string_length, validate_timezone, - validate_type, ) __all__ = [ @@ -54,10 +49,5 @@ "validate_azure_vm_sku", "validate_email_address", "validate_ip_address", - "validate_list", - "validate_non_empty_list", - "validate_non_empty_string", - "validate_string_length", "validate_timezone", - "validate_type", ] diff --git a/data_safe_haven/functions/validators.py b/data_safe_haven/functions/validators.py index da41d3db24..a1526acd83 100644 --- a/data_safe_haven/functions/validators.py +++ b/data_safe_haven/functions/validators.py @@ -1,117 +1,50 @@ import ipaddress import re -from collections.abc import Callable -from typing import Any import pytz -import typer -def validate_aad_guid(aad_guid: str | None) -> str | None: - if aad_guid is not None: - if not re.match( - r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$", - aad_guid, - ): - msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'." - raise typer.BadParameter(msg) +def validate_aad_guid(aad_guid: str) -> str: + if not re.match( + r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$", + aad_guid, + ): + msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'." + raise ValueError(msg) return aad_guid -def validate_azure_location(azure_location: str | None) -> str | None: - if azure_location is not None: - if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location): - msg = "Expected valid Azure location, for example 'uksouth'." - raise typer.BadParameter(msg) +def validate_azure_location(azure_location: str) -> str: + if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location): + msg = "Expected valid Azure location, for example 'uksouth'." + raise ValueError(msg) return azure_location -def validate_azure_vm_sku(azure_vm_sku: str | None) -> str | None: - if azure_vm_sku is not None: - if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku): - msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'." - raise typer.BadParameter(msg) +def validate_azure_vm_sku(azure_vm_sku: str) -> str: + if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku): + msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'." + raise ValueError(msg) return azure_vm_sku -def validate_email_address(email_address: str | None) -> str | None: - if email_address is not None: - if not re.match(r"^\S+@\S+$", email_address): - msg = "Expected valid email address, for example 'sherlock@holmes.com'." - raise typer.BadParameter(msg) +def validate_email_address(email_address: str) -> str: + if not re.match(r"^\S+@\S+$", email_address): + msg = "Expected valid email address, for example 'sherlock@holmes.com'." + raise ValueError(msg) return email_address -def validate_ip_address( - ip_address: str | None, -) -> str | None: +def validate_ip_address(ip_address: str) -> str: try: - if ip_address: - return str(ipaddress.ip_network(ip_address)) - return None + return str(ipaddress.ip_network(ip_address)) except Exception as exc: msg = "Expected valid IPv4 address, for example '1.1.1.1'." - raise typer.BadParameter(msg) from exc + raise ValueError(msg) from exc -def validate_list( - value: list[Any], - validator: Callable[[Any], Any] | None = None, - *, - allow_empty: bool = False, -) -> list[Any]: - try: - if not allow_empty: - validate_non_empty_list(value) - if validator: - for element in value: - validator(element) - return value - except Exception as exc: - msg = f"Expected valid list.\n{exc}" - raise typer.BadParameter(msg) from exc - - -def validate_non_empty_list(value: list[Any]) -> list[Any]: - if len(value) == 0: - msg = "Expected non-empty list." - raise typer.BadParameter(msg) - return value - - -def validate_non_empty_string(value: Any) -> str: - try: - return validate_string_length(value, min_length=1) - except Exception as exc: - msg = "Expected non-empty string." - raise typer.BadParameter(msg) from exc - - -def validate_string_length( - value: Any, min_length: int | None = None, max_length: int | None = None -) -> str: - if isinstance(value, str): - if min_length and len(value) < min_length: - msg = f"Expected string with minimum length {min_length}." - raise typer.BadParameter(msg) - if max_length and len(value) > max_length: - msg = f"Expected string with maximum length {max_length}." - raise typer.BadParameter(msg) - return str(value) - msg = "Expected string." - raise typer.BadParameter(msg) - - -def validate_timezone(timezone: str | None) -> str | None: - if timezone is not None: - if timezone not in pytz.all_timezones: - msg = "Expected valid timezone, for example 'Europe/London'." - raise typer.BadParameter(msg) +def validate_timezone(timezone: str) -> str: + if timezone not in pytz.all_timezones: + msg = "Expected valid timezone, for example 'Europe/London'." + raise ValueError(msg) return timezone - - -def validate_type(value: Any, type_: type) -> Any: - if not isinstance(value, type_): - msg = f"Expected type '{type_.__name__}' but received '{type(value).__name__}'." - raise typer.BadParameter(msg) - return value diff --git a/data_safe_haven/utility/annotated_types.py b/data_safe_haven/utility/annotated_types.py new file mode 100644 index 0000000000..1db246dabf --- /dev/null +++ b/data_safe_haven/utility/annotated_types.py @@ -0,0 +1,22 @@ +from typing import Annotated + +from pydantic import Field +from pydantic.functional_validators import AfterValidator + +from data_safe_haven.functions import ( + validate_aad_guid, + validate_azure_location, + validate_azure_vm_sku, + validate_email_address, + validate_ip_address, + validate_timezone, +) + +AzureShortName = Annotated[str, Field(min_length=1, max_length=24)] +AzureLongName = Annotated[str, Field(min_length=1, max_length=64)] +AzureLocation = Annotated[str, AfterValidator(validate_azure_location)] +AzureVmSku = Annotated[str, AfterValidator(validate_azure_vm_sku)] +EmailAdress = Annotated[str, AfterValidator(validate_email_address)] +Guid = Annotated[str, AfterValidator(validate_aad_guid)] +IpAddress = Annotated[str, AfterValidator(validate_ip_address)] +TimeZone = Annotated[str, AfterValidator(validate_timezone)] diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index e1cc3cfac5..8f04ed3fc1 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -3,23 +3,46 @@ import pytest import yaml +from pydantic import ValidationError from pytest import fixture +@fixture +def context_dict(): + return { + "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "location": "uksouth", + "name": "Acme Deployment", + "subscription_name": "Data Safe Haven (Acme)" + } + + class TestContext: - def test_constructor(self): - context_dict = { - "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - "location": "uksouth", - "name": "Acme Deployment", - "subscription_name": "Data Safe Haven (Acme)" - } + def test_constructor(self, context_dict): context = Context(**context_dict) assert isinstance(context, Context) assert all([ getattr(context, item) == context_dict[item] for item in context_dict.keys() ]) + def test_invalid_guid(self, context_dict): + context_dict["admin_group_id"] = "not a guid" + with pytest.raises(ValidationError) as exc: + Context(**context_dict) + assert "Value error, Expected GUID, for example" in exc + + def test_invalid_location(self, context_dict): + context_dict["location"] = "not_a_location" + with pytest.raises(ValidationError) as exc: + Context(**context_dict) + assert "Value error, Expected valid Azure location" in exc + + def test_invalid_subscription_name(self, context_dict): + context_dict["subscription_name"] = "very "*12 + "long name" + with pytest.raises(ValidationError) as exc: + Context(**context_dict) + assert "String should have at most 64 characters" in exc + @fixture def context_yaml(): From 7907e0c499cdb4618350f4f19fad987f28b1f666 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 15 Nov 2023 12:36:28 +0000 Subject: [PATCH 12/65] WIP: Add Typer validators factory --- data_safe_haven/commands/context.py | 6 ++-- data_safe_haven/commands/deploy.py | 26 ++++++++--------- data_safe_haven/functions/typer_validators.py | 29 +++++++++++++++++++ 3 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 data_safe_haven/functions/typer_validators.py diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index 1e17a0143d..c18db89e38 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -7,7 +7,7 @@ from data_safe_haven.config import Config, ContextSettings from data_safe_haven.config.context_settings import Context, default_config_file_path from data_safe_haven.context import Context as ContextInfra -from data_safe_haven.functions import validate_aad_guid +from data_safe_haven.functions.typer_validators import typer_validate_aad_guid context_command_group = typer.Typer() @@ -58,7 +58,7 @@ def add( str, typer.Option( help="The ID of an Azure group containing all administrators.", - callback=validate_aad_guid, + callback=typer_validate_aad_guid, ), ], location: Annotated[ @@ -112,7 +112,7 @@ def update( Optional[str], # noqa: UP007 typer.Option( help="The ID of an Azure group containing all administrators.", - callback=validate_aad_guid, + callback=typer_validate_aad_guid, ), ] = None, location: Annotated[ diff --git a/data_safe_haven/commands/deploy.py b/data_safe_haven/commands/deploy.py index 9771a9a99c..46ef50f7f4 100644 --- a/data_safe_haven/commands/deploy.py +++ b/data_safe_haven/commands/deploy.py @@ -3,12 +3,12 @@ import typer -from data_safe_haven.functions import ( - validate_aad_guid, - validate_azure_vm_sku, - validate_email_address, - validate_ip_address, - validate_timezone, +from data_safe_haven.functions.typer_validators import ( + typer_validate_aad_guid, + typer_validate_azure_vm_sku, + typer_validate_email_address, + typer_validate_ip_address, + typer_validate_timezone, ) from data_safe_haven.utility import DatabaseSystem, SoftwarePackageCategory @@ -29,7 +29,7 @@ def shm( "The tenant ID for the AzureAD where users will be created," " for example '10de18e7-b238-6f1e-a4ad-772708929203'." ), - callback=validate_aad_guid, + callback=typer_validate_aad_guid, ), ] = None, admin_email_address: Annotated[ @@ -38,7 +38,7 @@ def shm( "--email", "-e", help="The email address where your system deployers and administrators can be contacted.", - callback=validate_email_address, + callback=typer_validate_email_address, ), ] = None, admin_ip_addresses: Annotated[ @@ -50,7 +50,7 @@ def shm( "An IP address or range used by your system deployers and administrators." " [*may be specified several times*]" ), - callback=lambda ips: [validate_ip_address(ip) for ip in ips], + callback=lambda ips: [typer_validate_ip_address(ip) for ip in ips], ), ] = None, domain: Annotated[ @@ -75,7 +75,7 @@ def shm( "--timezone", "-t", help="The timezone that this Data Safe Haven deployment will use.", - callback=validate_timezone, + callback=typer_validate_timezone, ), ] = None, ) -> None: @@ -115,7 +115,7 @@ def sre( "--data-provider-ip-address", "-d", help="An IP address or range used by your data providers. [*may be specified several times*]", - callback=lambda vms: [validate_ip_address(vm) for vm in vms], + callback=lambda vms: [typer_validate_ip_address(vm) for vm in vms], ), ] = None, databases: Annotated[ @@ -148,7 +148,7 @@ def sre( "--user-ip-address", "-u", help="An IP address or range used by your users. [*may be specified several times*]", - callback=lambda ips: [validate_ip_address(ip) for ip in ips], + callback=lambda ips: [typer_validate_ip_address(ip) for ip in ips], ), ] = None, workspace_skus: Annotated[ @@ -160,7 +160,7 @@ def sre( "A virtual machine SKU to make available to your users as a workspace." " [*may be specified several times*]" ), - callback=lambda ips: [validate_azure_vm_sku(ip) for ip in ips], + callback=lambda ips: [typer_validate_azure_vm_sku(ip) for ip in ips], ), ] = None, ) -> None: diff --git a/data_safe_haven/functions/typer_validators.py b/data_safe_haven/functions/typer_validators.py new file mode 100644 index 0000000000..30d6227249 --- /dev/null +++ b/data_safe_haven/functions/typer_validators.py @@ -0,0 +1,29 @@ +from collections.abc import Callable +from typing import Any + +from typer import BadParameter + +from data_safe_haven.functions.validators import ( + validate_aad_guid, + validate_azure_vm_sku, + validate_email_address, + validate_ip_address, + validate_timezone, +) + + +def typer_validator_factory(validator: Callable[[Any], Any]) -> Callable[[Any], Any]: + def typer_validator(x: Any) -> Any: + try: + validator(x) + except ValueError as exc: + raise BadParameter(str(exc)) from exc + + return typer_validator + + +typer_validate_aad_guid = typer_validator_factory(validate_aad_guid) +typer_validate_email_address = typer_validator_factory(validate_email_address) +typer_validate_ip_address = typer_validator_factory(validate_ip_address) +typer_validate_azure_vm_sku = typer_validator_factory(validate_azure_vm_sku) +typer_validate_timezone = typer_validator_factory(validate_timezone) From 36f0e1c24e996cb3f18b0dc32ec9a04a926d0ae2 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 15 Nov 2023 13:48:21 +0000 Subject: [PATCH 13/65] Validate Context and ContextSettings on assignment --- data_safe_haven/config/context_settings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 10dab336c0..65ca048bba 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -27,14 +27,14 @@ def default_config_file_path() -> Path: return config_dir() / "contexts.yaml" -class Context(BaseModel): +class Context(BaseModel, validate_assignment=True): admin_group_id: Guid location: AzureLocation name: str subscription_name: AzureLongName -class ContextSettings(BaseModel): +class ContextSettings(BaseModel, validate_assignment=True): """Load global and local settings from dotfiles with structure like the following selected: acme_deployment From 1136b4994f74bc57686bc252769fe515056c8cac Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 20 Nov 2023 15:29:50 +0000 Subject: [PATCH 14/65] WIP: rewrite Config (and related) as BaseModel --- data_safe_haven/config/config.py | 446 ++++++--------------- data_safe_haven/config/context_settings.py | 25 +- data_safe_haven/functions/__init__.py | 1 - data_safe_haven/functions/miscellaneous.py | 12 - 4 files changed, 146 insertions(+), 338 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index c0f32247df..88826e590d 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -1,156 +1,69 @@ """Configuration file backed by blob storage""" +from __future__ import annotations + import pathlib -from collections import defaultdict -from collections.abc import Callable -from contextlib import suppress -from dataclasses import dataclass, field -from functools import partial -from typing import Any, ClassVar - -import chili + import yaml -from yaml.parser import ParserError +from pydantic import BaseModel, Field, ValidationError, computed_field +from yaml import YAMLError from data_safe_haven import __version__ +from data_safe_haven.config.context_settings import Context from data_safe_haven.exceptions import ( - DataSafeHavenAzureError, DataSafeHavenConfigError, DataSafeHavenParameterError, ) from data_safe_haven.external import AzureApi from data_safe_haven.functions import ( - alphanumeric, - as_dict, b64decode, b64encode, - validate_aad_guid, - validate_azure_location, - validate_azure_vm_sku, - validate_email_address, - validate_ip_address, - validate_list, - validate_non_empty_string, - validate_string_length, - validate_timezone, - validate_type, ) from data_safe_haven.utility import ( DatabaseSystem, LoggingSingleton, SoftwarePackageCategory, - config_dir, +) +from data_safe_haven.utility.annotated_types import ( + AzureLocation, + AzureLongName, + AzureShortName, + AzureVmSku, + EmailAdress, + Guid, + IpAddress, + TimeZone, ) -from .context_settings import ContextSettings - - -class Validator: - validation_functions: ClassVar[dict[str, Callable[[Any], Any]]] = {} - - def validate(self) -> None: - """Validate instance attributes. - Validation fails if the provided validation function raises an exception. - """ - try: - for attr_name in self.validation_functions.keys(): - self.validate_attribute(attr_name) - except Exception as exc: - msg = f"Failed to validate command line arguments.\n{exc}" - raise DataSafeHavenConfigError(msg) from exc +class ConfigSectionAzure(BaseModel, validate_assignment=True): + admin_group_id: Guid + location: AzureLocation + subscription_id: Guid + tenant_id: Guid - def validate_attribute(self, attribute_name: str) -> None: - """Validate single instance attribute. - Validation fails if the provided validation function raises an exception. - """ - try: - validator = self.validation_functions[attribute_name] - validator(getattr(self, attribute_name)) - except Exception as exc: - msg = f"Invalid value for '{attribute_name}': '{getattr(self, attribute_name)}'.\n{exc}" - raise DataSafeHavenConfigError(msg) from exc +class ConfigSectionContext(BaseModel, validate_assignment=True): + key_vault_name: AzureShortName + managed_identity_name: AzureLongName + resource_group_name: AzureLongName + storage_account_name: AzureShortName + storage_container_name: AzureLongName -class ConfigSection(Validator): - def to_dict(self) -> dict[str, Any]: - """Dictionary representation of this object.""" - self.validate() - return as_dict(chili.encode(self)) - - -@dataclass -class ConfigSectionAzure(ConfigSection): - admin_group_id: str = "" - location: str = "" - subscription_id: str = "" - tenant_id: str = "" - - validation_functions = { # noqa: RUF012 - "admin_group_id": validate_aad_guid, - "location": validate_azure_location, - "subscription_id": validate_aad_guid, - "tenant_id": validate_aad_guid, - } - - -@dataclass -class ConfigSectionContext(ConfigSection): - key_vault_name: str = "" - managed_identity_name: str = "" - resource_group_name: str = "" - storage_account_name: str = "" - storage_container_name: str = "" - - validation_functions = { # noqa: RUF012 - "key_vault_name": partial(validate_string_length, min_length=3, max_length=24), - "managed_identity_name": partial( - validate_string_length, min_length=1, max_length=64 - ), - "resource_group_name": partial( - validate_string_length, min_length=1, max_length=64 - ), - "storage_account_name": partial( - validate_string_length, min_length=1, max_length=24 - ), - "storage_container_name": partial( - validate_string_length, min_length=1, max_length=64 - ), - } - - -@dataclass -class ConfigSectionPulumi(ConfigSection): +class ConfigSectionPulumi(BaseModel, validate_assignment=True): encryption_key_name: str = "pulumi-encryption-key" - encryption_key_version: str = "" - stacks: dict[str, str] = field(default_factory=dict) + encryption_key_version: str + stacks: dict[str, str] = Field(default_factory=dict) storage_container_name: str = "pulumi" - validation_functions = { # noqa: RUF012 - "encryption_key_name": validate_non_empty_string, - "encryption_key_version": validate_non_empty_string, - "stacks": lambda stacks: isinstance(stacks, dict), - "storage_container_name": validate_non_empty_string, - } - - -@dataclass -class ConfigSectionSHM(ConfigSection): - aad_tenant_id: str = "" - admin_email_address: str = "" - admin_ip_addresses: list[str] = field(default_factory=list) - fqdn: str = "" - name: str = "" - timezone: str = "" - - validation_functions = { # noqa: RUF012 - "aad_tenant_id": validate_aad_guid, - "admin_email_address": validate_email_address, - "admin_ip_addresses": partial(validate_list, validator=validate_ip_address), - "fqdn": validate_non_empty_string, - "name": validate_non_empty_string, - "timezone": validate_timezone, - } + +class ConfigSectionSHM(BaseModel, validate_assignment=True): + aad_tenant_id: Guid + admin_email_address: EmailAdress + admin_ip_addresses: list[IpAddress] + fqdn: str + name: str + timezone: TimeZone def update( self, @@ -177,96 +90,65 @@ def update( logger.info( f"[bold]AzureAD tenant ID[/] will be [green]{self.aad_tenant_id}[/]." ) - self.validate_attribute("aad_tenant_id") # Set admin email address if admin_email_address: self.admin_email_address = admin_email_address logger.info( f"[bold]Admin email address[/] will be [green]{self.admin_email_address}[/]." ) - self.validate_attribute("admin_email_address") # Set admin IP addresses if admin_ip_addresses: self.admin_ip_addresses = admin_ip_addresses logger.info( f"[bold]IP addresses used by administrators[/] will be [green]{self.admin_ip_addresses}[/]." ) - self.validate_attribute("admin_ip_addresses") # Set fully-qualified domain name if fqdn: self.fqdn = fqdn logger.info( f"[bold]Fully-qualified domain name[/] will be [green]{self.fqdn}[/]." ) - self.validate_attribute("fqdn") # Set timezone if timezone: self.timezone = timezone logger.info(f"[bold]Timezone[/] will be [green]{self.timezone}[/].") - self.validate_attribute("timezone") - - -@dataclass -class ConfigSectionSRE(ConfigSection): - @dataclass - class ConfigSubsectionRemoteDesktopOpts(Validator): - allow_copy: bool = False - allow_paste: bool = False - - validation_functions = { # noqa: RUF012 - "allow_copy": partial(validate_type, type_=bool), - "allow_paste": partial(validate_type, type_=bool), - } - - def update( - self, *, allow_copy: bool | None = None, allow_paste: bool | None = None - ) -> None: - """Update SRE remote desktop settings - - Args: - allow_copy: Allow/deny copying text out of the SRE - allow_paste: Allow/deny pasting text into the SRE - """ - # Set whether copying text out of the SRE is allowed - if allow_copy: - self.allow_copy = allow_copy - LoggingSingleton().info( - f"[bold]Copying text out of the SRE[/] will be [green]{'allowed' if self.allow_copy else 'forbidden'}[/]." - ) - # Set whether pasting text into the SRE is allowed - if allow_paste: - self.allow_paste = allow_paste - LoggingSingleton().info( - f"[bold]Pasting text into the SRE[/] will be [green]{'allowed' if self.allow_paste else 'forbidden'}[/]." - ) - - databases: list[DatabaseSystem] = field(default_factory=list) - data_provider_ip_addresses: list[str] = field(default_factory=list) - index: int = 0 - remote_desktop: ConfigSubsectionRemoteDesktopOpts = field( - default_factory=ConfigSubsectionRemoteDesktopOpts - ) - workspace_skus: list[str] = field(default_factory=list) - research_user_ip_addresses: list[str] = field(default_factory=list) - software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE - validation_functions = { # noqa: RUF012 - "data_provider_ip_addresses": partial( - validate_list, validator=validate_ip_address - ), - "databases": partial( - validate_list, - validator=lambda pkg: isinstance(pkg, DatabaseSystem), - allow_empty=True, - ), - "index": lambda idx: isinstance(idx, int) and idx >= 0, - "remote_desktop": lambda dsktop: dsktop.validate(), - "workspace_skus": partial(validate_list, validator=validate_azure_vm_sku), - "research_user_ip_addresses": partial( - validate_list, validator=validate_ip_address - ), - "software_packages": lambda pkg: isinstance(pkg, SoftwarePackageCategory), - } + +class ConfigSubsectionRemoteDesktopOpts(BaseModel, validate_assignment=True): + allow_copy: bool = False + allow_paste: bool = False + + def update( + self, *, allow_copy: bool | None = None, allow_paste: bool | None = None + ) -> None: + """Update SRE remote desktop settings + + Args: + allow_copy: Allow/deny copying text out of the SRE + allow_paste: Allow/deny pasting text into the SRE + """ + # Set whether copying text out of the SRE is allowed + if allow_copy: + self.allow_copy = allow_copy + LoggingSingleton().info( + f"[bold]Copying text out of the SRE[/] will be [green]{'allowed' if self.allow_copy else 'forbidden'}[/]." + ) + # Set whether pasting text into the SRE is allowed + if allow_paste: + self.allow_paste = allow_paste + LoggingSingleton().info( + f"[bold]Pasting text into the SRE[/] will be [green]{'allowed' if self.allow_paste else 'forbidden'}[/]." + ) + + +class ConfigSectionSRE(BaseModel, validate_assignment=True): + databases: list[DatabaseSystem] + data_provider_ip_addresses: list[IpAddress] + index: int = Field(ge=0) + remote_desktop: ConfigSubsectionRemoteDesktopOpts + workspace_skus: list[AzureVmSku] + research_user_ip_addresses: list[IpAddress] + software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE def update( self, @@ -297,7 +179,6 @@ def update( logger.info( f"[bold]IP addresses used by data providers[/] will be [green]{self.data_provider_ip_addresses}[/]." ) - self.validate_attribute("data_provider_ip_addresses") # Set which databases to deploy if databases: self.databases = sorted(set(databases)) @@ -306,156 +187,44 @@ def update( logger.info( f"[bold]Databases available to users[/] will be [green]{[database.value for database in self.databases]}[/]." ) - self.validate_attribute("databases") # Pass allow_copy and allow_paste to remote desktop self.remote_desktop.update(allow_copy=allow_copy, allow_paste=allow_paste) - self.validate_attribute("remote_desktop") # Set research desktop SKUs if workspace_skus: self.workspace_skus = workspace_skus logger.info(f"[bold]Workspace SKUs[/] will be [green]{self.workspace_skus}[/].") - self.validate_attribute("remote_desktop") # Select which software packages can be installed by users if software_packages: self.software_packages = software_packages logger.info( f"[bold]Software packages[/] from [green]{self.software_packages.value}[/] sources will be installable." ) - self.validate_attribute("software_packages") # Set user IP addresses if user_ip_addresses: self.research_user_ip_addresses = user_ip_addresses logger.info( f"[bold]IP addresses used by users[/] will be [green]{self.research_user_ip_addresses}[/]." ) - self.validate_attribute("research_user_ip_addresses") -@dataclass -class ConfigSectionTags(ConfigSection): - deployment: str = "" +class ConfigSectionTags(BaseModel, validate_assignment=True): + deployment: str deployed_by: str = "Python" project: str = "Data Safe Haven" version: str = __version__ - validation_functions = { # noqa: RUF012 - "deployment": validate_non_empty_string, - "deployed_by": validate_non_empty_string, - "project": validate_non_empty_string, - "version": validate_non_empty_string, - } - - -class Config: - def __init__(self) -> None: - # Initialise config sections - self.azure_: ConfigSectionAzure | None = None - self.context_: ConfigSectionContext | None = None - self.pulumi_: ConfigSectionPulumi | None = None - self.shm_: ConfigSectionSHM | None = None - self.tags_: ConfigSectionTags | None = None - self.sres: dict[str, ConfigSectionSRE] = defaultdict(ConfigSectionSRE) - # Read context settings - settings = ContextSettings.from_file() - context = settings.context - # Check if backend exists and was loaded - try: - self.name = context.name - except DataSafeHavenParameterError as exc: - msg = "Data Safe Haven has not been initialised: run '[bright_cyan]dsh init[/]' before continuing." - raise DataSafeHavenConfigError(msg) from exc - self.subscription_name = context.subscription_name - self.azure.location = context.location - self.azure.admin_group_id = context.admin_group_id - self.context_storage_container_name = "config" - # Set derived names - self.shm_name_ = alphanumeric(self.name).lower() - self.filename = f"config-{self.shm_name_}.yaml" - self.context_resource_group_name = f"shm-{self.shm_name_}-rg-context" - self.context_storage_account_name = ( - f"shm{self.shm_name_[:14]}context" # maximum of 24 characters allowed - ) - self.work_directory = config_dir() / self.shm_name_ - self.azure_api = AzureApi(subscription_name=self.subscription_name) - # Attempt to load YAML dictionary from blob storage - yaml_input = {} - with suppress(DataSafeHavenAzureError, ParserError): - yaml_input = yaml.safe_load( - self.azure_api.download_blob( - self.filename, - self.context_resource_group_name, - self.context_storage_account_name, - self.context_storage_container_name, - ) - ) - # Attempt to decode each config section - if yaml_input: - if "azure" in yaml_input: - self.azure_ = chili.decode(yaml_input["azure"], ConfigSectionAzure) - if "context" in yaml_input: - self.context_ = chili.decode( - yaml_input["context"], ConfigSectionContext - ) - if "pulumi" in yaml_input: - self.pulumi_ = chili.decode(yaml_input["pulumi"], ConfigSectionPulumi) - if "shm" in yaml_input: - self.shm_ = chili.decode(yaml_input["shm"], ConfigSectionSHM) - if "sre" in yaml_input: - for sre_name, sre_details in dict(yaml_input["sre"]).items(): - self.sres[sre_name] = chili.decode(sre_details, ConfigSectionSRE) - - @property - def azure(self) -> ConfigSectionAzure: - if not self.azure_: - self.azure_ = ConfigSectionAzure() - return self.azure_ - - @property - def context(self) -> ConfigSectionContext: - if not self.context_: - self.context_ = ConfigSectionContext( - key_vault_name=f"shm-{self.shm_name_[:9]}-kv-context", - managed_identity_name=f"shm-{self.shm_name_}-identity-reader-context", - resource_group_name=self.context_resource_group_name, - storage_account_name=self.context_storage_account_name, - storage_container_name=self.context_storage_container_name, - ) - return self.context_ - - @property - def pulumi(self) -> ConfigSectionPulumi: - if not self.pulumi_: - self.pulumi_ = ConfigSectionPulumi() - return self.pulumi_ - - @property - def shm(self) -> ConfigSectionSHM: - if not self.shm_: - self.shm_ = ConfigSectionSHM(name=self.shm_name_) - return self.shm_ - - @property - def tags(self) -> ConfigSectionTags: - if not self.tags_: - self.tags_ = ConfigSectionTags(deployment=self.name) - return self.tags_ - - def __str__(self) -> str: - """String representation of the Config object""" - contents: dict[str, Any] = {} - if self.azure_: - contents["azure"] = self.azure.to_dict() - if self.context_: - contents["context"] = self.context.to_dict() - if self.pulumi_: - contents["pulumi"] = self.pulumi.to_dict() - if self.shm_: - contents["shm"] = self.shm.to_dict() - if self.sres: - contents["sre"] = {k: v.to_dict() for k, v in self.sres.items()} - if self.tags: - contents["tags"] = self.tags.to_dict() - return str(yaml.dump(contents, indent=2)) + +class Config(BaseModel, validate_assignment=True): + azure: ConfigSectionAzure | None = None + context: Context + pulumi: ConfigSectionPulumi | None = None + shm: ConfigSectionSHM | None = None + tags: ConfigSectionTags | None = None + sres: dict[str, ConfigSectionSRE] | None = None + + @computed_field + def work_directory(self) -> str: + return self.context.work_directory def read_stack(self, name: str, path: pathlib.Path) -> None: """Add a Pulumi stack file to config""" @@ -480,14 +249,43 @@ def sre(self, name: str) -> ConfigSectionSRE: self.sres[name].index = highest_index + 1 return self.sres[name] + @classmethod + def from_yaml(cls, config_yaml: str) -> Config: + try: + config_dict = yaml.safe_load(config_yaml) + except YAMLError as exc: + msg = f"Could not parse configuration as YAML.\n{exc}" + raise DataSafeHavenConfigError(msg) from exc + + if not isinstance(config_dict, dict): + msg = "Unable to parse configuration as a dict." + raise DataSafeHavenConfigError(msg) + + try: + return Config.model_validate(config_dict) + except ValidationError as exc: + msg = f"Could not load configuration.\n{exc}" + raise DataSafeHavenParameterError(msg) from exc + + @classmethod + def from_remote(cls, context: Context) -> Config: + azure_api = AzureApi(subscription_name=context.subscription_name) + config_yaml = azure_api.download_blob( + context.config_filename, + context.resource_group_name, + context.storage_account_name, + context.storage_container_name, + ) + return Config.from_yaml(config_yaml) + def upload(self) -> None: """Upload config to Azure storage""" self.azure_api.upload_blob( - str(self), - self.filename, - self.context_resource_group_name, - self.context_storage_account_name, - self.context_storage_container_name, + yaml.dump(self.model_dump, indent=2), + self.context.config_filename, + self.context.resource_group_name, + self.context.storage_account_name, + self.context.storage_container_name, ) def write_stack(self, name: str, path: pathlib.Path) -> None: diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 65ca048bba..00edc4c9ac 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -8,13 +8,14 @@ from typing import ClassVar import yaml -from pydantic import BaseModel, Field, ValidationError, model_validator +from pydantic import BaseModel, Field, ValidationError, computed_field, model_validator from yaml import YAMLError from data_safe_haven.exceptions import ( DataSafeHavenConfigError, DataSafeHavenParameterError, ) +from data_safe_haven.functions import alphanumeric from data_safe_haven.utility import LoggingSingleton, config_dir from data_safe_haven.utility.annotated_types import ( AzureLocation, @@ -32,6 +33,28 @@ class Context(BaseModel, validate_assignment=True): location: AzureLocation name: str subscription_name: AzureLongName + storage_container_name: ClassVar[str] = "config" + + @computed_field + def shm_name(self) -> str: + return alphanumeric(self.name).lower() + + @computed_field + def work_directory(self) -> Path: + return config_dir() / self.shm_name + + @computed_field + def config_filename(self) -> str: + return f"config-{self.shm_name}.yaml" + + @computed_field + def resource_group_name(self) -> str: + return f"shm-{self.shm_name}-rg-context" + + @computed_field + def storage_account_name(self) -> str: + # maximum of 24 characters allowed + return f"shm{self.shm_name[:14]}context" class ContextSettings(BaseModel, validate_assignment=True): diff --git a/data_safe_haven/functions/__init__.py b/data_safe_haven/functions/__init__.py index 6b444728ff..179816e4f6 100644 --- a/data_safe_haven/functions/__init__.py +++ b/data_safe_haven/functions/__init__.py @@ -1,6 +1,5 @@ from .miscellaneous import ( allowed_dns_lookups, - as_dict, ordered_private_dns_zones, time_as_string, ) diff --git a/data_safe_haven/functions/miscellaneous.py b/data_safe_haven/functions/miscellaneous.py index 4b609f260e..8703fd39fe 100644 --- a/data_safe_haven/functions/miscellaneous.py +++ b/data_safe_haven/functions/miscellaneous.py @@ -1,5 +1,4 @@ import datetime -from typing import Any import pytz @@ -19,17 +18,6 @@ def allowed_dns_lookups() -> list[str]: return sorted({zone for zones in dns_lookups.values() for zone in zones}) -def as_dict(container: Any) -> dict[str, Any]: - if ( - not isinstance(container, dict) - and hasattr(container, "keys") - and all(isinstance(x, str) for x in container.keys()) - ): - msg = f"{container} {type(container)} is not a valid dict[str, Any]" - raise TypeError(msg) - return {str(k): v for k, v in container.items()} - - def ordered_private_dns_zones(resource_type: str | None = None) -> list[str]: """ Return required DNS zones for a given resource type. From 515668e87beaf892778a3a867cb0d4419b89bad1 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 21 Nov 2023 10:10:17 +0000 Subject: [PATCH 15/65] Update Context tests --- data_safe_haven/config/context_settings.py | 12 +++++----- tests_/config/test_context_settings.py | 26 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 00edc4c9ac..fb89f448d2 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -8,7 +8,7 @@ from typing import ClassVar import yaml -from pydantic import BaseModel, Field, ValidationError, computed_field, model_validator +from pydantic import BaseModel, Field, ValidationError, model_validator from yaml import YAMLError from data_safe_haven.exceptions import ( @@ -35,23 +35,23 @@ class Context(BaseModel, validate_assignment=True): subscription_name: AzureLongName storage_container_name: ClassVar[str] = "config" - @computed_field + @property def shm_name(self) -> str: return alphanumeric(self.name).lower() - @computed_field + @property def work_directory(self) -> Path: return config_dir() / self.shm_name - @computed_field + @property def config_filename(self) -> str: return f"config-{self.shm_name}.yaml" - @computed_field + @property def resource_group_name(self) -> str: return f"shm-{self.shm_name}-rg-context" - @computed_field + @property def storage_account_name(self) -> str: # maximum of 24 characters allowed return f"shm{self.shm_name[:14]}context" diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 8f04ed3fc1..2b7339b628 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -17,6 +17,11 @@ def context_dict(): } +@fixture +def context(context_dict): + return Context(**context_dict) + + class TestContext: def test_constructor(self, context_dict): context = Context(**context_dict) @@ -24,6 +29,7 @@ def test_constructor(self, context_dict): assert all([ getattr(context, item) == context_dict[item] for item in context_dict.keys() ]) + assert context.storage_container_name == "config" def test_invalid_guid(self, context_dict): context_dict["admin_group_id"] = "not a guid" @@ -43,6 +49,26 @@ def test_invalid_subscription_name(self, context_dict): Context(**context_dict) assert "String should have at most 64 characters" in exc + def test_shm_name(self, context): + assert context.shm_name == "acmedeployment" + + def test_work_directory(self, context): + assert "data_safe_haven/acmedeployment" in str(context.work_directory) + + def test_config_filename(self, context): + assert context.config_filename == "config-acmedeployment.yaml" + + def test_resource_group_name(self, context): + assert context.resource_group_name == "shm-acmedeployment-rg-context" + + def test_storage_account_name(self, context): + assert context.storage_account_name == "shmacmedeploymentcontext" + + def test_long_storage_account_name(self, context_dict): + context_dict["name"] = "very "*5 + "long name" + context = Context(**context_dict) + assert context.storage_account_name == "shmveryveryveryvecontext" + @fixture def context_yaml(): From a74154264720b8b30e6503e34fd9cb5d0e330608 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 21 Nov 2023 10:15:16 +0000 Subject: [PATCH 16/65] Remove ConfigSectionContext --- data_safe_haven/config/config.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 88826e590d..979d0546a8 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -25,8 +25,6 @@ ) from data_safe_haven.utility.annotated_types import ( AzureLocation, - AzureLongName, - AzureShortName, AzureVmSku, EmailAdress, Guid, @@ -42,14 +40,6 @@ class ConfigSectionAzure(BaseModel, validate_assignment=True): tenant_id: Guid -class ConfigSectionContext(BaseModel, validate_assignment=True): - key_vault_name: AzureShortName - managed_identity_name: AzureLongName - resource_group_name: AzureLongName - storage_account_name: AzureShortName - storage_container_name: AzureLongName - - class ConfigSectionPulumi(BaseModel, validate_assignment=True): encryption_key_name: str = "pulumi-encryption-key" encryption_key_version: str From acb01b8af6b308161e854fe6340f3701458ff030 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 21 Nov 2023 14:58:54 +0000 Subject: [PATCH 17/65] Improve design of config sections --- data_safe_haven/config/config.py | 79 +++++++++++++++++++------- pyproject.toml | 12 ++-- tests_/commands/test_context.py | 28 ++++----- tests_/config/test_context_settings.py | 52 +++++++---------- 4 files changed, 104 insertions(+), 67 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 979d0546a8..552395a498 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -2,9 +2,10 @@ from __future__ import annotations import pathlib +from typing import ClassVar import yaml -from pydantic import BaseModel, Field, ValidationError, computed_field +from pydantic import BaseModel, Field, ValidationError, computed_field, field_validator from yaml import YAMLError from data_safe_haven import __version__ @@ -39,6 +40,17 @@ class ConfigSectionAzure(BaseModel, validate_assignment=True): subscription_id: Guid tenant_id: Guid + @classmethod + def from_context( + cls, context: Context, subscription_id: Guid, tenant_id: Guid + ) -> ConfigSectionAzure: + return ConfigSectionAzure( + admin_group_id=context.admin_group_id, + location=context.location, + subscription_id=subscription_id, + tenant_id=tenant_id, + ) + class ConfigSectionPulumi(BaseModel, validate_assignment=True): encryption_key_name: str = "pulumi-encryption-key" @@ -55,6 +67,25 @@ class ConfigSectionSHM(BaseModel, validate_assignment=True): name: str timezone: TimeZone + @classmethod + def from_context( + cls, + context: Context, + aad_tenant_id: Guid, + admin_email_address: EmailAdress, + admin_ip_addresses: list[IpAddress], + fqdn: str, + timezone: TimeZone, + ) -> ConfigSectionSHM: + return ConfigSectionSHM( + aad_tenant_id=aad_tenant_id, + admin_email_address=admin_email_address, + admin_ip_addresses=admin_ip_addresses, + fqdn=fqdn, + name=context.shm_name, + timezone=timezone, + ) + def update( self, *, @@ -62,7 +93,7 @@ def update( admin_email_address: str | None = None, admin_ip_addresses: list[str] | None = None, fqdn: str | None = None, - timezone: str | None = None, + timezone: TimeZone | None = None, ) -> None: """Update SHM settings @@ -132,30 +163,38 @@ def update( class ConfigSectionSRE(BaseModel, validate_assignment=True): - databases: list[DatabaseSystem] - data_provider_ip_addresses: list[IpAddress] + databases: list[DatabaseSystem] = Field(default_factory=list[DatabaseSystem]) + data_provider_ip_addresses: list[IpAddress] = Field(default_factory=list[IpAddress]) index: int = Field(ge=0) - remote_desktop: ConfigSubsectionRemoteDesktopOpts - workspace_skus: list[AzureVmSku] - research_user_ip_addresses: list[IpAddress] + remote_desktop: ConfigSubsectionRemoteDesktopOpts = Field( + default_factory=ConfigSubsectionRemoteDesktopOpts + ) + workspace_skus: list[AzureVmSku] = Field(default_factory=list[AzureVmSku]) + research_user_ip_addresses: list[IpAddress] = Field(default_factory=list[IpAddress]) software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE + @field_validator("databases") + @classmethod + def all_databases_must_be_unique( + cls, v: list[DatabaseSystem] + ) -> list[DatabaseSystem]: + if len(v) != len(set(v)): + msg = "all databases must be unique" + raise ValueError(msg) + return v + def update( self, *, - allow_copy: bool | None = None, - allow_paste: bool | None = None, - data_provider_ip_addresses: list[str] | None = None, + data_provider_ip_addresses: list[IpAddress] | None = None, databases: list[DatabaseSystem] | None = None, - workspace_skus: list[str] | None = None, + workspace_skus: list[AzureVmSku] | None = None, software_packages: SoftwarePackageCategory | None = None, - user_ip_addresses: list[str] | None = None, + user_ip_addresses: list[IpAddress] | None = None, ) -> None: """Update SRE settings Args: - allow_copy: Allow/deny copying text out of the SRE - allow_paste: Allow/deny pasting text into the SRE databases: List of database systems to deploy data_provider_ip_addresses: List of IP addresses belonging to data providers workspace_skus: List of VM SKUs for workspaces @@ -177,8 +216,6 @@ def update( logger.info( f"[bold]Databases available to users[/] will be [green]{[database.value for database in self.databases]}[/]." ) - # Pass allow_copy and allow_paste to remote desktop - self.remote_desktop.update(allow_copy=allow_copy, allow_paste=allow_paste) # Set research desktop SKUs if workspace_skus: self.workspace_skus = workspace_skus @@ -199,9 +236,13 @@ def update( class ConfigSectionTags(BaseModel, validate_assignment=True): deployment: str - deployed_by: str = "Python" - project: str = "Data Safe Haven" - version: str = __version__ + deployed_by: ClassVar[str] = "Python" + project: ClassVar[str] = "Data Safe Haven" + version: ClassVar[str] = __version__ + + @classmethod + def from_context(cls, context: Context) -> ConfigSectionTags: + return ConfigSectionTags(deployment=context.name) class Config(BaseModel, validate_assignment=True): diff --git a/pyproject.toml b/pyproject.toml index f901ea955e..f648c705bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,12 +77,12 @@ dependencies = [ typing = "mypy {args:data_safe_haven}" style = [ - "ruff {args:data_safe_haven}", - "black --check --diff {args:data_safe_haven}", + "ruff {args:data_safe_haven tests_}", + "black --check --diff {args:data_safe_haven tests_}", ] fmt = [ - "black {args:data_safe_haven}", - "ruff --fix {args:data_safe_haven}", + "black {args:data_safe_haven tests_}", + "ruff --fix {args:data_safe_haven tests_}", "style", ] all = [ @@ -148,6 +148,10 @@ known-first-party = ["data_safe_haven"] [tool.ruff.flake8-tidy-imports] ban-relative-imports = "parents" +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests_/**/*" = ["PLR2004", "S101", "TID252"] + [tool.mypy] disallow_subclassing_any = false # allow subclassing of types from third-party libraries files = "data_safe_haven" # run mypy over this directory diff --git a/tests_/commands/test_context.py b/tests_/commands/test_context.py index 0cdf878c44..5d8525f4c5 100644 --- a/tests_/commands/test_context.py +++ b/tests_/commands/test_context.py @@ -1,10 +1,10 @@ +from pytest import fixture +from typer.testing import CliRunner + from data_safe_haven.commands.context import context_command_group from data_safe_haven.config import Config from data_safe_haven.context import Context -from pytest import fixture -from typer.testing import CliRunner - context_settings = """\ selected: acme_deployment contexts: @@ -88,7 +88,7 @@ def test_add(self, runner): "uksouth", "--subscription", "Data Safe Haven (Example)", - ] + ], ) assert result.exit_code == 0 result = runner.invoke(context_command_group, ["switch", "example"]) @@ -108,7 +108,7 @@ def test_add_duplicate(self, runner): "uksouth", "--subscription", "Data Safe Haven (Acme)", - ] + ], ) assert result.exit_code == 1 # Unable to check error as this is written outside of any Typer @@ -128,7 +128,7 @@ def test_add_invalid_uuid(self, runner): "uksouth", "--subscription", "Data Safe Haven (Example)", - ] + ], ) assert result.exit_code == 2 # This works because the context_command_group Typer writes this error @@ -142,7 +142,7 @@ def test_add_missing_ags(self, runner): "example", "--name", "Example", - ] + ], ) assert result.exit_code == 2 assert "Missing option" in result.stderr @@ -162,7 +162,7 @@ def test_add_bootstrap(self, tmp_contexts, runner): "uksouth", "--subscription", "Data Safe Haven (Acme)", - ] + ], ) assert result.exit_code == 0 assert (tmp_contexts / "contexts.yaml").exists() @@ -201,11 +201,11 @@ def test_remove_invalid(self, runner): class TestCreate: def test_create(self, runner, monkeypatch): - def mock_create(self): - print("mock create") + def mock_create(): + print("mock create") # noqa: T201 - def mock_upload(self): - print("mock upload") + def mock_upload(): + print("mock upload") # noqa: T201 monkeypatch.setattr(Context, "create", mock_create) monkeypatch.setattr(Config, "upload", mock_upload) @@ -218,8 +218,8 @@ def mock_upload(self): class TestTeardown: def test_teardown(self, runner, monkeypatch): - def mock_teardown(self): - print("mock teardown") + def mock_teardown(): + print("mock teardown") # noqa: T201 monkeypatch.setattr(Context, "teardown", mock_teardown) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 2b7339b628..41901c4c7e 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -1,34 +1,22 @@ -from data_safe_haven.config.context_settings import Context, ContextSettings -from data_safe_haven.exceptions import DataSafeHavenConfigError, DataSafeHavenParameterError - import pytest import yaml from pydantic import ValidationError from pytest import fixture - -@fixture -def context_dict(): - return { - "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - "location": "uksouth", - "name": "Acme Deployment", - "subscription_name": "Data Safe Haven (Acme)" - } - - -@fixture -def context(context_dict): - return Context(**context_dict) +from data_safe_haven.config.context_settings import Context, ContextSettings +from data_safe_haven.exceptions import ( + DataSafeHavenConfigError, + DataSafeHavenParameterError, +) class TestContext: def test_constructor(self, context_dict): context = Context(**context_dict) assert isinstance(context, Context) - assert all([ + assert all( getattr(context, item) == context_dict[item] for item in context_dict.keys() - ]) + ) assert context.storage_container_name == "config" def test_invalid_guid(self, context_dict): @@ -44,7 +32,7 @@ def test_invalid_location(self, context_dict): assert "Value error, Expected valid Azure location" in exc def test_invalid_subscription_name(self, context_dict): - context_dict["subscription_name"] = "very "*12 + "long name" + context_dict["subscription_name"] = "very " * 12 + "long name" with pytest.raises(ValidationError) as exc: Context(**context_dict) assert "String should have at most 64 characters" in exc @@ -65,7 +53,7 @@ def test_storage_account_name(self, context): assert context.storage_account_name == "shmacmedeploymentcontext" def test_long_storage_account_name(self, context_dict): - context_dict["name"] = "very "*5 + "long name" + context_dict["name"] = "very " * 5 + "long name" context = Context(**context_dict) assert context.storage_account_name == "shmveryveryveryvecontext" @@ -119,7 +107,9 @@ def test_missing_selected(self, context_yaml): assert "Field required" in exc def test_invalid_selected_input(self, context_yaml): - context_yaml = context_yaml.replace("selected: acme_deployment", "selected: invalid") + context_yaml = context_yaml.replace( + "selected: acme_deployment", "selected: invalid" + ) with pytest.raises(DataSafeHavenParameterError) as exc: ContextSettings.from_yaml(context_yaml) @@ -153,24 +143,26 @@ def test_invalid_selected(self, context_settings): def test_context(self, context_yaml, context_settings): yaml_dict = yaml.safe_load(context_yaml) assert isinstance(context_settings.context, Context) - assert all([ - getattr(context_settings.context, item) == yaml_dict["contexts"]["acme_deployment"][item] + assert all( + getattr(context_settings.context, item) + == yaml_dict["contexts"]["acme_deployment"][item] for item in yaml_dict["contexts"]["acme_deployment"].keys() - ]) + ) def test_set_context(self, context_yaml, context_settings): yaml_dict = yaml.safe_load(context_yaml) context_settings.selected = "gems" assert isinstance(context_settings.context, Context) - assert all([ - getattr(context_settings.context, item) == yaml_dict["contexts"]["gems"][item] + assert all( + getattr(context_settings.context, item) + == yaml_dict["contexts"]["gems"][item] for item in yaml_dict["contexts"]["gems"].keys() - ]) + ) def test_available(self, context_settings): available = context_settings.available assert isinstance(available, list) - assert all([isinstance(item, str) for item in available]) + assert all(isinstance(item, str) for item in available) assert available == ["acme_deployment", "gems"] def test_update(self, context_settings): @@ -238,7 +230,7 @@ def test_write(self, tmp_path, context_yaml): settings.selected = "gems" settings.update(name="replaced") settings.write(config_file_path) - with open(config_file_path, "r") as f: + with open(config_file_path) as f: context_dict = yaml.safe_load(f) assert context_dict["selected"] == "gems" assert context_dict["contexts"]["gems"]["name"] == "replaced" From e7ca180818efb2ec319685cbfe77cb14b3a679a9 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 22 Nov 2023 10:05:38 +0000 Subject: [PATCH 18/65] Add missing tests --- tests_/config/conftest.py | 18 ++++ tests_/config/test_config.py | 178 +++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 tests_/config/conftest.py create mode 100644 tests_/config/test_config.py diff --git a/tests_/config/conftest.py b/tests_/config/conftest.py new file mode 100644 index 0000000000..1ac39c86df --- /dev/null +++ b/tests_/config/conftest.py @@ -0,0 +1,18 @@ +from pytest import fixture + +from data_safe_haven.config.context_settings import Context + + +@fixture +def context_dict(): + return { + "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "location": "uksouth", + "name": "Acme Deployment", + "subscription_name": "Data Safe Haven (Acme)", + } + + +@fixture +def context(context_dict): + return Context(**context_dict) diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py new file mode 100644 index 0000000000..2b55b8b827 --- /dev/null +++ b/tests_/config/test_config.py @@ -0,0 +1,178 @@ +import pytest +from pydantic import ValidationError +from pytest import fixture + +from data_safe_haven.config.config import ( + ConfigSectionAzure, + ConfigSectionPulumi, + ConfigSectionSHM, + ConfigSectionSRE, + ConfigSectionTags, + ConfigSubsectionRemoteDesktopOpts, +) +from data_safe_haven.utility.enums import DatabaseSystem, SoftwarePackageCategory +from data_safe_haven.version import __version__ + + +class TestConfigSectionAzure: + def test_constructor(self): + ConfigSectionAzure( + admin_group_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + location="uksouth", + subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + ) + + def test_from_context(self, context): + azure_config = ConfigSectionAzure.from_context( + context=context, + subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + ) + assert azure_config.location == context.location + + +class TestConfigSectionPulumi: + def test_constructor_defaults(self): + pulumi_config = ConfigSectionPulumi(encryption_key_version="lorem") + assert pulumi_config.encryption_key_name == "pulumi-encryption-key" + assert pulumi_config.stacks == {} + assert pulumi_config.storage_container_name == "pulumi" + + +@fixture +def shm_config(): + return ConfigSectionSHM( + aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + admin_email_address="admin@example.com", + admin_ip_addresses=["0.0.0.0"], # noqa: S104 + fqdn="shm.acme.com", + name="ACME SHM", + timezone="UTC", + ) + + +class TestConfigSectionSHM: + def test_constructor(self): + ConfigSectionSHM( + aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + admin_email_address="admin@example.com", + admin_ip_addresses=["0.0.0.0"], # noqa: S104 + fqdn="shm.acme.com", + name="ACME SHM", + timezone="UTC", + ) + + def test_from_context(self, context): + shm_config = ConfigSectionSHM.from_context( + context=context, + aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + admin_email_address="admin@example.com", + admin_ip_addresses=["0.0.0.0"], # noqa: S104 + fqdn="shm.acme.com", + timezone="UTC", + ) + assert shm_config.name == context.shm_name + + def test_update(self, shm_config): + assert shm_config.fqdn == "shm.acme.com" + shm_config.update(fqdn="modified") + assert shm_config.fqdn == "modified" + + def test_update_validation(self, shm_config): + with pytest.raises(ValidationError) as exc: + shm_config.update(admin_email_address="not an email address") + assert "Value error, Expected valid email address" in exc + assert "not an email address" in exc + + +@fixture +def remote_desktop_config(): + return ConfigSubsectionRemoteDesktopOpts() + + +class TestConfigSubsectionRemoteDesktopOpts: + def test_constructor(self): + ConfigSubsectionRemoteDesktopOpts(allow_copy=True, allow_paste=True) + + def test_constructor_defaults(self): + remote_desktop_config = ConfigSubsectionRemoteDesktopOpts() + assert not all( + [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + ) + + def test_update(self, remote_desktop_config): + assert not all( + [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + ) + remote_desktop_config.update(allow_copy=True, allow_paste=True) + assert all( + [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + ) + + +class TestConfigSectionSRE: + def test_constructor(self, remote_desktop_config): + sre_config = ConfigSectionSRE( + databases=[DatabaseSystem.POSTGRESQL], + data_provider_ip_addresses=["0.0.0.0"], # noqa: S104 + index=0, + remote_desktop=remote_desktop_config, + workspace_skus=["Standard_D2s_v4"], + research_user_ip_addresses=["0.0.0.0"], # noqa: S104 + software_packages=SoftwarePackageCategory.ANY, + ) + assert sre_config.data_provider_ip_addresses[0] == "0.0.0.0/32" + + def test_constructor_defaults(self, remote_desktop_config): + sre_config = ConfigSectionSRE(index=0) + assert sre_config.databases == [] + assert sre_config.data_provider_ip_addresses == [] + assert sre_config.remote_desktop == remote_desktop_config + assert sre_config.workspace_skus == [] + assert sre_config.research_user_ip_addresses == [] + assert sre_config.software_packages == SoftwarePackageCategory.NONE + + def test_all_databases_must_be_unique(self): + with pytest.raises(ValueError) as exc: + ConfigSectionSRE( + index=0, + databases=[DatabaseSystem.POSTGRESQL, DatabaseSystem.POSTGRESQL], + ) + assert "all databases must be unique" in exc + + def test_update(self): + sre_config = ConfigSectionSRE(index=0) + assert sre_config.databases == [] + assert sre_config.data_provider_ip_addresses == [] + assert sre_config.workspace_skus == [] + assert sre_config.research_user_ip_addresses == [] + assert sre_config.software_packages == SoftwarePackageCategory.NONE + sre_config.update( + data_provider_ip_addresses=["0.0.0.0"], # noqa: S104 + databases=[DatabaseSystem.MICROSOFT_SQL_SERVER], + workspace_skus=["Standard_D8s_v4"], + software_packages=SoftwarePackageCategory.ANY, + user_ip_addresses=["0.0.0.0"], # noqa: S104 + ) + assert sre_config.databases == [DatabaseSystem.MICROSOFT_SQL_SERVER] + assert sre_config.data_provider_ip_addresses == ["0.0.0.0/32"] + assert sre_config.workspace_skus == ["Standard_D8s_v4"] + assert sre_config.research_user_ip_addresses == ["0.0.0.0/32"] + assert sre_config.software_packages == SoftwarePackageCategory.ANY + + +class TestConfigSectionTags: + def test_constructor(self): + tags_config = ConfigSectionTags(deployment="Test Deployment") + assert tags_config.deployment == "Test Deployment" + assert tags_config.deployed_by == "Python" + assert tags_config.project == "Data Safe Haven" + assert tags_config.version == __version__ + + def test_from_context(self, context): + tags_config = ConfigSectionTags.from_context(context) + assert tags_config.deployment == "Acme Deployment" + assert tags_config.deployed_by == "Python" + assert tags_config.project == "Data Safe Haven" + assert tags_config.version == __version__ From e8e2e5b3eb6acee52eafccc379abd4a6f50fb8e8 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 22 Nov 2023 11:05:30 +0000 Subject: [PATCH 19/65] Add basic tests for Config --- data_safe_haven/config/config.py | 16 ++++-- tests_/config/test_config.py | 92 +++++++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 552395a498..3e68f9205d 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -5,7 +5,7 @@ from typing import ClassVar import yaml -from pydantic import BaseModel, Field, ValidationError, computed_field, field_validator +from pydantic import BaseModel, Field, ValidationError, field_validator from yaml import YAMLError from data_safe_haven import __version__ @@ -251,12 +251,22 @@ class Config(BaseModel, validate_assignment=True): pulumi: ConfigSectionPulumi | None = None shm: ConfigSectionSHM | None = None tags: ConfigSectionTags | None = None - sres: dict[str, ConfigSectionSRE] | None = None + sres: dict[str, ConfigSectionSRE] = Field( + default_factory=dict[str, ConfigSectionSRE] + ) - @computed_field + @property def work_directory(self) -> str: return self.context.work_directory + def is_complete(self, *, require_sres: bool) -> bool: + if require_sres: + if not self.sres: + return False + if not all((self.azure, self.pulumi, self.shm, self.tags)): + return False + return True + def read_stack(self, name: str, path: pathlib.Path) -> None: """Add a Pulumi stack file to config""" with open(path, encoding="utf-8") as f_stack: diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 2b55b8b827..5542735cc1 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -3,6 +3,7 @@ from pytest import fixture from data_safe_haven.config.config import ( + Config, ConfigSectionAzure, ConfigSectionPulumi, ConfigSectionSHM, @@ -14,6 +15,15 @@ from data_safe_haven.version import __version__ +@fixture +def azure_config(context): + return ConfigSectionAzure.from_context( + context=context, + subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + ) + + class TestConfigSectionAzure: def test_constructor(self): ConfigSectionAzure( @@ -32,6 +42,11 @@ def test_from_context(self, context): assert azure_config.location == context.location +@fixture +def pulumi_config(): + return ConfigSectionPulumi(encryption_key_version="lorem") + + class TestConfigSectionPulumi: def test_constructor_defaults(self): pulumi_config = ConfigSectionPulumi(encryption_key_version="lorem") @@ -41,13 +56,13 @@ def test_constructor_defaults(self): @fixture -def shm_config(): - return ConfigSectionSHM( +def shm_config(context): + return ConfigSectionSHM.from_context( + context=context, aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", admin_email_address="admin@example.com", admin_ip_addresses=["0.0.0.0"], # noqa: S104 fqdn="shm.acme.com", - name="ACME SHM", timezone="UTC", ) @@ -98,16 +113,16 @@ def test_constructor(self): def test_constructor_defaults(self): remote_desktop_config = ConfigSubsectionRemoteDesktopOpts() assert not all( - [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + (remote_desktop_config.allow_copy, remote_desktop_config.allow_paste) ) def test_update(self, remote_desktop_config): assert not all( - [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + (remote_desktop_config.allow_copy, remote_desktop_config.allow_paste) ) remote_desktop_config.update(allow_copy=True, allow_paste=True) assert all( - [remote_desktop_config.allow_copy, remote_desktop_config.allow_paste] + (remote_desktop_config.allow_copy, remote_desktop_config.allow_paste) ) @@ -162,6 +177,11 @@ def test_update(self): assert sre_config.software_packages == SoftwarePackageCategory.ANY +@fixture +def tags_config(context): + return ConfigSectionTags.from_context(context) + + class TestConfigSectionTags: def test_constructor(self): tags_config = ConfigSectionTags(deployment="Test Deployment") @@ -176,3 +196,63 @@ def test_from_context(self, context): assert tags_config.deployed_by == "Python" assert tags_config.project == "Data Safe Haven" assert tags_config.version == __version__ + + +@fixture +def config_no_sres(context, azure_config, pulumi_config, shm_config, tags_config): + return Config( + context=context, + azure=azure_config, + pulumi=pulumi_config, + shm=shm_config, + tags=tags_config + ) + + +@fixture +def config_sres(context, azure_config, pulumi_config, shm_config, tags_config): + sre_config_1 = ConfigSectionSRE(index=0) + sre_config_2 = ConfigSectionSRE(index=1) + return Config( + context=context, + azure=azure_config, + pulumi=pulumi_config, + shm=shm_config, + sres={ + "sre1": sre_config_1, + "sre2": sre_config_2, + }, + tags=tags_config + ) + + +class TestConfig: + def test_constructor_defaults(self, context): + config = Config(context=context) + assert config.context == context + assert not any( + (config.azure, config.pulumi, config.shm, config.tags, config.sres) + ) + + @pytest.mark.parametrize("require_sres", [False, True]) + def test_is_complete_bare(self, context, require_sres): + config = Config(context=context) + assert config.is_complete(require_sres=require_sres) is False + + def test_constructor(self, context, azure_config, pulumi_config, shm_config, tags_config): + config = Config( + context=context, + azure=azure_config, + pulumi=pulumi_config, + shm=shm_config, + tags=tags_config + ) + assert not config.sres + + @pytest.mark.parametrize("require_sres,expected", [(False, True), (True, False)]) + def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): + assert config_no_sres.is_complete(require_sres=require_sres) is expected + + @pytest.mark.parametrize("require_sres", [False, True]) + def test_is_complete_sres(self, config_sres, require_sres): + assert config_sres.is_complete(require_sres=require_sres) From 8f98a9e8087c33c1935c2f5357a49c8c13bfca74 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 23 Nov 2023 13:37:12 +0000 Subject: [PATCH 20/65] Add to_yaml method --- data_safe_haven/config/config.py | 24 ++++++++++-- data_safe_haven/utility/enums.py | 4 +- tests_/config/test_config.py | 63 ++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 3e68f9205d..3fa624c0e7 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -5,7 +5,14 @@ from typing import ClassVar import yaml -from pydantic import BaseModel, Field, ValidationError, field_validator +from pydantic import ( + BaseModel, + Field, + FieldSerializationInfo, + ValidationError, + field_serializer, + field_validator, +) from yaml import YAMLError from data_safe_haven import __version__ @@ -55,7 +62,7 @@ def from_context( class ConfigSectionPulumi(BaseModel, validate_assignment=True): encryption_key_name: str = "pulumi-encryption-key" encryption_key_version: str - stacks: dict[str, str] = Field(default_factory=dict) + stacks: dict[str, str] = Field(default_factory=dict[str, str]) storage_container_name: str = "pulumi" @@ -183,6 +190,14 @@ def all_databases_must_be_unique( raise ValueError(msg) return v + @field_serializer("software_packages") + def software_packages_serializer( + self, + packages: SoftwarePackageCategory, + info: FieldSerializationInfo, # noqa: ARG002 + ) -> str: + return packages.value + def update( self, *, @@ -319,10 +334,13 @@ def from_remote(cls, context: Context) -> Config: ) return Config.from_yaml(config_yaml) + def to_yaml(self) -> str: + return yaml.dump(self.model_dump(), indent=2) + def upload(self) -> None: """Upload config to Azure storage""" self.azure_api.upload_blob( - yaml.dump(self.model_dump, indent=2), + self.to_yaml(), self.context.config_filename, self.context.resource_group_name, self.context.storage_account_name, diff --git a/data_safe_haven/utility/enums.py b/data_safe_haven/utility/enums.py index 06fc204392..0e1d15ce4a 100644 --- a/data_safe_haven/utility/enums.py +++ b/data_safe_haven/utility/enums.py @@ -1,11 +1,13 @@ -from enum import Enum +from enum import UNIQUE, Enum, verify +@verify(UNIQUE) class DatabaseSystem(str, Enum): MICROSOFT_SQL_SERVER = "mssql" POSTGRESQL = "postgresql" +@verify(UNIQUE) class SoftwarePackageCategory(str, Enum): ANY = "any" PRE_APPROVED = "pre-approved" diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 5542735cc1..9da08b8cc9 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -226,6 +226,57 @@ def config_sres(context, azure_config, pulumi_config, shm_config, tags_config): ) +@fixture +def config_yaml(): + return """azure: + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd +context: + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + name: Acme Deployment + subscription_name: Data Safe Haven (Acme) +pulumi: + encryption_key_name: pulumi-encryption-key + encryption_key_version: lorem + stacks: {} + storage_container_name: pulumi +shm: + aad_tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + admin_email_address: admin@example.com + admin_ip_addresses: + - 0.0.0.0/32 + fqdn: shm.acme.com + name: acmedeployment + timezone: UTC +sres: + sre1: + data_provider_ip_addresses: [] + databases: [] + index: 0 + remote_desktop: + allow_copy: false + allow_paste: false + research_user_ip_addresses: [] + software_packages: none + workspace_skus: [] + sre2: + data_provider_ip_addresses: [] + databases: [] + index: 1 + remote_desktop: + allow_copy: false + allow_paste: false + research_user_ip_addresses: [] + software_packages: none + workspace_skus: [] +tags: + deployment: Acme Deployment +""" + + class TestConfig: def test_constructor_defaults(self, context): config = Config(context=context) @@ -256,3 +307,15 @@ def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): @pytest.mark.parametrize("require_sres", [False, True]) def test_is_complete_sres(self, config_sres, require_sres): assert config_sres.is_complete(require_sres=require_sres) + + def test_work_directory(self, config_sres): + config = config_sres + assert config.work_directory == config.context.work_directory + + def test_to_yaml(self, config_sres, config_yaml): + assert config_sres.to_yaml() == config_yaml + + def test_from_yaml(self, config_sres, config_yaml): + config = Config.from_yaml(config_yaml) + assert config == config_sres + assert isinstance(config.sres["sre1"].software_packages, SoftwarePackageCategory) From bde5fd4fe17f20609ffe0566dc23b35a31cfb51b Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 23 Nov 2023 13:59:13 +0000 Subject: [PATCH 21/65] Add upload and from remote tests --- data_safe_haven/config/config.py | 3 ++- tests_/config/test_config.py | 33 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 3fa624c0e7..76ff4a4f61 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -339,7 +339,8 @@ def to_yaml(self) -> str: def upload(self) -> None: """Upload config to Azure storage""" - self.azure_api.upload_blob( + azure_api = AzureApi(subscription_name=self.context.subscription_name) + azure_api.upload_blob( self.to_yaml(), self.context.config_filename, self.context.resource_group_name, diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 9da08b8cc9..b331f43cc7 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -11,6 +11,7 @@ ConfigSectionTags, ConfigSubsectionRemoteDesktopOpts, ) +from data_safe_haven.external import AzureApi from data_safe_haven.utility.enums import DatabaseSystem, SoftwarePackageCategory from data_safe_haven.version import __version__ @@ -319,3 +320,35 @@ def test_from_yaml(self, config_sres, config_yaml): config = Config.from_yaml(config_yaml) assert config == config_sres assert isinstance(config.sres["sre1"].software_packages, SoftwarePackageCategory) + + def test_upload(self, config_sres, monkeypatch): + def mock_upload_blob( + self, + blob_data: bytes | str, + blob_name: str, + resource_group_name: str, + storage_account_name: str, + storage_container_name: str, + ): + pass + + monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) + config_sres.upload() + + def test_from_remote(self, context, config_sres, config_yaml, monkeypatch): + def mock_download_blob( + self, + blob_name: str, + resource_group_name: str, + storage_account_name: str, + storage_container_name: str, + ): + assert blob_name == context.config_filename + assert resource_group_name == context.resource_group_name + assert storage_account_name == context.storage_account_name + assert storage_container_name == context.storage_container_name + return config_yaml + + monkeypatch.setattr(AzureApi, "download_blob", mock_download_blob) + config = Config.from_remote(context) + assert config == config_sres From e908111b16c9d92900c7a133bf81a7dcbbf386d2 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 23 Nov 2023 14:01:34 +0000 Subject: [PATCH 22/65] Run lint:fmt --- tests_/config/test_config.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index b331f43cc7..426f325585 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -206,7 +206,7 @@ def config_no_sres(context, azure_config, pulumi_config, shm_config, tags_config azure=azure_config, pulumi=pulumi_config, shm=shm_config, - tags=tags_config + tags=tags_config, ) @@ -223,7 +223,7 @@ def config_sres(context, azure_config, pulumi_config, shm_config, tags_config): "sre1": sre_config_1, "sre2": sre_config_2, }, - tags=tags_config + tags=tags_config, ) @@ -291,13 +291,15 @@ def test_is_complete_bare(self, context, require_sres): config = Config(context=context) assert config.is_complete(require_sres=require_sres) is False - def test_constructor(self, context, azure_config, pulumi_config, shm_config, tags_config): + def test_constructor( + self, context, azure_config, pulumi_config, shm_config, tags_config + ): config = Config( context=context, azure=azure_config, pulumi=pulumi_config, shm=shm_config, - tags=tags_config + tags=tags_config, ) assert not config.sres @@ -319,16 +321,18 @@ def test_to_yaml(self, config_sres, config_yaml): def test_from_yaml(self, config_sres, config_yaml): config = Config.from_yaml(config_yaml) assert config == config_sres - assert isinstance(config.sres["sre1"].software_packages, SoftwarePackageCategory) + assert isinstance( + config.sres["sre1"].software_packages, SoftwarePackageCategory + ) def test_upload(self, config_sres, monkeypatch): def mock_upload_blob( - self, - blob_data: bytes | str, - blob_name: str, - resource_group_name: str, - storage_account_name: str, - storage_container_name: str, + self, # noqa: ARG001 + blob_data: bytes | str, # noqa: ARG001 + blob_name: str, # noqa: ARG001 + resource_group_name: str, # noqa: ARG001 + storage_account_name: str, # noqa: ARG001 + storage_container_name: str, # noqa: ARG001 ): pass @@ -337,7 +341,7 @@ def mock_upload_blob( def test_from_remote(self, context, config_sres, config_yaml, monkeypatch): def mock_download_blob( - self, + self, # noqa: ARG001 blob_name: str, resource_group_name: str, storage_account_name: str, From e44acd6c2c8e8dd5bc5cf6f27fc1ef61b427fa08 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 23 Nov 2023 14:08:41 +0000 Subject: [PATCH 23/65] Tidy Config --- data_safe_haven/config/config.py | 36 +++++++++++----------- tests_/config/test_config.py | 52 ++++++++++++++++---------------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 76ff4a4f61..6e1c222dd3 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -1,7 +1,7 @@ """Configuration file backed by blob storage""" from __future__ import annotations -import pathlib +from pathlib import Path from typing import ClassVar import yaml @@ -282,28 +282,34 @@ def is_complete(self, *, require_sres: bool) -> bool: return False return True - def read_stack(self, name: str, path: pathlib.Path) -> None: - """Add a Pulumi stack file to config""" - with open(path, encoding="utf-8") as f_stack: - pulumi_cfg = f_stack.read() - self.pulumi.stacks[name] = b64encode(pulumi_cfg) + def sre(self, name: str) -> ConfigSectionSRE: + """Return the config entry for this SRE creating it if it does not exist""" + if name not in self.sres.keys(): + highest_index = max(0 + sre.index for sre in self.sres.values()) + self.sres[name].index = highest_index + 1 + return self.sres[name] def remove_sre(self, name: str) -> None: """Remove SRE config section by name""" if name in self.sres.keys(): del self.sres[name] + def add_stack(self, name: str, path: Path) -> None: + """Add a Pulumi stack file to config""" + with open(path, encoding="utf-8") as f_stack: + pulumi_cfg = f_stack.read() + self.pulumi.stacks[name] = b64encode(pulumi_cfg) + def remove_stack(self, name: str) -> None: """Remove Pulumi stack section by name""" if name in self.pulumi.stacks.keys(): del self.pulumi.stacks[name] - def sre(self, name: str) -> ConfigSectionSRE: - """Return the config entry for this SRE creating it if it does not exist""" - if name not in self.sres.keys(): - highest_index = max([0] + [sre.index for sre in self.sres.values()]) - self.sres[name].index = highest_index + 1 - return self.sres[name] + def write_stack(self, name: str, path: Path) -> None: + """Write a Pulumi stack file from config""" + pulumi_cfg = b64decode(self.pulumi.stacks[name]) + with open(path, "w", encoding="utf-8") as f_stack: + f_stack.write(pulumi_cfg) @classmethod def from_yaml(cls, config_yaml: str) -> Config: @@ -347,9 +353,3 @@ def upload(self) -> None: self.context.storage_account_name, self.context.storage_container_name, ) - - def write_stack(self, name: str, path: pathlib.Path) -> None: - """Write a Pulumi stack file from config""" - pulumi_cfg = b64decode(self.pulumi.stacks[name]) - with open(path, "w", encoding="utf-8") as f_stack: - f_stack.write(pulumi_cfg) diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 426f325585..452aa24c04 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -286,11 +286,6 @@ def test_constructor_defaults(self, context): (config.azure, config.pulumi, config.shm, config.tags, config.sres) ) - @pytest.mark.parametrize("require_sres", [False, True]) - def test_is_complete_bare(self, context, require_sres): - config = Config(context=context) - assert config.is_complete(require_sres=require_sres) is False - def test_constructor( self, context, azure_config, pulumi_config, shm_config, tags_config ): @@ -303,6 +298,15 @@ def test_constructor( ) assert not config.sres + def test_work_directory(self, config_sres): + config = config_sres + assert config.work_directory == config.context.work_directory + + @pytest.mark.parametrize("require_sres", [False, True]) + def test_is_complete_bare(self, context, require_sres): + config = Config(context=context) + assert config.is_complete(require_sres=require_sres) is False + @pytest.mark.parametrize("require_sres,expected", [(False, True), (True, False)]) def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): assert config_no_sres.is_complete(require_sres=require_sres) is expected @@ -311,13 +315,6 @@ def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): def test_is_complete_sres(self, config_sres, require_sres): assert config_sres.is_complete(require_sres=require_sres) - def test_work_directory(self, config_sres): - config = config_sres - assert config.work_directory == config.context.work_directory - - def test_to_yaml(self, config_sres, config_yaml): - assert config_sres.to_yaml() == config_yaml - def test_from_yaml(self, config_sres, config_yaml): config = Config.from_yaml(config_yaml) assert config == config_sres @@ -325,20 +322,6 @@ def test_from_yaml(self, config_sres, config_yaml): config.sres["sre1"].software_packages, SoftwarePackageCategory ) - def test_upload(self, config_sres, monkeypatch): - def mock_upload_blob( - self, # noqa: ARG001 - blob_data: bytes | str, # noqa: ARG001 - blob_name: str, # noqa: ARG001 - resource_group_name: str, # noqa: ARG001 - storage_account_name: str, # noqa: ARG001 - storage_container_name: str, # noqa: ARG001 - ): - pass - - monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) - config_sres.upload() - def test_from_remote(self, context, config_sres, config_yaml, monkeypatch): def mock_download_blob( self, # noqa: ARG001 @@ -356,3 +339,20 @@ def mock_download_blob( monkeypatch.setattr(AzureApi, "download_blob", mock_download_blob) config = Config.from_remote(context) assert config == config_sres + + def test_to_yaml(self, config_sres, config_yaml): + assert config_sres.to_yaml() == config_yaml + + def test_upload(self, config_sres, monkeypatch): + def mock_upload_blob( + self, # noqa: ARG001 + blob_data: bytes | str, # noqa: ARG001 + blob_name: str, # noqa: ARG001 + resource_group_name: str, # noqa: ARG001 + storage_account_name: str, # noqa: ARG001 + storage_container_name: str, # noqa: ARG001 + ): + pass + + monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) + config_sres.upload() From 6344a920745a06b2558f5f713a504600246306f8 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 23 Nov 2023 14:32:10 +0000 Subject: [PATCH 24/65] Correct sre method, add tests --- data_safe_haven/config/config.py | 2 +- tests_/config/test_config.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 6e1c222dd3..ebaace04e0 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -286,7 +286,7 @@ def sre(self, name: str) -> ConfigSectionSRE: """Return the config entry for this SRE creating it if it does not exist""" if name not in self.sres.keys(): highest_index = max(0 + sre.index for sre in self.sres.values()) - self.sres[name].index = highest_index + 1 + self.sres[name] = ConfigSectionSRE(index=highest_index+1) return self.sres[name] def remove_sre(self, name: str) -> None: diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 452aa24c04..7edbc44960 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -315,6 +315,27 @@ def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): def test_is_complete_sres(self, config_sres, require_sres): assert config_sres.is_complete(require_sres=require_sres) + def test_sre(self, config_sres): + sre1, sre2 = config_sres.sre("sre1"), config_sres.sre("sre2") + assert sre1.index == 0 + assert sre2.index == 1 + assert sre1 != sre2 + + def test_sre_create(self, config_sres): + sre1 = config_sres.sre("sre1") + sre3 = config_sres.sre("sre3") + assert isinstance(sre3, ConfigSectionSRE) + assert sre3.index == 2 + assert sre3 != sre1 + assert len(config_sres.sres) == 3 + + def test_remove_sre(self, config_sres): + assert len(config_sres.sres) == 2 + config_sres.remove_sre("sre1") + assert len(config_sres.sres) == 1 + assert "sre2" in config_sres.sres.keys() + assert "sre1" not in config_sres.sres.keys() + def test_from_yaml(self, config_sres, config_yaml): config = Config.from_yaml(config_yaml) assert config == config_sres From bf9435eea3ef534b94b3b6c79a5b63b5b5e63dd1 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 24 Nov 2023 14:18:22 +0000 Subject: [PATCH 25/65] Dynamically derive config items from context This simplifies the construction of a Config object and makes the serialised configuration less verbose and redundant. --- data_safe_haven/config/config.py | 64 ++++++++-------------- data_safe_haven/config/context_settings.py | 8 ++- tests_/config/test_config.py | 57 ++++--------------- 3 files changed, 42 insertions(+), 87 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index ebaace04e0..dafd205e26 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -2,7 +2,7 @@ from __future__ import annotations from pathlib import Path -from typing import ClassVar +from typing import Any, ClassVar import yaml from pydantic import ( @@ -42,20 +42,14 @@ class ConfigSectionAzure(BaseModel, validate_assignment=True): - admin_group_id: Guid - location: AzureLocation + admin_group_id: Guid = Field(exclude=True) + location: AzureLocation = Field(exclude=True) subscription_id: Guid tenant_id: Guid - @classmethod - def from_context( - cls, context: Context, subscription_id: Guid, tenant_id: Guid - ) -> ConfigSectionAzure: - return ConfigSectionAzure( - admin_group_id=context.admin_group_id, - location=context.location, - subscription_id=subscription_id, - tenant_id=tenant_id, + def __init__(self, context: Context, **kwargs: dict[Any, Any]): + super().__init__( + admin_group_id=context.admin_group_id, location=context.location, **kwargs ) @@ -71,27 +65,11 @@ class ConfigSectionSHM(BaseModel, validate_assignment=True): admin_email_address: EmailAdress admin_ip_addresses: list[IpAddress] fqdn: str - name: str + name: str = Field(exclude=True) timezone: TimeZone - @classmethod - def from_context( - cls, - context: Context, - aad_tenant_id: Guid, - admin_email_address: EmailAdress, - admin_ip_addresses: list[IpAddress], - fqdn: str, - timezone: TimeZone, - ) -> ConfigSectionSHM: - return ConfigSectionSHM( - aad_tenant_id=aad_tenant_id, - admin_email_address=admin_email_address, - admin_ip_addresses=admin_ip_addresses, - fqdn=fqdn, - name=context.shm_name, - timezone=timezone, - ) + def __init__(self, context: Context, **kwargs: dict[Any, Any]): + super().__init__(name=context.shm_name, **kwargs) def update( self, @@ -250,25 +228,24 @@ def update( class ConfigSectionTags(BaseModel, validate_assignment=True): - deployment: str + deployment: str = Field(exclude=True) deployed_by: ClassVar[str] = "Python" project: ClassVar[str] = "Data Safe Haven" version: ClassVar[str] = __version__ - @classmethod - def from_context(cls, context: Context) -> ConfigSectionTags: - return ConfigSectionTags(deployment=context.name) + def __init__(self, context: Context, **kwargs: dict[Any, Any]): + super().__init__(deployment=context.name, **kwargs) class Config(BaseModel, validate_assignment=True): azure: ConfigSectionAzure | None = None - context: Context + context: Context = Field(exclude=True) pulumi: ConfigSectionPulumi | None = None shm: ConfigSectionSHM | None = None - tags: ConfigSectionTags | None = None sres: dict[str, ConfigSectionSRE] = Field( default_factory=dict[str, ConfigSectionSRE] ) + tags: ConfigSectionTags | None = Field(exclude=True, default=None) @property def work_directory(self) -> str: @@ -286,7 +263,7 @@ def sre(self, name: str) -> ConfigSectionSRE: """Return the config entry for this SRE creating it if it does not exist""" if name not in self.sres.keys(): highest_index = max(0 + sre.index for sre in self.sres.values()) - self.sres[name] = ConfigSectionSRE(index=highest_index+1) + self.sres[name] = ConfigSectionSRE(index=highest_index + 1) return self.sres[name] def remove_sre(self, name: str) -> None: @@ -312,7 +289,7 @@ def write_stack(self, name: str, path: Path) -> None: f_stack.write(pulumi_cfg) @classmethod - def from_yaml(cls, config_yaml: str) -> Config: + def from_yaml(cls, context: Context, config_yaml: str) -> Config: try: config_dict = yaml.safe_load(config_yaml) except YAMLError as exc: @@ -323,6 +300,13 @@ def from_yaml(cls, config_yaml: str) -> Config: msg = "Unable to parse configuration as a dict." raise DataSafeHavenConfigError(msg) + # Add context for constructors that require it + # context_dict = context.model_dump() + config_dict["context"] = context + config_dict["tags"] = {} + for section in ["azure", "shm", "tags"]: + config_dict[section]["context"] = context + try: return Config.model_validate(config_dict) except ValidationError as exc: @@ -338,7 +322,7 @@ def from_remote(cls, context: Context) -> Config: context.storage_account_name, context.storage_container_name, ) - return Config.from_yaml(config_yaml) + return Config.from_yaml(context, config_yaml) def to_yaml(self) -> str: return yaml.dump(self.model_dump(), indent=2) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index fb89f448d2..a9f0b5f218 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -56,6 +56,9 @@ def storage_account_name(self) -> str: # maximum of 24 characters allowed return f"shm{self.shm_name[:14]}context" + def to_yaml(self) -> str: + return yaml.dump(self.model_dump(), indent=2) + class ContextSettings(BaseModel, validate_assignment=True): """Load global and local settings from dotfiles with structure like the following @@ -189,6 +192,9 @@ def from_file(cls, config_file_path: Path | None = None) -> ContextSettings: msg = f"Could not find file {config_file_path}.\n{exc}" raise DataSafeHavenConfigError(msg) from exc + def to_yaml(self) -> str: + return yaml.dump(self.model_dump(by_alias=True), indent=2) + def write(self, config_file_path: Path | None = None) -> None: """Write settings to YAML file""" if config_file_path is None: @@ -197,5 +203,5 @@ def write(self, config_file_path: Path | None = None) -> None: config_file_path.parent.mkdir(parents=True, exist_ok=True) with open(config_file_path, "w", encoding="utf-8") as f_yaml: - yaml.dump(self.model_dump(by_alias=True), f_yaml, indent=2) + f_yaml.write(self.to_yaml()) self.logger.info(f"Saved context settings to '[green]{config_file_path}[/]'.") diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 7edbc44960..ac7ae7f8ba 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -18,7 +18,7 @@ @fixture def azure_config(context): - return ConfigSectionAzure.from_context( + return ConfigSectionAzure( context=context, subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", @@ -26,16 +26,8 @@ def azure_config(context): class TestConfigSectionAzure: - def test_constructor(self): - ConfigSectionAzure( - admin_group_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - location="uksouth", - subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - ) - - def test_from_context(self, context): - azure_config = ConfigSectionAzure.from_context( + def test_constructor(self, context): + azure_config = ConfigSectionAzure( context=context, subscription_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", @@ -58,7 +50,7 @@ def test_constructor_defaults(self): @fixture def shm_config(context): - return ConfigSectionSHM.from_context( + return ConfigSectionSHM( context=context, aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", admin_email_address="admin@example.com", @@ -69,18 +61,8 @@ def shm_config(context): class TestConfigSectionSHM: - def test_constructor(self): - ConfigSectionSHM( - aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - admin_email_address="admin@example.com", - admin_ip_addresses=["0.0.0.0"], # noqa: S104 - fqdn="shm.acme.com", - name="ACME SHM", - timezone="UTC", - ) - - def test_from_context(self, context): - shm_config = ConfigSectionSHM.from_context( + def test_constructor(self, context): + shm_config = ConfigSectionSHM( context=context, aad_tenant_id="d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", admin_email_address="admin@example.com", @@ -180,19 +162,12 @@ def test_update(self): @fixture def tags_config(context): - return ConfigSectionTags.from_context(context) + return ConfigSectionTags(context) class TestConfigSectionTags: - def test_constructor(self): - tags_config = ConfigSectionTags(deployment="Test Deployment") - assert tags_config.deployment == "Test Deployment" - assert tags_config.deployed_by == "Python" - assert tags_config.project == "Data Safe Haven" - assert tags_config.version == __version__ - - def test_from_context(self, context): - tags_config = ConfigSectionTags.from_context(context) + def test_constructor(self, context): + tags_config = ConfigSectionTags(context) assert tags_config.deployment == "Acme Deployment" assert tags_config.deployed_by == "Python" assert tags_config.project == "Data Safe Haven" @@ -230,15 +205,8 @@ def config_sres(context, azure_config, pulumi_config, shm_config, tags_config): @fixture def config_yaml(): return """azure: - admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - location: uksouth subscription_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd -context: - admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - location: uksouth - name: Acme Deployment - subscription_name: Data Safe Haven (Acme) pulumi: encryption_key_name: pulumi-encryption-key encryption_key_version: lorem @@ -250,7 +218,6 @@ def config_yaml(): admin_ip_addresses: - 0.0.0.0/32 fqdn: shm.acme.com - name: acmedeployment timezone: UTC sres: sre1: @@ -273,8 +240,6 @@ def config_yaml(): research_user_ip_addresses: [] software_packages: none workspace_skus: [] -tags: - deployment: Acme Deployment """ @@ -336,8 +301,8 @@ def test_remove_sre(self, config_sres): assert "sre2" in config_sres.sres.keys() assert "sre1" not in config_sres.sres.keys() - def test_from_yaml(self, config_sres, config_yaml): - config = Config.from_yaml(config_yaml) + def test_from_yaml(self, config_sres, context, config_yaml): + config = Config.from_yaml(context, config_yaml) assert config == config_sres assert isinstance( config.sres["sre1"].software_packages, SoftwarePackageCategory From c0002f6a021e7e0e3e2d114c8a0376b3ac1d5493 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 24 Nov 2023 14:48:57 +0000 Subject: [PATCH 26/65] Dynamically construct ConfigSectionTags --- data_safe_haven/config/config.py | 31 +++++++++++++++++++------------ tests_/config/test_config.py | 11 ++++------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index dafd205e26..3ccc3c29cc 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -245,10 +245,14 @@ class Config(BaseModel, validate_assignment=True): sres: dict[str, ConfigSectionSRE] = Field( default_factory=dict[str, ConfigSectionSRE] ) - tags: ConfigSectionTags | None = Field(exclude=True, default=None) + tags: ConfigSectionTags = Field(exclude=True) + + def __init__(self, context, **kwargs: dict[Any, Any]): + tags = ConfigSectionTags(context) + super().__init__(context=context, tags=tags, **kwargs) @property - def work_directory(self) -> str: + def work_directory(self) -> Path: return self.context.work_directory def is_complete(self, *, require_sres: bool) -> bool: @@ -273,20 +277,24 @@ def remove_sre(self, name: str) -> None: def add_stack(self, name: str, path: Path) -> None: """Add a Pulumi stack file to config""" - with open(path, encoding="utf-8") as f_stack: - pulumi_cfg = f_stack.read() - self.pulumi.stacks[name] = b64encode(pulumi_cfg) + if self.pulumi: + with open(path, encoding="utf-8") as f_stack: + pulumi_cfg = f_stack.read() + self.pulumi.stacks[name] = b64encode(pulumi_cfg) def remove_stack(self, name: str) -> None: """Remove Pulumi stack section by name""" - if name in self.pulumi.stacks.keys(): - del self.pulumi.stacks[name] + if self.pulumi: + stacks = self.pulumi.stacks + if name in stacks.keys(): + del stacks[name] def write_stack(self, name: str, path: Path) -> None: """Write a Pulumi stack file from config""" - pulumi_cfg = b64decode(self.pulumi.stacks[name]) - with open(path, "w", encoding="utf-8") as f_stack: - f_stack.write(pulumi_cfg) + if self.pulumi: + pulumi_cfg = b64decode(self.pulumi.stacks[name]) + with open(path, "w", encoding="utf-8") as f_stack: + f_stack.write(pulumi_cfg) @classmethod def from_yaml(cls, context: Context, config_yaml: str) -> Config: @@ -303,8 +311,7 @@ def from_yaml(cls, context: Context, config_yaml: str) -> Config: # Add context for constructors that require it # context_dict = context.model_dump() config_dict["context"] = context - config_dict["tags"] = {} - for section in ["azure", "shm", "tags"]: + for section in ["azure", "shm"]: config_dict[section]["context"] = context try: diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index ac7ae7f8ba..f112036d48 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -175,18 +175,17 @@ def test_constructor(self, context): @fixture -def config_no_sres(context, azure_config, pulumi_config, shm_config, tags_config): +def config_no_sres(context, azure_config, pulumi_config, shm_config): return Config( context=context, azure=azure_config, pulumi=pulumi_config, shm=shm_config, - tags=tags_config, ) @fixture -def config_sres(context, azure_config, pulumi_config, shm_config, tags_config): +def config_sres(context, azure_config, pulumi_config, shm_config): sre_config_1 = ConfigSectionSRE(index=0) sre_config_2 = ConfigSectionSRE(index=1) return Config( @@ -198,7 +197,6 @@ def config_sres(context, azure_config, pulumi_config, shm_config, tags_config): "sre1": sre_config_1, "sre2": sre_config_2, }, - tags=tags_config, ) @@ -248,18 +246,17 @@ def test_constructor_defaults(self, context): config = Config(context=context) assert config.context == context assert not any( - (config.azure, config.pulumi, config.shm, config.tags, config.sres) + (config.azure, config.pulumi, config.shm, config.sres) ) def test_constructor( - self, context, azure_config, pulumi_config, shm_config, tags_config + self, context, azure_config, pulumi_config, shm_config ): config = Config( context=context, azure=azure_config, pulumi=pulumi_config, shm=shm_config, - tags=tags_config, ) assert not config.sres From bd358dc3f1761a5daeebb7e83e7b50863a2033c8 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 24 Nov 2023 15:13:39 +0000 Subject: [PATCH 27/65] Replace tags to_dict with model_dump --- data_safe_haven/config/config.py | 40 ++++++++++--------- data_safe_haven/context/context.py | 2 +- .../infrastructure/stacks/declarative_shm.py | 14 +++---- .../infrastructure/stacks/declarative_sre.py | 18 ++++----- tests_/config/test_config.py | 16 +++++--- 5 files changed, 49 insertions(+), 41 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 3ccc3c29cc..499145d4b5 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -2,7 +2,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, ClassVar +from typing import Any import yaml from pydantic import ( @@ -42,8 +42,8 @@ class ConfigSectionAzure(BaseModel, validate_assignment=True): - admin_group_id: Guid = Field(exclude=True) - location: AzureLocation = Field(exclude=True) + admin_group_id: Guid = Field(..., exclude=True) + location: AzureLocation = Field(..., exclude=True) subscription_id: Guid tenant_id: Guid @@ -56,7 +56,7 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]): class ConfigSectionPulumi(BaseModel, validate_assignment=True): encryption_key_name: str = "pulumi-encryption-key" encryption_key_version: str - stacks: dict[str, str] = Field(default_factory=dict[str, str]) + stacks: dict[str, str] = Field(..., default_factory=dict[str, str]) storage_container_name: str = "pulumi" @@ -65,7 +65,7 @@ class ConfigSectionSHM(BaseModel, validate_assignment=True): admin_email_address: EmailAdress admin_ip_addresses: list[IpAddress] fqdn: str - name: str = Field(exclude=True) + name: str = Field(..., exclude=True) timezone: TimeZone def __init__(self, context: Context, **kwargs: dict[Any, Any]): @@ -148,14 +148,18 @@ def update( class ConfigSectionSRE(BaseModel, validate_assignment=True): - databases: list[DatabaseSystem] = Field(default_factory=list[DatabaseSystem]) - data_provider_ip_addresses: list[IpAddress] = Field(default_factory=list[IpAddress]) - index: int = Field(ge=0) + databases: list[DatabaseSystem] = Field(..., default_factory=list[DatabaseSystem]) + data_provider_ip_addresses: list[IpAddress] = Field( + ..., default_factory=list[IpAddress] + ) + index: int = Field(..., ge=0) remote_desktop: ConfigSubsectionRemoteDesktopOpts = Field( - default_factory=ConfigSubsectionRemoteDesktopOpts + ..., default_factory=ConfigSubsectionRemoteDesktopOpts + ) + workspace_skus: list[AzureVmSku] = Field(..., default_factory=list[AzureVmSku]) + research_user_ip_addresses: list[IpAddress] = Field( + ..., default_factory=list[IpAddress] ) - workspace_skus: list[AzureVmSku] = Field(default_factory=list[AzureVmSku]) - research_user_ip_addresses: list[IpAddress] = Field(default_factory=list[IpAddress]) software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE @field_validator("databases") @@ -228,10 +232,10 @@ def update( class ConfigSectionTags(BaseModel, validate_assignment=True): - deployment: str = Field(exclude=True) - deployed_by: ClassVar[str] = "Python" - project: ClassVar[str] = "Data Safe Haven" - version: ClassVar[str] = __version__ + deployment: str + deployed_by: str = "Python" + project: str = "Data Safe Haven" + version: str = __version__ def __init__(self, context: Context, **kwargs: dict[Any, Any]): super().__init__(deployment=context.name, **kwargs) @@ -239,13 +243,13 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]): class Config(BaseModel, validate_assignment=True): azure: ConfigSectionAzure | None = None - context: Context = Field(exclude=True) + context: Context = Field(..., exclude=True) pulumi: ConfigSectionPulumi | None = None shm: ConfigSectionSHM | None = None sres: dict[str, ConfigSectionSRE] = Field( - default_factory=dict[str, ConfigSectionSRE] + ..., default_factory=dict[str, ConfigSectionSRE] ) - tags: ConfigSectionTags = Field(exclude=True) + tags: ConfigSectionTags = Field(..., exclude=True) def __init__(self, context, **kwargs: dict[Any, Any]): tags = ConfigSectionTags(context) diff --git a/data_safe_haven/context/context.py b/data_safe_haven/context/context.py index 234dc6c4c0..28175939d5 100644 --- a/data_safe_haven/context/context.py +++ b/data_safe_haven/context/context.py @@ -9,7 +9,7 @@ class Context: def __init__(self, config: Config) -> None: self.azure_api_: AzureApi | None = None self.config = config - self.tags = {"component": "context"} | self.config.tags.to_dict() + self.tags = {"component": "context"} | self.config.tags.model_dump() @property def azure_api(self) -> AzureApi: diff --git a/data_safe_haven/infrastructure/stacks/declarative_shm.py b/data_safe_haven/infrastructure/stacks/declarative_shm.py index 777f042b27..acceb69a59 100644 --- a/data_safe_haven/infrastructure/stacks/declarative_shm.py +++ b/data_safe_haven/infrastructure/stacks/declarative_shm.py @@ -40,7 +40,7 @@ def run(self) -> None: "verification-azuread-custom-domain" ), ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy firewall and routing @@ -57,7 +57,7 @@ def run(self) -> None: subnet_identity_servers=networking.subnet_identity_servers, subnet_update_servers=networking.subnet_update_servers, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy firewall and routing @@ -69,7 +69,7 @@ def run(self) -> None: resource_group_name=networking.resource_group_name, subnet=networking.subnet_bastion, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy data storage @@ -83,7 +83,7 @@ def run(self) -> None: pulumi_opts=self.pulumi_opts, tenant_id=self.cfg.azure.tenant_id, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy automated monitoring @@ -97,7 +97,7 @@ def run(self) -> None: subnet_monitoring=networking.subnet_monitoring, timezone=self.cfg.shm.timezone, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy update servers @@ -113,7 +113,7 @@ def run(self) -> None: virtual_network_name=networking.virtual_network.name, virtual_network_resource_group_name=networking.resource_group_name, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy domain controllers @@ -137,7 +137,7 @@ def run(self) -> None: virtual_network_name=networking.virtual_network.name, virtual_network_resource_group_name=networking.resource_group_name, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Export values for later use diff --git a/data_safe_haven/infrastructure/stacks/declarative_sre.py b/data_safe_haven/infrastructure/stacks/declarative_sre.py index 245339bf29..f4058b160c 100644 --- a/data_safe_haven/infrastructure/stacks/declarative_sre.py +++ b/data_safe_haven/infrastructure/stacks/declarative_sre.py @@ -91,7 +91,7 @@ def run(self) -> None: ), sre_index=self.cfg.sres[self.sre_name].index, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy networking @@ -129,7 +129,7 @@ def run(self) -> None: self.sre_name ].research_user_ip_addresses, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy automated monitoring @@ -150,7 +150,7 @@ def run(self) -> None: sre_index=self.cfg.sres[self.sre_name].index, timezone=self.cfg.shm.timezone, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy data storage @@ -176,7 +176,7 @@ def run(self) -> None: subscription_name=self.cfg.subscription_name, tenant_id=self.cfg.azure.tenant_id, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy frontend application gateway @@ -191,7 +191,7 @@ def run(self) -> None: subnet_guacamole_containers=networking.subnet_guacamole_containers, sre_fqdn=networking.sre_fqdn, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy containerised remote desktop gateway @@ -222,7 +222,7 @@ def run(self) -> None: virtual_network_resource_group_name=networking.resource_group.name, virtual_network=networking.virtual_network, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy workspaces @@ -261,7 +261,7 @@ def run(self) -> None: virtual_network=networking.virtual_network, vm_details=list(enumerate(self.cfg.sres[self.sre_name].workspace_skus)), ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy containerised user services @@ -300,7 +300,7 @@ def run(self) -> None: virtual_network=networking.virtual_network, virtual_network_resource_group_name=networking.resource_group.name, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Deploy backup service @@ -312,7 +312,7 @@ def run(self) -> None: storage_account_data_private_sensitive_id=data.storage_account_data_private_sensitive_id, storage_account_data_private_sensitive_name=data.storage_account_data_private_sensitive_name, ), - tags=self.cfg.tags.to_dict(), + tags=self.cfg.tags.model_dump(), ) # Export values for later use diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index f112036d48..e58c213e25 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -173,6 +173,14 @@ def test_constructor(self, context): assert tags_config.project == "Data Safe Haven" assert tags_config.version == __version__ + def test_model_dump(self, tags_config): + tags_dict = tags_config.model_dump() + assert all( + ("deployment", "deployed_by", "project", "version" in tags_dict.keys()) + ) + assert tags_dict["deployment"] == "Acme Deployment" + assert tags_dict["version"] == __version__ + @fixture def config_no_sres(context, azure_config, pulumi_config, shm_config): @@ -245,13 +253,9 @@ class TestConfig: def test_constructor_defaults(self, context): config = Config(context=context) assert config.context == context - assert not any( - (config.azure, config.pulumi, config.shm, config.sres) - ) + assert not any((config.azure, config.pulumi, config.shm, config.sres)) - def test_constructor( - self, context, azure_config, pulumi_config, shm_config - ): + def test_constructor(self, context, azure_config, pulumi_config, shm_config): config = Config( context=context, azure=azure_config, From e6eb86ff6c81e09c4550b82e77c7351c19e436d5 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 27 Nov 2023 13:26:30 +0000 Subject: [PATCH 28/65] Add missing type annotation --- data_safe_haven/config/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 499145d4b5..b6d21f7519 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -251,7 +251,7 @@ class Config(BaseModel, validate_assignment=True): ) tags: ConfigSectionTags = Field(..., exclude=True) - def __init__(self, context, **kwargs: dict[Any, Any]): + def __init__(self, context: Context, **kwargs: dict[Any, Any]): tags = ConfigSectionTags(context) super().__init__(context=context, tags=tags, **kwargs) From 09e8427a4abf920f5ca9ac9cc3b73c6d88a9f2e8 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 27 Nov 2023 13:49:32 +0000 Subject: [PATCH 29/65] Correct config attributes --- .../administration/users/active_directory_users.py | 2 +- .../administration/users/guacamole_users.py | 2 +- data_safe_haven/commands/admin_add_users.py | 2 +- data_safe_haven/commands/admin_list_users.py | 2 +- data_safe_haven/commands/admin_register_users.py | 2 +- data_safe_haven/commands/admin_remove_users.py | 2 +- data_safe_haven/commands/admin_unregister_users.py | 2 +- data_safe_haven/commands/deploy_shm.py | 4 ++-- data_safe_haven/commands/deploy_sre.py | 10 ++++++---- data_safe_haven/context/context.py | 2 +- data_safe_haven/infrastructure/stack_manager.py | 4 ++-- .../infrastructure/stacks/declarative_shm.py | 2 +- .../infrastructure/stacks/declarative_sre.py | 4 ++-- 13 files changed, 21 insertions(+), 19 deletions(-) diff --git a/data_safe_haven/administration/users/active_directory_users.py b/data_safe_haven/administration/users/active_directory_users.py index 6832c40f71..192cef3549 100644 --- a/data_safe_haven/administration/users/active_directory_users.py +++ b/data_safe_haven/administration/users/active_directory_users.py @@ -24,7 +24,7 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) shm_stack = SHMStackManager(config) - self.azure_api = AzureApi(config.subscription_name) + self.azure_api = AzureApi(config.context.subscription_name) self.logger = LoggingSingleton() self.resource_group_name = shm_stack.output("domain_controllers")[ "resource_group_name" diff --git a/data_safe_haven/administration/users/guacamole_users.py b/data_safe_haven/administration/users/guacamole_users.py index c8791db8fa..6b05a47c54 100644 --- a/data_safe_haven/administration/users/guacamole_users.py +++ b/data_safe_haven/administration/users/guacamole_users.py @@ -18,7 +18,7 @@ def __init__(self, config: Config, sre_name: str, *args: Any, **kwargs: Any): sre_stack.secret("password-user-database-admin"), sre_stack.output("remote_desktop")["connection_db_server_name"], sre_stack.output("remote_desktop")["resource_group_name"], - config.subscription_name, + config.context.subscription_name, ) self.users_: Sequence[ResearchUser] | None = None self.postgres_script_path: pathlib.Path = ( diff --git a/data_safe_haven/commands/admin_add_users.py b/data_safe_haven/commands/admin_add_users.py index e209e8e740..c6ca9bb24e 100644 --- a/data_safe_haven/commands/admin_add_users.py +++ b/data_safe_haven/commands/admin_add_users.py @@ -13,7 +13,7 @@ def admin_add_users(csv_path: pathlib.Path) -> None: try: # Load config file config = Config() - shm_name = config.name + shm_name = config.context.name # Load GraphAPI as this may require user-interaction that is not # possible as part of a Pulumi declarative command diff --git a/data_safe_haven/commands/admin_list_users.py b/data_safe_haven/commands/admin_list_users.py index 87033843a5..0e69eede16 100644 --- a/data_safe_haven/commands/admin_list_users.py +++ b/data_safe_haven/commands/admin_list_users.py @@ -11,7 +11,7 @@ def admin_list_users() -> None: try: # Load config file config = Config() - shm_name = config.name + shm_name = config.context.name # Load GraphAPI as this may require user-interaction that is not # possible as part of a Pulumi declarative command diff --git a/data_safe_haven/commands/admin_register_users.py b/data_safe_haven/commands/admin_register_users.py index 9517a38c70..af6b028dc0 100644 --- a/data_safe_haven/commands/admin_register_users.py +++ b/data_safe_haven/commands/admin_register_users.py @@ -20,7 +20,7 @@ def admin_register_users( # Load config file config = Config() - shm_name = config.name + shm_name = config.context.name # Check that SRE option has been provided if not sre_name: diff --git a/data_safe_haven/commands/admin_remove_users.py b/data_safe_haven/commands/admin_remove_users.py index ddfbe6d97c..5f3a464c25 100644 --- a/data_safe_haven/commands/admin_remove_users.py +++ b/data_safe_haven/commands/admin_remove_users.py @@ -13,7 +13,7 @@ def admin_remove_users( try: # Load config file config = Config() - shm_name = config.name + shm_name = config.context.name # Load GraphAPI as this may require user-interaction that is not # possible as part of a Pulumi declarative command diff --git a/data_safe_haven/commands/admin_unregister_users.py b/data_safe_haven/commands/admin_unregister_users.py index c3eed321b5..e2fbcb86b0 100644 --- a/data_safe_haven/commands/admin_unregister_users.py +++ b/data_safe_haven/commands/admin_unregister_users.py @@ -20,7 +20,7 @@ def admin_unregister_users( # Load config file config = Config() - shm_name = config.name + shm_name = config.context.name # Check that SRE option has been provided if not sre_name: diff --git a/data_safe_haven/commands/deploy_shm.py b/data_safe_haven/commands/deploy_shm.py index f23d52afae..d6864f3511 100644 --- a/data_safe_haven/commands/deploy_shm.py +++ b/data_safe_haven/commands/deploy_shm.py @@ -62,7 +62,7 @@ def deploy_shm( stack.deploy(force=force) # Add Pulumi infrastructure information to the config file - config.read_stack(stack.stack_name, stack.local_stack_path) + config.add_stack(stack.stack_name, stack.local_stack_path) # Upload config to blob storage config.upload() @@ -75,7 +75,7 @@ def deploy_shm( # Provision SHM with anything that could not be done in Pulumi manager = SHMProvisioningManager( - subscription_name=config.subscription_name, + subscription_name=config.context.subscription_name, stack=stack, ) manager.run() diff --git a/data_safe_haven/commands/deploy_sre.py b/data_safe_haven/commands/deploy_sre.py index 7f95227e17..5fe4b43bb8 100644 --- a/data_safe_haven/commands/deploy_sre.py +++ b/data_safe_haven/commands/deploy_sre.py @@ -31,14 +31,16 @@ def deploy_sre( # Load and validate config file config = Config() config.sre(sre_name).update( - allow_copy=allow_copy, - allow_paste=allow_paste, data_provider_ip_addresses=data_provider_ip_addresses, databases=databases, workspace_skus=workspace_skus, software_packages=software_packages, user_ip_addresses=user_ip_addresses, ) + config.sre(sre_name).remote_desktop.update( + allow_copy=allow_copy, + allow_paste=allow_paste, + ) # Load GraphAPI as this may require user-interaction that is not possible as # part of a Pulumi declarative command @@ -150,7 +152,7 @@ def deploy_sre( stack.deploy(force=force) # Add Pulumi infrastructure information to the config file - config.read_stack(stack.stack_name, stack.local_stack_path) + config.add_stack(stack.stack_name, stack.local_stack_path) # Upload config to blob storage config.upload() @@ -160,7 +162,7 @@ def deploy_sre( shm_stack=shm_stack, sre_name=sre_name, sre_stack=stack, - subscription_name=config.subscription_name, + subscription_name=config.context.subscription_name, timezone=config.shm.timezone, ) manager.run() diff --git a/data_safe_haven/context/context.py b/data_safe_haven/context/context.py index 28175939d5..aa50969c1d 100644 --- a/data_safe_haven/context/context.py +++ b/data_safe_haven/context/context.py @@ -20,7 +20,7 @@ def azure_api(self) -> AzureApi: """ if not self.azure_api_: self.azure_api_ = AzureApi( - subscription_name=self.config.subscription_name, + subscription_name=self.config.context.subscription_name, ) return self.azure_api_ diff --git a/data_safe_haven/infrastructure/stack_manager.py b/data_safe_haven/infrastructure/stack_manager.py index f63026ff1b..d8b4cc891c 100644 --- a/data_safe_haven/infrastructure/stack_manager.py +++ b/data_safe_haven/infrastructure/stack_manager.py @@ -37,7 +37,7 @@ def __init__(self, config: Config): def env(self) -> dict[str, Any]: """Get necessary Pulumi environment variables""" if not self.env_: - azure_api = AzureApi(self.cfg.subscription_name) + azure_api = AzureApi(self.cfg.context.subscription_name) backend_storage_account_keys = azure_api.get_storage_account_keys( self.cfg.context.resource_group_name, self.cfg.context.storage_account_name, @@ -210,7 +210,7 @@ def destroy(self) -> None: self.logger.info( f"Removing Pulumi stack backup [green]{stack_backup_name}[/]." ) - azure_api = AzureApi(self.cfg.subscription_name) + azure_api = AzureApi(self.cfg.context.subscription_name) azure_api.remove_blob( blob_name=f".pulumi/stacks/{self.project_name}/{stack_backup_name}", resource_group_name=self.cfg.context.resource_group_name, diff --git a/data_safe_haven/infrastructure/stacks/declarative_shm.py b/data_safe_haven/infrastructure/stacks/declarative_shm.py index acceb69a59..a7a93e7d91 100644 --- a/data_safe_haven/infrastructure/stacks/declarative_shm.py +++ b/data_safe_haven/infrastructure/stacks/declarative_shm.py @@ -133,7 +133,7 @@ def run(self) -> None: password_domain_searcher=data.password_domain_searcher, private_ip_address=networking.domain_controller_private_ip, subnet_identity_servers=networking.subnet_identity_servers, - subscription_name=self.cfg.subscription_name, + subscription_name=self.cfg.context.subscription_name, virtual_network_name=networking.virtual_network.name, virtual_network_resource_group_name=networking.resource_group_name, ), diff --git a/data_safe_haven/infrastructure/stacks/declarative_sre.py b/data_safe_haven/infrastructure/stacks/declarative_sre.py index f4058b160c..106166ce2d 100644 --- a/data_safe_haven/infrastructure/stacks/declarative_sre.py +++ b/data_safe_haven/infrastructure/stacks/declarative_sre.py @@ -173,7 +173,7 @@ def run(self) -> None: subnet_data_configuration=networking.subnet_data_configuration, subnet_data_private=networking.subnet_data_private, subscription_id=self.cfg.azure.subscription_id, - subscription_name=self.cfg.subscription_name, + subscription_name=self.cfg.context.subscription_name, tenant_id=self.cfg.azure.tenant_id, ), tags=self.cfg.tags.model_dump(), @@ -256,7 +256,7 @@ def run(self) -> None: storage_account_data_private_user_name=data.storage_account_data_private_user_name, storage_account_data_private_sensitive_name=data.storage_account_data_private_sensitive_name, subnet_workspaces=networking.subnet_workspaces, - subscription_name=self.cfg.subscription_name, + subscription_name=self.cfg.context.subscription_name, virtual_network_resource_group=networking.resource_group, virtual_network=networking.virtual_network, vm_details=list(enumerate(self.cfg.sres[self.sre_name].workspace_skus)), From adbb464a958672496e74de22ec2bc4470f814df8 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 27 Nov 2023 13:49:51 +0000 Subject: [PATCH 30/65] Add missing properties to context --- data_safe_haven/config/context_settings.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index a9f0b5f218..d37ca48613 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -56,6 +56,14 @@ def storage_account_name(self) -> str: # maximum of 24 characters allowed return f"shm{self.shm_name[:14]}context" + @property + def key_vault_name(self) -> str: + return f"shm-{self.shm_name[:9]}-kv-context" + + @property + def managed_identity_name(self) -> str: + return f"shm-{self.shm_name}-identity-reader-context" + def to_yaml(self) -> str: return yaml.dump(self.model_dump(), indent=2) From c040e47e7743b723d49e80422acfcda4f91bde6b Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 28 Nov 2023 13:42:07 +0000 Subject: [PATCH 31/65] Update minimum Python to 3.11 --- data_safe_haven/external/api/graph_api.py | 2 +- .../external/interface/azure_postgresql_database.py | 2 +- data_safe_haven/functions/miscellaneous.py | 2 +- pyproject.toml | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/data_safe_haven/external/api/graph_api.py b/data_safe_haven/external/api/graph_api.py index 1b1a00b677..d1eeddbd64 100644 --- a/data_safe_haven/external/api/graph_api.py +++ b/data_safe_haven/external/api/graph_api.py @@ -290,7 +290,7 @@ def create_application_secret( "passwordCredential": { "displayName": application_secret_name, "endDateTime": ( - datetime.datetime.now(datetime.timezone.utc) + datetime.datetime.now(datetime.UTC) + datetime.timedelta(weeks=520) ).strftime("%Y-%m-%dT%H:%M:%SZ"), } diff --git a/data_safe_haven/external/interface/azure_postgresql_database.py b/data_safe_haven/external/interface/azure_postgresql_database.py index ad0edc30da..3ffe5f9871 100644 --- a/data_safe_haven/external/interface/azure_postgresql_database.py +++ b/data_safe_haven/external/interface/azure_postgresql_database.py @@ -54,7 +54,7 @@ def __init__( self.port = 5432 self.resource_group_name = resource_group_name self.server_name = database_server_name - self.rule_suffix = datetime.datetime.now(tz=datetime.timezone.utc).strftime( + self.rule_suffix = datetime.datetime.now(tz=datetime.UTC).strftime( r"%Y%m%d-%H%M%S" ) diff --git a/data_safe_haven/functions/miscellaneous.py b/data_safe_haven/functions/miscellaneous.py index 8703fd39fe..00b47dcacf 100644 --- a/data_safe_haven/functions/miscellaneous.py +++ b/data_safe_haven/functions/miscellaneous.py @@ -41,7 +41,7 @@ def ordered_private_dns_zones(resource_type: str | None = None) -> list[str]: def time_as_string(hour: int, minute: int, timezone: str) -> str: """Get the next occurence of a repeating daily time as a string""" - dt = datetime.datetime.now(datetime.timezone.utc).replace( + dt = datetime.datetime.now(datetime.UTC).replace( hour=hour, minute=minute, second=0, diff --git a/pyproject.toml b/pyproject.toml index f648c705bf..01b90cc932 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "An open-source framework for creating secure environments to anal authors = [ { name = "Data Safe Haven development team", email = "safehavendevs@turing.ac.uk" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" license = "BSD-3-Clause" dependencies = [ "appdirs~=1.4", @@ -100,7 +100,7 @@ pre-install-commands = ["pip install -r requirements.txt"] test = "pytest {args:-vvv tests_}" [tool.black] -target-version = ["py310", "py311"] +target-version = ["py311", "py312"] [tool.ruff] select = [ From ba05dd7e48be9bb6e1cc439bae7cfad681f8ce7d Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 29 Nov 2023 15:23:51 +0000 Subject: [PATCH 32/65] Correct typer validator factory --- data_safe_haven/functions/typer_validators.py | 1 + tests_/functions/test_typer_validators.py | 28 +++++++++++++++++++ tests_/functions/test_validators.py | 27 ++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 tests_/functions/test_typer_validators.py create mode 100644 tests_/functions/test_validators.py diff --git a/data_safe_haven/functions/typer_validators.py b/data_safe_haven/functions/typer_validators.py index 30d6227249..df72cd149c 100644 --- a/data_safe_haven/functions/typer_validators.py +++ b/data_safe_haven/functions/typer_validators.py @@ -16,6 +16,7 @@ def typer_validator_factory(validator: Callable[[Any], Any]) -> Callable[[Any], def typer_validator(x: Any) -> Any: try: validator(x) + return x except ValueError as exc: raise BadParameter(str(exc)) from exc diff --git a/tests_/functions/test_typer_validators.py b/tests_/functions/test_typer_validators.py new file mode 100644 index 0000000000..0c9c02ff90 --- /dev/null +++ b/tests_/functions/test_typer_validators.py @@ -0,0 +1,28 @@ +import pytest +from typer import BadParameter + +from data_safe_haven.functions.typer_validators import typer_validate_aad_guid + + +class TestTyperValidateAadGuid: + @pytest.mark.parametrize( + "guid", + [ + "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "10de18e7-b238-6f1e-a4ad-772708929203", + ] + ) + def test_typer_validate_aad_guid(self, guid): + assert typer_validate_aad_guid(guid) == guid + + @pytest.mark.parametrize( + "guid", + [ + "10de18e7_b238_6f1e_a4ad_772708929203", + "not a guid", + ] + ) + def test_typer_validate_aad_guid_fail(self, guid): + with pytest.raises(BadParameter) as exc: + typer_validate_aad_guid(guid) + assert "Expected GUID" in exc diff --git a/tests_/functions/test_validators.py b/tests_/functions/test_validators.py new file mode 100644 index 0000000000..bcf41ca263 --- /dev/null +++ b/tests_/functions/test_validators.py @@ -0,0 +1,27 @@ +import pytest + +from data_safe_haven.functions.validators import validate_aad_guid + + +class TestValidateAadGuid: + @pytest.mark.parametrize( + "guid", + [ + "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "10de18e7-b238-6f1e-a4ad-772708929203", + ] + ) + def test_validate_aad_guid(self, guid): + assert validate_aad_guid(guid) == guid + + @pytest.mark.parametrize( + "guid", + [ + "10de18e7_b238_6f1e_a4ad_772708929203", + "not a guid", + ] + ) + def test_validate_aad_guid_fail(self, guid): + with pytest.raises(ValueError) as exc: + validate_aad_guid(guid) + assert "Expected GUID" in exc From e1a02648d96e62abb7b01af30621af1b78bc9e57 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 29 Nov 2023 15:37:44 +0000 Subject: [PATCH 33/65] Support optional args in typer validator factory --- data_safe_haven/functions/typer_validators.py | 6 ++++++ tests_/functions/test_typer_validators.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/data_safe_haven/functions/typer_validators.py b/data_safe_haven/functions/typer_validators.py index df72cd149c..91434740d9 100644 --- a/data_safe_haven/functions/typer_validators.py +++ b/data_safe_haven/functions/typer_validators.py @@ -13,7 +13,13 @@ def typer_validator_factory(validator: Callable[[Any], Any]) -> Callable[[Any], Any]: + """Factory to create validation functions for Typer from Pydantic validators""" def typer_validator(x: Any) -> Any: + # Return unused optional arguments + if x is None: + return x + + # Validate input, catching ValueError to raise Typer Exception try: validator(x) return x diff --git a/tests_/functions/test_typer_validators.py b/tests_/functions/test_typer_validators.py index 0c9c02ff90..72ae8b3f53 100644 --- a/tests_/functions/test_typer_validators.py +++ b/tests_/functions/test_typer_validators.py @@ -26,3 +26,6 @@ def test_typer_validate_aad_guid_fail(self, guid): with pytest.raises(BadParameter) as exc: typer_validate_aad_guid(guid) assert "Expected GUID" in exc + + def test_typer_validate_aad_guid_nonae(self): + assert typer_validate_aad_guid(None) is None From 7a0cab83744f917af83993e86010090bd2987717 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 28 Nov 2023 12:03:12 +0000 Subject: [PATCH 34/65] Move Pulumi encryption key to Config property This avoids having to write back to the configuration and the possibility of breaking a configuration through making manual changes. --- data_safe_haven/config/config.py | 24 ++++++++++-- data_safe_haven/context/context.py | 4 +- data_safe_haven/external/api/azure_api.py | 21 +++++++++++ .../infrastructure/stack_manager.py | 2 +- tests_/config/test_config.py | 34 ++++++++++++++--- tests_/external/api/azure_api.py | 37 +++++++++++++++++++ 6 files changed, 109 insertions(+), 13 deletions(-) create mode 100644 tests_/external/api/azure_api.py diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index b6d21f7519..d16474ce9f 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -2,9 +2,10 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, ClassVar import yaml +from azure.keyvault.keys import KeyVaultKey from pydantic import ( BaseModel, Field, @@ -54,10 +55,9 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]): class ConfigSectionPulumi(BaseModel, validate_assignment=True): - encryption_key_name: str = "pulumi-encryption-key" - encryption_key_version: str + storage_container_name: ClassVar[str] = "pulumi" + encryption_key_name: ClassVar[str] = "pulumi-encryption-key" stacks: dict[str, str] = Field(..., default_factory=dict[str, str]) - storage_container_name: str = "pulumi" class ConfigSectionSHM(BaseModel, validate_assignment=True): @@ -251,6 +251,8 @@ class Config(BaseModel, validate_assignment=True): ) tags: ConfigSectionTags = Field(..., exclude=True) + _pulumi_encryption_key = None + def __init__(self, context: Context, **kwargs: dict[Any, Any]): tags = ConfigSectionTags(context) super().__init__(context=context, tags=tags, **kwargs) @@ -259,6 +261,20 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]): def work_directory(self) -> Path: return self.context.work_directory + @property + def pulumi_encryption_key(self) -> KeyVaultKey: + if not self._pulumi_encryption_key: + azure_api = AzureApi(subscription_name=self.context.subscription_name) + self._pulumi_encryption_key = azure_api.get_keyvault_key( + key_name=self.pulumi.encryption_key_name, + key_vault_name=self.context.key_vault_name, + ) + return self._pulumi_encryption_key + + @property + def pulumi_encryption_key_version(self) -> str: + return self.pulumi_encryption_key.id.split("/")[-1] + def is_complete(self, *, require_sres: bool) -> bool: if require_sres: if not self.sres: diff --git a/data_safe_haven/context/context.py b/data_safe_haven/context/context.py index aa50969c1d..ce92423892 100644 --- a/data_safe_haven/context/context.py +++ b/data_safe_haven/context/context.py @@ -78,12 +78,10 @@ def create(self) -> None: f"Keyvault '{self.config.context.key_vault_name}' was not created." ) raise DataSafeHavenAzureError(msg) - pulumi_encryption_key = self.azure_api.ensure_keyvault_key( + self.azure_api.ensure_keyvault_key( key_name=self.config.pulumi.encryption_key_name, key_vault_name=keyvault.name, ) - key_version = pulumi_encryption_key.id.split("/")[-1] - self.config.pulumi.encryption_key_version = key_version except Exception as exc: msg = f"Failed to create context resources.\n{exc}" raise DataSafeHavenAzureError(msg) from exc diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 0ce8044bce..4024217912 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -615,6 +615,27 @@ def get_keyvault_certificate( msg = f"Failed to retrieve certificate {certificate_name}.\n{exc}" raise DataSafeHavenAzureError(msg) from exc + def get_keyvault_key(self, key_name: str, key_vault_name: str) -> KeyVaultKey: + """Read a key from the KeyVault + + Returns: + KeyVaultKey: The key + + Raises: + DataSafeHavenAzureError if the secret could not be read + """ + # Connect to Azure clients + key_client = KeyClient( + vault_url=f"https://{key_vault_name}.vault.azure.net", + credential=self.credential, + ) + # Ensure that certificate exists + try: + return key_client.get_key(key_name) + except Exception as exc: + msg = f"Failed to retrieve key {key_name}.\n{exc}" + raise DataSafeHavenAzureError(msg) from exc + def get_keyvault_secret(self, key_vault_name: str, secret_name: str) -> str: """Read a secret from the KeyVault diff --git a/data_safe_haven/infrastructure/stack_manager.py b/data_safe_haven/infrastructure/stack_manager.py index d8b4cc891c..0d2ad96acb 100644 --- a/data_safe_haven/infrastructure/stack_manager.py +++ b/data_safe_haven/infrastructure/stack_manager.py @@ -100,7 +100,7 @@ def stack(self) -> automation.Stack: stack_name=self.stack_name, program=self.program.run, opts=automation.LocalWorkspaceOptions( - secrets_provider=f"azurekeyvault://{self.cfg.context.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.pulumi.encryption_key_version}", + secrets_provider=f"azurekeyvault://{self.cfg.context.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.encryption_key_version}", work_dir=str(self.work_dir), env_vars=self.account.env, ), diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index e58c213e25..106d919b4f 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -37,12 +37,12 @@ def test_constructor(self, context): @fixture def pulumi_config(): - return ConfigSectionPulumi(encryption_key_version="lorem") + return ConfigSectionPulumi() class TestConfigSectionPulumi: def test_constructor_defaults(self): - pulumi_config = ConfigSectionPulumi(encryption_key_version="lorem") + pulumi_config = ConfigSectionPulumi() assert pulumi_config.encryption_key_name == "pulumi-encryption-key" assert pulumi_config.stacks == {} assert pulumi_config.storage_container_name == "pulumi" @@ -214,10 +214,7 @@ def config_yaml(): subscription_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd pulumi: - encryption_key_name: pulumi-encryption-key - encryption_key_version: lorem stacks: {} - storage_container_name: pulumi shm: aad_tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd admin_email_address: admin@example.com @@ -249,6 +246,20 @@ def config_yaml(): """ +@fixture +def mock_key_vault_key(monkeypatch): + class MockKeyVaultKey: + def __init__(self, key_name, key_vault_name): + self.key_name = key_name + self.key_vault_name = key_vault_name + self.id = "mock_key/version" + + def mock_get_keyvault_key(self, key_name, key_vault_name): # noqa: ARG001 + return MockKeyVaultKey(key_name, key_vault_name) + + monkeypatch.setattr(AzureApi, "get_keyvault_key", mock_get_keyvault_key) + + class TestConfig: def test_constructor_defaults(self, context): config = Config(context=context) @@ -268,6 +279,19 @@ def test_work_directory(self, config_sres): config = config_sres assert config.work_directory == config.context.work_directory + def test_pulumi_encryption_key( + self, config_sres, mock_key_vault_key # noqa: ARG002 + ): + key = config_sres.pulumi_encryption_key + assert key.key_name == config_sres.pulumi.encryption_key_name + assert key.key_vault_name == config_sres.context.key_vault_name + + def test_pulumi_encryption_key_version( + self, config_sres, mock_key_vault_key # noqa: ARG002 + ): + version = config_sres.pulumi_encryption_key_version + assert version == "version" + @pytest.mark.parametrize("require_sres", [False, True]) def test_is_complete_bare(self, context, require_sres): config = Config(context=context) diff --git a/tests_/external/api/azure_api.py b/tests_/external/api/azure_api.py new file mode 100644 index 0000000000..1eab649234 --- /dev/null +++ b/tests_/external/api/azure_api.py @@ -0,0 +1,37 @@ +import pytest +from pytest import fixture + +import data_safe_haven.external.api.azure_api +from data_safe_haven.exceptions import DataSafeHavenAzureError +from data_safe_haven.external.api.azure_api import AzureApi + + +@fixture +def mock_key_client(monkeypatch): + class MockKeyClient: + def __init__(self, vault_url, credential): + self.vault_url = vault_url + self.credential = credential + + def get_key(self, key_name): + if key_name == "exists": + return f"key: {key_name}" + else: + raise Exception + + monkeypatch.setattr( + data_safe_haven.external.api.azure_api, "KeyClient", MockKeyClient + ) + + +class TestAzureApi: + def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 + api = AzureApi("subscription name") + key = api.get_keyvault_key("exists", "key vault name") + assert key == "key: exists" + + def test_get_keyvault_key_missing(self, mock_key_client): # noqa: ARG002 + api = AzureApi("subscription name") + with pytest.raises(DataSafeHavenAzureError) as exc: + api.get_keyvault_key("does not exist", "key vault name") + assert "Failed to retrieve key does not exist" in exc From 4727c325f23970666fa7c007d3e0a3abc1805440 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 28 Nov 2023 15:45:28 +0000 Subject: [PATCH 35/65] Add template method --- data_safe_haven/config/config.py | 25 ++++++++++++++++++++++--- tests_/config/test_config.py | 21 +++++++++++---------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index d16474ce9f..6b1f9abd45 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -242,10 +242,10 @@ def __init__(self, context: Context, **kwargs: dict[Any, Any]): class Config(BaseModel, validate_assignment=True): - azure: ConfigSectionAzure | None = None + azure: ConfigSectionAzure context: Context = Field(..., exclude=True) - pulumi: ConfigSectionPulumi | None = None - shm: ConfigSectionSHM | None = None + pulumi: ConfigSectionPulumi + shm: ConfigSectionSHM sres: dict[str, ConfigSectionSRE] = Field( ..., default_factory=dict[str, ConfigSectionSRE] ) @@ -316,6 +316,25 @@ def write_stack(self, name: str, path: Path) -> None: with open(path, "w", encoding="utf-8") as f_stack: f_stack.write(pulumi_cfg) + @classmethod + def template(cls, context: Context) -> Config: + # Create object without validation to allow "replace me" prompts + return Config.model_construct( + context=context, + azure=ConfigSectionAzure.model_construct( + subscription_id="Azure subscription ID", + tenant_id="Azure tenant ID", + ), + pulumi=ConfigSectionPulumi(), + shm=ConfigSectionSHM.model_construct( + aad_tenant_id="Azure Active Directory tenant ID", + admin_email_address="Admin email address", + admin_ip_addresses=["Admin IP addresses"], + fqdn="TRE domain name", + timezone="Timezone", + ), + ) + @classmethod def from_yaml(cls, context: Context, config_yaml: str) -> Config: try: diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index 106d919b4f..fc6670e27a 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -11,6 +11,7 @@ ConfigSectionTags, ConfigSubsectionRemoteDesktopOpts, ) +from data_safe_haven.exceptions import DataSafeHavenParameterError from data_safe_haven.external import AzureApi from data_safe_haven.utility.enums import DatabaseSystem, SoftwarePackageCategory from data_safe_haven.version import __version__ @@ -261,11 +262,6 @@ def mock_get_keyvault_key(self, key_name, key_vault_name): # noqa: ARG001 class TestConfig: - def test_constructor_defaults(self, context): - config = Config(context=context) - assert config.context == context - assert not any((config.azure, config.pulumi, config.shm, config.sres)) - def test_constructor(self, context, azure_config, pulumi_config, shm_config): config = Config( context=context, @@ -292,11 +288,6 @@ def test_pulumi_encryption_key_version( version = config_sres.pulumi_encryption_key_version assert version == "version" - @pytest.mark.parametrize("require_sres", [False, True]) - def test_is_complete_bare(self, context, require_sres): - config = Config(context=context) - assert config.is_complete(require_sres=require_sres) is False - @pytest.mark.parametrize("require_sres,expected", [(False, True), (True, False)]) def test_is_complete_no_sres(self, config_no_sres, require_sres, expected): assert config_no_sres.is_complete(require_sres=require_sres) is expected @@ -326,6 +317,16 @@ def test_remove_sre(self, config_sres): assert "sre2" in config_sres.sres.keys() assert "sre1" not in config_sres.sres.keys() + def test_template(self, context): + config = Config.template(context) + assert isinstance(config, Config) + assert config.azure.subscription_id == "Azure subscription ID" + + def test_template_validation(self, context): + config = Config.template(context) + with pytest.raises(DataSafeHavenParameterError): + Config.from_yaml(context, config.to_yaml()) + def test_from_yaml(self, config_sres, context, config_yaml): config = Config.from_yaml(context, config_yaml) assert config == config_sres From 36d05f35b4bcca11fdc411022a1deb249ba8b6e0 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 28 Nov 2023 15:47:14 +0000 Subject: [PATCH 36/65] Add config template command --- data_safe_haven/cli.py | 6 +++++ data_safe_haven/commands/__init__.py | 2 ++ data_safe_haven/commands/config.py | 27 ++++++++++++++++++++ tests_/commands/conftest.py | 38 ++++++++++++++++++++++++++++ tests_/commands/test_config.py | 18 +++++++++++++ tests_/commands/test_context.py | 38 ---------------------------- 6 files changed, 91 insertions(+), 38 deletions(-) create mode 100644 data_safe_haven/commands/config.py create mode 100644 tests_/commands/conftest.py create mode 100644 tests_/commands/test_config.py diff --git a/data_safe_haven/cli.py b/data_safe_haven/cli.py index a48a3e38af..4242953aa1 100644 --- a/data_safe_haven/cli.py +++ b/data_safe_haven/cli.py @@ -7,6 +7,7 @@ from data_safe_haven import __version__ from data_safe_haven.commands import ( admin_command_group, + config_command_group, context_command_group, deploy_command_group, teardown_command_group, @@ -69,6 +70,11 @@ def main() -> None: name="admin", help="Perform administrative tasks for a Data Safe Haven deployment.", ) + application.add_typer( + config_command_group, + name="config", + help="Manage Data Safe Haven configuration.", + ) application.add_typer( context_command_group, name="context", help="Manage Data Safe Haven contexts." ) diff --git a/data_safe_haven/commands/__init__.py b/data_safe_haven/commands/__init__.py index 299c982302..a471a41b67 100644 --- a/data_safe_haven/commands/__init__.py +++ b/data_safe_haven/commands/__init__.py @@ -1,4 +1,5 @@ from .admin import admin_command_group +from .config import config_command_group from .context import context_command_group from .deploy import deploy_command_group from .teardown import teardown_command_group @@ -6,6 +7,7 @@ __all__ = [ "admin_command_group", "context_command_group", + "config_command_group", "deploy_command_group", "teardown_command_group", ] diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py new file mode 100644 index 0000000000..da11f11b9e --- /dev/null +++ b/data_safe_haven/commands/config.py @@ -0,0 +1,27 @@ +"""Command group and entrypoints for managing DSH configuration""" +from pathlib import Path +from typing import Annotated, Optional + +import typer +from rich import print + +from data_safe_haven.config import Config, ContextSettings + +config_command_group = typer.Typer() + + +@config_command_group.command() +def template( + file: Annotated[ + Optional[Path], + typer.Option(help="File path to write configuration template to.") + ] = None +) -> None: + """Write a template Data Safe Haven configuration.""" + context = ContextSettings.from_file() + config = Config.template(context) + if file: + with open(file, "w") as outfile: + outfile.write(config.to_yaml()) + else: + print(config.to_yaml()) diff --git a/tests_/commands/conftest.py b/tests_/commands/conftest.py new file mode 100644 index 0000000000..5b337ca021 --- /dev/null +++ b/tests_/commands/conftest.py @@ -0,0 +1,38 @@ +from pytest import fixture +from typer.testing import CliRunner + + +context_settings = """\ + selected: acme_deployment + contexts: + acme_deployment: + name: Acme Deployment + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_name: Data Safe Haven (Acme) + gems: + name: Gems + admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + location: uksouth + subscription_name: Data Safe Haven (Gems)""" + + +@fixture +def tmp_contexts(tmp_path): + config_file_path = tmp_path / "contexts.yaml" + with open(config_file_path, "w") as f: + f.write(context_settings) + return tmp_path + + +@fixture +def runner(tmp_contexts): + runner = CliRunner( + env={ + "DSH_CONFIG_DIRECTORY": str(tmp_contexts), + "COLUMNS": "500", # Set large number of columns to avoid rich wrapping text + "TERM": "dumb", # Disable colours, style and interactive rich features + }, + mix_stderr=False, + ) + return runner diff --git a/tests_/commands/test_config.py b/tests_/commands/test_config.py new file mode 100644 index 0000000000..1ee14d416a --- /dev/null +++ b/tests_/commands/test_config.py @@ -0,0 +1,18 @@ +from data_safe_haven.commands.config import config_command_group + + +class TestTemplate: + def test_template(self, runner): + result = runner.invoke(config_command_group, ["template"]) + assert result.exit_code == 0 + assert "subscription_id: Azure subscription ID" in result.stdout + assert "sres: {}" in result.stdout + + def test_template_file(self, runner, tmp_path): + template_file = (tmp_path / "template.yaml").absolute() + result = runner.invoke(config_command_group, ["template", "--file", str(template_file)]) + assert result.exit_code == 0 + with open(template_file) as f: + template_text = f.read() + assert "subscription_id: Azure subscription ID" in template_text + assert "sres: {}" in template_text diff --git a/tests_/commands/test_context.py b/tests_/commands/test_context.py index 5d8525f4c5..87a499f7b8 100644 --- a/tests_/commands/test_context.py +++ b/tests_/commands/test_context.py @@ -1,45 +1,7 @@ -from pytest import fixture -from typer.testing import CliRunner - from data_safe_haven.commands.context import context_command_group from data_safe_haven.config import Config from data_safe_haven.context import Context -context_settings = """\ - selected: acme_deployment - contexts: - acme_deployment: - name: Acme Deployment - admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - location: uksouth - subscription_name: Data Safe Haven (Acme) - gems: - name: Gems - admin_group_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - location: uksouth - subscription_name: Data Safe Haven (Gems)""" - - -@fixture -def tmp_contexts(tmp_path): - config_file_path = tmp_path / "contexts.yaml" - with open(config_file_path, "w") as f: - f.write(context_settings) - return tmp_path - - -@fixture -def runner(tmp_contexts): - runner = CliRunner( - env={ - "DSH_CONFIG_DIRECTORY": str(tmp_contexts), - "COLUMNS": "500", # Set large number of columns to avoid rich wrapping text - "TERM": "dumb", # Disable colours, style and interactive rich features - }, - mix_stderr=False, - ) - return runner - class TestShow: def test_show(self, runner): From 681379eb4780eb3687d26350e9886de90302ce8a Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 29 Nov 2023 10:57:58 +0000 Subject: [PATCH 37/65] Add config upload command --- data_safe_haven/commands/config.py | 12 ++++++ data_safe_haven/config/config.py | 1 - tests_/commands/test_config.py | 16 ++++++++ tests_/config/test_config.py | 51 +----------------------- tests_/conftest.py | 64 ++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 51 deletions(-) create mode 100644 tests_/conftest.py diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index da11f11b9e..c687696beb 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -25,3 +25,15 @@ def template( outfile.write(config.to_yaml()) else: print(config.to_yaml()) + + +@config_command_group.command() +def upload( + file: Annotated[Path, typer.Argument(help="Path to configuration file")] +) -> None: + """Upload a configuration to the Data Safe Haven context""" + context = ContextSettings.from_file().context + with open(file) as config_file: + config_yaml = config_file.read() + config = Config.from_yaml(context, config_yaml) + config.upload() diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 6b1f9abd45..d47cfa83da 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -348,7 +348,6 @@ def from_yaml(cls, context: Context, config_yaml: str) -> Config: raise DataSafeHavenConfigError(msg) # Add context for constructors that require it - # context_dict = context.model_dump() config_dict["context"] = context for section in ["azure", "shm"]: config_dict[section]["context"] = context diff --git a/tests_/commands/test_config.py b/tests_/commands/test_config.py index 1ee14d416a..7f37e461f5 100644 --- a/tests_/commands/test_config.py +++ b/tests_/commands/test_config.py @@ -16,3 +16,19 @@ def test_template_file(self, runner, tmp_path): template_text = f.read() assert "subscription_id: Azure subscription ID" in template_text assert "sres: {}" in template_text + + +class TestUpload: + def test_upload(self, runner, config_file, mock_upload_blob): + result = runner.invoke( + config_command_group, + ["upload", str(config_file)], + ) + assert result.exit_code == 0 + + def test_upload_no_file(self, runner, mock_upload_blob): + result = runner.invoke( + config_command_group, + ["upload"], + ) + assert result.exit_code == 2 diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index fc6670e27a..cf25dda4b2 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -209,44 +209,6 @@ def config_sres(context, azure_config, pulumi_config, shm_config): ) -@fixture -def config_yaml(): - return """azure: - subscription_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd -pulumi: - stacks: {} -shm: - aad_tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd - admin_email_address: admin@example.com - admin_ip_addresses: - - 0.0.0.0/32 - fqdn: shm.acme.com - timezone: UTC -sres: - sre1: - data_provider_ip_addresses: [] - databases: [] - index: 0 - remote_desktop: - allow_copy: false - allow_paste: false - research_user_ip_addresses: [] - software_packages: none - workspace_skus: [] - sre2: - data_provider_ip_addresses: [] - databases: [] - index: 1 - remote_desktop: - allow_copy: false - allow_paste: false - research_user_ip_addresses: [] - software_packages: none - workspace_skus: [] -""" - - @fixture def mock_key_vault_key(monkeypatch): class MockKeyVaultKey: @@ -355,16 +317,5 @@ def mock_download_blob( def test_to_yaml(self, config_sres, config_yaml): assert config_sres.to_yaml() == config_yaml - def test_upload(self, config_sres, monkeypatch): - def mock_upload_blob( - self, # noqa: ARG001 - blob_data: bytes | str, # noqa: ARG001 - blob_name: str, # noqa: ARG001 - resource_group_name: str, # noqa: ARG001 - storage_account_name: str, # noqa: ARG001 - storage_container_name: str, # noqa: ARG001 - ): - pass - - monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) + def test_upload(self, config_sres, mock_upload_blob): config_sres.upload() diff --git a/tests_/conftest.py b/tests_/conftest.py new file mode 100644 index 0000000000..3bc2dd959b --- /dev/null +++ b/tests_/conftest.py @@ -0,0 +1,64 @@ +from pytest import fixture + +from data_safe_haven.external import AzureApi + + +@fixture +def config_yaml(): + return """azure: + subscription_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd +pulumi: + stacks: {} +shm: + aad_tenant_id: d5c5c439-1115-4cb6-ab50-b8e547b6c8dd + admin_email_address: admin@example.com + admin_ip_addresses: + - 0.0.0.0/32 + fqdn: shm.acme.com + timezone: UTC +sres: + sre1: + data_provider_ip_addresses: [] + databases: [] + index: 0 + remote_desktop: + allow_copy: false + allow_paste: false + research_user_ip_addresses: [] + software_packages: none + workspace_skus: [] + sre2: + data_provider_ip_addresses: [] + databases: [] + index: 1 + remote_desktop: + allow_copy: false + allow_paste: false + research_user_ip_addresses: [] + software_packages: none + workspace_skus: [] +""" + + +@fixture +def config_file(config_yaml, tmp_path): + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as f: + f.write(config_yaml) + return config_file_path + + +@fixture +def mock_upload_blob(monkeypatch): + def mock_upload_blob( + self, # noqa: ARG001 + blob_data: bytes | str, # noqa: ARG001 + blob_name: str, # noqa: ARG001 + resource_group_name: str, # noqa: ARG001 + storage_account_name: str, # noqa: ARG001 + storage_container_name: str, # noqa: ARG001 + ): + pass + + monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) From 82a8d172b428a33597c73b97f90d16ada6ed59ff Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 29 Nov 2023 11:18:26 +0000 Subject: [PATCH 38/65] Add config show command --- data_safe_haven/commands/config.py | 7 ++++++ tests_/commands/test_config.py | 10 ++++++++ tests_/config/conftest.py | 18 -------------- tests_/config/test_config.py | 16 +----------- tests_/conftest.py | 40 +++++++++++++++++++++++++++++- 5 files changed, 57 insertions(+), 34 deletions(-) delete mode 100644 tests_/config/conftest.py diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index c687696beb..5db33ccb73 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -37,3 +37,10 @@ def upload( config_yaml = config_file.read() config = Config.from_yaml(context, config_yaml) config.upload() + + +@config_command_group.command() +def show() -> None: + context = ContextSettings.from_file().context + config = Config.from_remote(context) + print(config.to_yaml()) diff --git a/tests_/commands/test_config.py b/tests_/commands/test_config.py index 7f37e461f5..d700976ef3 100644 --- a/tests_/commands/test_config.py +++ b/tests_/commands/test_config.py @@ -32,3 +32,13 @@ def test_upload_no_file(self, runner, mock_upload_blob): ["upload"], ) assert result.exit_code == 2 + + +class TestShow: + def test_show(self, runner, config_yaml, mock_download_blob): + result = runner.invoke( + config_command_group, + ["show"] + ) + assert result.exit_code == 0 + assert config_yaml in result.stdout diff --git a/tests_/config/conftest.py b/tests_/config/conftest.py deleted file mode 100644 index 1ac39c86df..0000000000 --- a/tests_/config/conftest.py +++ /dev/null @@ -1,18 +0,0 @@ -from pytest import fixture - -from data_safe_haven.config.context_settings import Context - - -@fixture -def context_dict(): - return { - "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - "location": "uksouth", - "name": "Acme Deployment", - "subscription_name": "Data Safe Haven (Acme)", - } - - -@fixture -def context(context_dict): - return Context(**context_dict) diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index cf25dda4b2..ebb6ee9e5f 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -296,21 +296,7 @@ def test_from_yaml(self, config_sres, context, config_yaml): config.sres["sre1"].software_packages, SoftwarePackageCategory ) - def test_from_remote(self, context, config_sres, config_yaml, monkeypatch): - def mock_download_blob( - self, # noqa: ARG001 - blob_name: str, - resource_group_name: str, - storage_account_name: str, - storage_container_name: str, - ): - assert blob_name == context.config_filename - assert resource_group_name == context.resource_group_name - assert storage_account_name == context.storage_account_name - assert storage_container_name == context.storage_container_name - return config_yaml - - monkeypatch.setattr(AzureApi, "download_blob", mock_download_blob) + def test_from_remote(self, context, config_sres, mock_download_blob): config = Config.from_remote(context) assert config == config_sres diff --git a/tests_/conftest.py b/tests_/conftest.py index 3bc2dd959b..808693eb7d 100644 --- a/tests_/conftest.py +++ b/tests_/conftest.py @@ -1,8 +1,24 @@ from pytest import fixture +from data_safe_haven.config.context_settings import Context from data_safe_haven.external import AzureApi +@fixture +def context_dict(): + return { + "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "location": "uksouth", + "name": "Acme Deployment", + "subscription_name": "Data Safe Haven (Acme)", + } + + +@fixture +def context(context_dict): + return Context(**context_dict) + + @fixture def config_yaml(): return """azure: @@ -50,7 +66,25 @@ def config_file(config_yaml, tmp_path): @fixture -def mock_upload_blob(monkeypatch): +def mock_download_blob(monkeypatch, context, config_yaml): + def mock_download_blob( + self, # noqa: ARG001 + blob_name: str, + resource_group_name: str, + storage_account_name: str, + storage_container_name: str, + ): + assert blob_name == context.config_filename + assert resource_group_name == context.resource_group_name + assert storage_account_name == context.storage_account_name + assert storage_container_name == context.storage_container_name + return config_yaml + + monkeypatch.setattr(AzureApi, "download_blob", mock_download_blob) + + +@fixture +def mock_upload_blob(monkeypatch, context): def mock_upload_blob( self, # noqa: ARG001 blob_data: bytes | str, # noqa: ARG001 @@ -59,6 +93,10 @@ def mock_upload_blob( storage_account_name: str, # noqa: ARG001 storage_container_name: str, # noqa: ARG001 ): + assert blob_name == context.config_filename + assert resource_group_name == context.resource_group_name + assert storage_account_name == context.storage_account_name + assert storage_container_name == context.storage_container_name pass monkeypatch.setattr(AzureApi, "upload_blob", mock_upload_blob) From b25f9770af830592117bb920d6fe26c23b8dab67 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 29 Nov 2023 11:18:43 +0000 Subject: [PATCH 39/65] Run lint:fmt --- data_safe_haven/commands/config.py | 4 ++-- tests_/commands/conftest.py | 1 - tests_/commands/test_config.py | 15 +++++++-------- tests_/config/test_config.py | 6 ++++-- tests_/conftest.py | 8 ++++---- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index 5db33ccb73..12c591d216 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -13,8 +13,8 @@ @config_command_group.command() def template( file: Annotated[ - Optional[Path], - typer.Option(help="File path to write configuration template to.") + Optional[Path], # noqa: UP007 + typer.Option(help="File path to write configuration template to."), ] = None ) -> None: """Write a template Data Safe Haven configuration.""" diff --git a/tests_/commands/conftest.py b/tests_/commands/conftest.py index 5b337ca021..ba563e900d 100644 --- a/tests_/commands/conftest.py +++ b/tests_/commands/conftest.py @@ -1,7 +1,6 @@ from pytest import fixture from typer.testing import CliRunner - context_settings = """\ selected: acme_deployment contexts: diff --git a/tests_/commands/test_config.py b/tests_/commands/test_config.py index d700976ef3..a73fb4709d 100644 --- a/tests_/commands/test_config.py +++ b/tests_/commands/test_config.py @@ -10,7 +10,9 @@ def test_template(self, runner): def test_template_file(self, runner, tmp_path): template_file = (tmp_path / "template.yaml").absolute() - result = runner.invoke(config_command_group, ["template", "--file", str(template_file)]) + result = runner.invoke( + config_command_group, ["template", "--file", str(template_file)] + ) assert result.exit_code == 0 with open(template_file) as f: template_text = f.read() @@ -19,14 +21,14 @@ def test_template_file(self, runner, tmp_path): class TestUpload: - def test_upload(self, runner, config_file, mock_upload_blob): + def test_upload(self, runner, config_file, mock_upload_blob): # noqa: ARG002 result = runner.invoke( config_command_group, ["upload", str(config_file)], ) assert result.exit_code == 0 - def test_upload_no_file(self, runner, mock_upload_blob): + def test_upload_no_file(self, runner, mock_upload_blob): # noqa: ARG002 result = runner.invoke( config_command_group, ["upload"], @@ -35,10 +37,7 @@ def test_upload_no_file(self, runner, mock_upload_blob): class TestShow: - def test_show(self, runner, config_yaml, mock_download_blob): - result = runner.invoke( - config_command_group, - ["show"] - ) + def test_show(self, runner, config_yaml, mock_download_blob): # noqa: ARG002 + result = runner.invoke(config_command_group, ["show"]) assert result.exit_code == 0 assert config_yaml in result.stdout diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index ebb6ee9e5f..bfe102e2e6 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -296,12 +296,14 @@ def test_from_yaml(self, config_sres, context, config_yaml): config.sres["sre1"].software_packages, SoftwarePackageCategory ) - def test_from_remote(self, context, config_sres, mock_download_blob): + def test_from_remote( + self, context, config_sres, mock_download_blob # noqa: ARG002 + ): config = Config.from_remote(context) assert config == config_sres def test_to_yaml(self, config_sres, config_yaml): assert config_sres.to_yaml() == config_yaml - def test_upload(self, config_sres, mock_upload_blob): + def test_upload(self, config_sres, mock_upload_blob): # noqa: ARG002 config_sres.upload() diff --git a/tests_/conftest.py b/tests_/conftest.py index 808693eb7d..f97a7c2fcb 100644 --- a/tests_/conftest.py +++ b/tests_/conftest.py @@ -88,10 +88,10 @@ def mock_upload_blob(monkeypatch, context): def mock_upload_blob( self, # noqa: ARG001 blob_data: bytes | str, # noqa: ARG001 - blob_name: str, # noqa: ARG001 - resource_group_name: str, # noqa: ARG001 - storage_account_name: str, # noqa: ARG001 - storage_container_name: str, # noqa: ARG001 + blob_name: str, + resource_group_name: str, + storage_account_name: str, + storage_container_name: str, ): assert blob_name == context.config_filename assert resource_group_name == context.resource_group_name From 3235983bce5bf80fb78664a80d1ec3e68de4d772 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Wed, 29 Nov 2023 15:46:22 +0000 Subject: [PATCH 40/65] Remove dependency on config for context --- data_safe_haven/commands/context.py | 17 ++++---- data_safe_haven/context/context.py | 43 ++++++++----------- data_safe_haven/functions/typer_validators.py | 1 + tests_/commands/test_context.py | 10 +---- tests_/functions/test_typer_validators.py | 4 +- tests_/functions/test_validators.py | 4 +- 6 files changed, 34 insertions(+), 45 deletions(-) diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index c18db89e38..7fa6eb142f 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -4,7 +4,7 @@ import typer from rich import print -from data_safe_haven.config import Config, ContextSettings +from data_safe_haven.config import ContextSettings from data_safe_haven.config.context_settings import Context, default_config_file_path from data_safe_haven.context import Context as ContextInfra from data_safe_haven.functions.typer_validators import typer_validate_aad_guid @@ -149,7 +149,7 @@ def update( def remove( key: Annotated[str, typer.Argument(help="Name of the context to remove.")], ) -> None: - """Remove the selected context.""" + """Removes a context.""" settings = ContextSettings.from_file() settings.remove(key) settings.write() @@ -158,15 +158,14 @@ def remove( @context_command_group.command() def create() -> None: """Create Data Safe Haven context infrastructure.""" - config = Config() - context = ContextInfra(config) - context.create() - context.config.upload() + context = ContextSettings.from_file().context + context_infra = ContextInfra(context) + context_infra.create() @context_command_group.command() def teardown() -> None: """Tear down Data Safe Haven context infrastructure.""" - config = Config() - context = ContextInfra(config) - context.teardown() + context = ContextSettings.from_file().context + context_infra = ContextInfra(context) + context_infra.teardown() diff --git a/data_safe_haven/context/context.py b/data_safe_haven/context/context.py index ce92423892..e57fa0d5e0 100644 --- a/data_safe_haven/context/context.py +++ b/data_safe_haven/context/context.py @@ -1,4 +1,5 @@ -from data_safe_haven.config import Config +from data_safe_haven.config.config import ConfigSectionPulumi, ConfigSectionTags +from data_safe_haven.config.context_settings import Context from data_safe_haven.exceptions import DataSafeHavenAzureError from data_safe_haven.external import AzureApi @@ -6,10 +7,10 @@ class Context: """Azure resources to support Data Safe Haven context""" - def __init__(self, config: Config) -> None: + def __init__(self, context: Context) -> None: self.azure_api_: AzureApi | None = None - self.config = config - self.tags = {"component": "context"} | self.config.tags.model_dump() + self.context = context + self.tags = {"component": "context"} | ConfigSectionTags(context).model_dump() @property def azure_api(self) -> AzureApi: @@ -20,7 +21,7 @@ def azure_api(self) -> AzureApi: """ if not self.azure_api_: self.azure_api_ = AzureApi( - subscription_name=self.config.context.subscription_name, + subscription_name=self.context.subscription_name, ) return self.azure_api_ @@ -31,55 +32,51 @@ def create(self) -> None: DataSafeHavenAzureError if any resources cannot be created """ try: - self.config.azure.subscription_id = self.azure_api.subscription_id - self.config.azure.tenant_id = self.azure_api.tenant_id resource_group = self.azure_api.ensure_resource_group( - location=self.config.azure.location, - resource_group_name=self.config.context.resource_group_name, + location=self.context.location, + resource_group_name=self.context.resource_group_name, tags=self.tags, ) if not resource_group.name: - msg = f"Resource group '{self.config.context.resource_group_name}' was not created." + msg = f"Resource group '{self.context.resource_group_name}' was not created." raise DataSafeHavenAzureError(msg) identity = self.azure_api.ensure_managed_identity( - identity_name=self.config.context.managed_identity_name, + identity_name=self.context.managed_identity_name, location=resource_group.location, resource_group_name=resource_group.name, ) storage_account = self.azure_api.ensure_storage_account( location=resource_group.location, resource_group_name=resource_group.name, - storage_account_name=self.config.context.storage_account_name, + storage_account_name=self.context.storage_account_name, tags=self.tags, ) if not storage_account.name: - msg = f"Storage account '{self.config.context.storage_account_name}' was not created." + msg = f"Storage account '{self.context.storage_account_name}' was not created." raise DataSafeHavenAzureError(msg) _ = self.azure_api.ensure_storage_blob_container( - container_name=self.config.context.storage_container_name, + container_name=self.context.storage_container_name, resource_group_name=resource_group.name, storage_account_name=storage_account.name, ) _ = self.azure_api.ensure_storage_blob_container( - container_name=self.config.pulumi.storage_container_name, + container_name=self.pulumi.storage_container_name, resource_group_name=resource_group.name, storage_account_name=storage_account.name, ) keyvault = self.azure_api.ensure_keyvault( - admin_group_id=self.config.azure.admin_group_id, - key_vault_name=self.config.context.key_vault_name, + admin_group_id=self.context.admin_group_id, + key_vault_name=self.context.key_vault_name, location=resource_group.location, managed_identity=identity, resource_group_name=resource_group.name, tags=self.tags, ) if not keyvault.name: - msg = ( - f"Keyvault '{self.config.context.key_vault_name}' was not created." - ) + msg = f"Keyvault '{self.context.key_vault_name}' was not created." raise DataSafeHavenAzureError(msg) self.azure_api.ensure_keyvault_key( - key_name=self.config.pulumi.encryption_key_name, + key_name=ConfigSectionPulumi.encryption_key_name, key_vault_name=keyvault.name, ) except Exception as exc: @@ -93,9 +90,7 @@ def teardown(self) -> None: DataSafeHavenAzureError if any resources cannot be destroyed """ try: - self.azure_api.remove_resource_group( - self.config.context.resource_group_name - ) + self.azure_api.remove_resource_group(self.context.resource_group_name) except Exception as exc: msg = f"Failed to destroy context resources.\n{exc}" raise DataSafeHavenAzureError(msg) from exc diff --git a/data_safe_haven/functions/typer_validators.py b/data_safe_haven/functions/typer_validators.py index 91434740d9..dafd3b2fec 100644 --- a/data_safe_haven/functions/typer_validators.py +++ b/data_safe_haven/functions/typer_validators.py @@ -14,6 +14,7 @@ def typer_validator_factory(validator: Callable[[Any], Any]) -> Callable[[Any], Any]: """Factory to create validation functions for Typer from Pydantic validators""" + def typer_validator(x: Any) -> Any: # Return unused optional arguments if x is None: diff --git a/tests_/commands/test_context.py b/tests_/commands/test_context.py index 87a499f7b8..5af53b9350 100644 --- a/tests_/commands/test_context.py +++ b/tests_/commands/test_context.py @@ -1,5 +1,4 @@ from data_safe_haven.commands.context import context_command_group -from data_safe_haven.config import Config from data_safe_haven.context import Context @@ -163,24 +162,19 @@ def test_remove_invalid(self, runner): class TestCreate: def test_create(self, runner, monkeypatch): - def mock_create(): + def mock_create(self): # noqa: ARG001 print("mock create") # noqa: T201 - def mock_upload(): - print("mock upload") # noqa: T201 - monkeypatch.setattr(Context, "create", mock_create) - monkeypatch.setattr(Config, "upload", mock_upload) result = runner.invoke(context_command_group, ["create"]) assert "mock create" in result.stdout - assert "mock upload" in result.stdout assert result.exit_code == 0 class TestTeardown: def test_teardown(self, runner, monkeypatch): - def mock_teardown(): + def mock_teardown(self): # noqa: ARG001 print("mock teardown") # noqa: T201 monkeypatch.setattr(Context, "teardown", mock_teardown) diff --git a/tests_/functions/test_typer_validators.py b/tests_/functions/test_typer_validators.py index 72ae8b3f53..4cba74fc29 100644 --- a/tests_/functions/test_typer_validators.py +++ b/tests_/functions/test_typer_validators.py @@ -10,7 +10,7 @@ class TestTyperValidateAadGuid: [ "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", "10de18e7-b238-6f1e-a4ad-772708929203", - ] + ], ) def test_typer_validate_aad_guid(self, guid): assert typer_validate_aad_guid(guid) == guid @@ -20,7 +20,7 @@ def test_typer_validate_aad_guid(self, guid): [ "10de18e7_b238_6f1e_a4ad_772708929203", "not a guid", - ] + ], ) def test_typer_validate_aad_guid_fail(self, guid): with pytest.raises(BadParameter) as exc: diff --git a/tests_/functions/test_validators.py b/tests_/functions/test_validators.py index bcf41ca263..c8351fdcc7 100644 --- a/tests_/functions/test_validators.py +++ b/tests_/functions/test_validators.py @@ -9,7 +9,7 @@ class TestValidateAadGuid: [ "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", "10de18e7-b238-6f1e-a4ad-772708929203", - ] + ], ) def test_validate_aad_guid(self, guid): assert validate_aad_guid(guid) == guid @@ -19,7 +19,7 @@ def test_validate_aad_guid(self, guid): [ "10de18e7_b238_6f1e_a4ad_772708929203", "not a guid", - ] + ], ) def test_validate_aad_guid_fail(self, guid): with pytest.raises(ValueError) as exc: From db7855afc72e8ef720ab3347024f4a037ef03fa6 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 30 Nov 2023 11:54:28 +0000 Subject: [PATCH 41/65] Use new Config in deploy commands --- data_safe_haven/commands/deploy.py | 340 ++++++++++++++----------- data_safe_haven/commands/deploy_shm.py | 84 ------ data_safe_haven/commands/deploy_sre.py | 171 ------------- 3 files changed, 198 insertions(+), 397 deletions(-) delete mode 100644 data_safe_haven/commands/deploy_shm.py delete mode 100644 data_safe_haven/commands/deploy_sre.py diff --git a/data_safe_haven/commands/deploy.py b/data_safe_haven/commands/deploy.py index 46ef50f7f4..5f2d0b12e2 100644 --- a/data_safe_haven/commands/deploy.py +++ b/data_safe_haven/commands/deploy.py @@ -3,64 +3,18 @@ import typer -from data_safe_haven.functions.typer_validators import ( - typer_validate_aad_guid, - typer_validate_azure_vm_sku, - typer_validate_email_address, - typer_validate_ip_address, - typer_validate_timezone, -) -from data_safe_haven.utility import DatabaseSystem, SoftwarePackageCategory - -from .deploy_shm import deploy_shm -from .deploy_sre import deploy_sre +from data_safe_haven.config import Config, ContextSettings +from data_safe_haven.exceptions import DataSafeHavenError +from data_safe_haven.external import GraphApi +from data_safe_haven.functions import alphanumeric, bcrypt_salt, password +from data_safe_haven.infrastructure import SHMStackManager, SREStackManager +from data_safe_haven.provisioning import SHMProvisioningManager, SREProvisioningManager deploy_command_group = typer.Typer() @deploy_command_group.command() def shm( - aad_tenant_id: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--aad-tenant-id", - "-a", - help=( - "The tenant ID for the AzureAD where users will be created," - " for example '10de18e7-b238-6f1e-a4ad-772708929203'." - ), - callback=typer_validate_aad_guid, - ), - ] = None, - admin_email_address: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--email", - "-e", - help="The email address where your system deployers and administrators can be contacted.", - callback=typer_validate_email_address, - ), - ] = None, - admin_ip_addresses: Annotated[ - Optional[list[str]], # noqa: UP007 - typer.Option( - "--ip-address", - "-i", - help=( - "An IP address or range used by your system deployers and administrators." - " [*may be specified several times*]" - ), - callback=lambda ips: [typer_validate_ip_address(ip) for ip in ips], - ), - ] = None, - domain: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--domain", - "-d", - help="The domain that SHM users will belong to.", - ), - ] = None, force: Annotated[ Optional[bool], # noqa: UP007 typer.Option( @@ -69,63 +23,71 @@ def shm( help="Force this operation, cancelling any others that are in progress.", ), ] = None, - timezone: Annotated[ - Optional[str], # noqa: UP007 - typer.Option( - "--timezone", - "-t", - help="The timezone that this Data Safe Haven deployment will use.", - callback=typer_validate_timezone, - ), - ] = None, ) -> None: """Deploy a Safe Haven Management component""" - deploy_shm( - aad_tenant_id=aad_tenant_id, - admin_email_address=admin_email_address, - admin_ip_addresses=admin_ip_addresses, - force=force, - fqdn=domain, - timezone=timezone, - ) + context = ContextSettings.from_file().context + config = Config.from_remote(context) + + try: + # Add the SHM domain to AzureAD as a custom domain + graph_api = GraphApi( + tenant_id=config.shm.aad_tenant_id, + default_scopes=[ + "Application.ReadWrite.All", + "Domain.ReadWrite.All", + "Group.ReadWrite.All", + ], + ) + verification_record = graph_api.add_custom_domain(config.shm.fqdn) + + # Initialise Pulumi stack + stack = SHMStackManager(config) + # Set Azure options + stack.add_option("azure-native:location", config.azure.location, replace=False) + stack.add_option( + "azure-native:subscriptionId", + config.azure.subscription_id, + replace=False, + ) + stack.add_option("azure-native:tenantId", config.azure.tenant_id, replace=False) + # Add necessary secrets + stack.add_secret("password-domain-ldap-searcher", password(20), replace=False) + stack.add_secret( + "verification-azuread-custom-domain", verification_record, replace=False + ) + + # Deploy Azure infrastructure with Pulumi + if force is None: + stack.deploy() + else: + stack.deploy(force=force) + + # Add Pulumi infrastructure information to the config file + config.add_stack(stack.stack_name, stack.local_stack_path) + + # Upload config to blob storage + config.upload() + + # Add the SHM domain as a custom domain in AzureAD + graph_api.verify_custom_domain( + config.shm.fqdn, + stack.output("networking")["fqdn_nameservers"], + ) + + # Provision SHM with anything that could not be done in Pulumi + manager = SHMProvisioningManager( + subscription_name=config.context.subscription_name, + stack=stack, + ) + manager.run() + except DataSafeHavenError as exc: + msg = f"Could not deploy Data Safe Haven Management environment.\n{exc}" + raise DataSafeHavenError(msg) from exc @deploy_command_group.command() def sre( name: Annotated[str, typer.Argument(help="Name of SRE to deploy")], - allow_copy: Annotated[ - Optional[bool], # noqa: UP007 - typer.Option( - "--allow-copy", - "-c", - help="Whether to allow text to be copied out of the SRE.", - ), - ] = None, - allow_paste: Annotated[ - Optional[bool], # noqa: UP007 - typer.Option( - "--allow-paste", - "-p", - help="Whether to allow text to be pasted into the SRE.", - ), - ] = None, - data_provider_ip_addresses: Annotated[ - Optional[list[str]], # noqa: UP007 - typer.Option( - "--data-provider-ip-address", - "-d", - help="An IP address or range used by your data providers. [*may be specified several times*]", - callback=lambda vms: [typer_validate_ip_address(vm) for vm in vms], - ), - ] = None, - databases: Annotated[ - Optional[list[DatabaseSystem]], # noqa: UP007 - typer.Option( - "--database", - "-b", - help="Make a database of this system available to users of this SRE.", - ), - ] = None, force: Annotated[ Optional[bool], # noqa: UP007 typer.Option( @@ -134,45 +96,139 @@ def sre( help="Force this operation, cancelling any others that are in progress.", ), ] = None, - software_packages: Annotated[ - Optional[SoftwarePackageCategory], # noqa: UP007 - typer.Option( - "--software-packages", - "-s", - help="The category of package to allow users to install from enabled software repositories.", - ), - ] = None, - user_ip_addresses: Annotated[ - Optional[list[str]], # noqa: UP007 - typer.Option( - "--user-ip-address", - "-u", - help="An IP address or range used by your users. [*may be specified several times*]", - callback=lambda ips: [typer_validate_ip_address(ip) for ip in ips], - ), - ] = None, - workspace_skus: Annotated[ - Optional[list[str]], # noqa: UP007 - typer.Option( - "--workspace-sku", - "-w", - help=( - "A virtual machine SKU to make available to your users as a workspace." - " [*may be specified several times*]" - ), - callback=lambda ips: [typer_validate_azure_vm_sku(ip) for ip in ips], - ), - ] = None, ) -> None: """Deploy a Secure Research Environment""" - deploy_sre( - name, - allow_copy=allow_copy, - allow_paste=allow_paste, - data_provider_ip_addresses=data_provider_ip_addresses, - databases=databases, - force=force, - software_packages=software_packages, - user_ip_addresses=user_ip_addresses, - workspace_skus=workspace_skus, - ) + context = ContextSettings.from_file().context + config = Config.from_remote(context) + + try: + # Use a JSON-safe SRE name + sre_name = alphanumeric(name).lower() + + # Load GraphAPI as this may require user-interaction that is not possible as + # part of a Pulumi declarative command + graph_api = GraphApi( + tenant_id=config.shm.aad_tenant_id, + default_scopes=["Application.ReadWrite.All", "Group.ReadWrite.All"], + ) + + # Initialise Pulumi stack + shm_stack = SHMStackManager(config) + stack = SREStackManager(config, sre_name, graph_api_token=graph_api.token) + # Set Azure options + stack.add_option("azure-native:location", config.azure.location, replace=False) + stack.add_option( + "azure-native:subscriptionId", + config.azure.subscription_id, + replace=False, + ) + stack.add_option("azure-native:tenantId", config.azure.tenant_id, replace=False) + # Load SHM stack outputs + stack.add_option( + "shm-domain_controllers-domain_sid", + shm_stack.output("domain_controllers")["domain_sid"], + replace=True, + ) + stack.add_option( + "shm-domain_controllers-ldap_root_dn", + shm_stack.output("domain_controllers")["ldap_root_dn"], + replace=True, + ) + stack.add_option( + "shm-domain_controllers-ldap_server_ip", + shm_stack.output("domain_controllers")["ldap_server_ip"], + replace=True, + ) + stack.add_option( + "shm-domain_controllers-netbios_name", + shm_stack.output("domain_controllers")["netbios_name"], + replace=True, + ) + stack.add_option( + "shm-firewall-private-ip-address", + shm_stack.output("firewall")["private_ip_address"], + replace=True, + ) + stack.add_option( + "shm-monitoring-automation_account_name", + shm_stack.output("monitoring")["automation_account_name"], + replace=True, + ) + stack.add_option( + "shm-monitoring-log_analytics_workspace_id", + shm_stack.output("monitoring")["log_analytics_workspace_id"], + replace=True, + ) + stack.add_secret( + "shm-monitoring-log_analytics_workspace_key", + shm_stack.output("monitoring")["log_analytics_workspace_key"], + replace=True, + ) + stack.add_option( + "shm-monitoring-resource_group_name", + shm_stack.output("monitoring")["resource_group_name"], + replace=True, + ) + stack.add_option( + "shm-networking-private_dns_zone_base_id", + shm_stack.output("networking")["private_dns_zone_base_id"], + replace=True, + ) + stack.add_option( + "shm-networking-resource_group_name", + shm_stack.output("networking")["resource_group_name"], + replace=True, + ) + stack.add_option( + "shm-networking-subnet_identity_servers_prefix", + shm_stack.output("networking")["subnet_identity_servers_prefix"], + replace=True, + ) + stack.add_option( + "shm-networking-subnet_subnet_monitoring_prefix", + shm_stack.output("networking")["subnet_monitoring_prefix"], + replace=True, + ) + stack.add_option( + "shm-networking-subnet_update_servers_prefix", + shm_stack.output("networking")["subnet_update_servers_prefix"], + replace=True, + ) + stack.add_option( + "shm-networking-virtual_network_name", + shm_stack.output("networking")["virtual_network_name"], + replace=True, + ) + stack.add_option( + "shm-update_servers-ip_address_linux", + shm_stack.output("update_servers")["ip_address_linux"], + replace=True, + ) + # Add necessary secrets + stack.copy_secret("password-domain-ldap-searcher", shm_stack) + stack.add_secret("salt-dns-server-admin", bcrypt_salt(), replace=False) + + # Deploy Azure infrastructure with Pulumi + if force is None: + stack.deploy() + else: + stack.deploy(force=force) + + # Add Pulumi infrastructure information to the config file + config.add_stack(stack.stack_name, stack.local_stack_path) + + # Upload config to blob storage + config.upload() + + # Provision SRE with anything that could not be done in Pulumi + manager = SREProvisioningManager( + shm_stack=shm_stack, + sre_name=sre_name, + sre_stack=stack, + subscription_name=config.context.subscription_name, + timezone=config.shm.timezone, + ) + manager.run() + except DataSafeHavenError as exc: + msg = f"Could not deploy Secure Research Environment {sre_name}.\n{exc}" + raise DataSafeHavenError(msg) from exc diff --git a/data_safe_haven/commands/deploy_shm.py b/data_safe_haven/commands/deploy_shm.py deleted file mode 100644 index d6864f3511..0000000000 --- a/data_safe_haven/commands/deploy_shm.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Deploy a Safe Haven Management component""" -from data_safe_haven.config import Config -from data_safe_haven.exceptions import DataSafeHavenError -from data_safe_haven.external import GraphApi -from data_safe_haven.functions import password -from data_safe_haven.infrastructure import SHMStackManager -from data_safe_haven.provisioning import SHMProvisioningManager - - -def deploy_shm( - *, - aad_tenant_id: str | None = None, - admin_email_address: str | None = None, - admin_ip_addresses: list[str] | None = None, - force: bool | None = None, - fqdn: str | None = None, - timezone: str | None = None, -) -> None: - """Deploy a Safe Haven Management component""" - try: - # Load and validate config file - config = Config() - config.shm.update( - aad_tenant_id=aad_tenant_id, - admin_email_address=admin_email_address, - admin_ip_addresses=admin_ip_addresses, - fqdn=fqdn, - timezone=timezone, - ) - - # Add the SHM domain to AzureAD as a custom domain - graph_api = GraphApi( - tenant_id=config.shm.aad_tenant_id, - default_scopes=[ - "Application.ReadWrite.All", - "Domain.ReadWrite.All", - "Group.ReadWrite.All", - ], - ) - verification_record = graph_api.add_custom_domain(config.shm.fqdn) - - # Initialise Pulumi stack - stack = SHMStackManager(config) - # Set Azure options - stack.add_option("azure-native:location", config.azure.location, replace=False) - stack.add_option( - "azure-native:subscriptionId", - config.azure.subscription_id, - replace=False, - ) - stack.add_option("azure-native:tenantId", config.azure.tenant_id, replace=False) - # Add necessary secrets - stack.add_secret("password-domain-ldap-searcher", password(20), replace=False) - stack.add_secret( - "verification-azuread-custom-domain", verification_record, replace=False - ) - - # Deploy Azure infrastructure with Pulumi - if force is None: - stack.deploy() - else: - stack.deploy(force=force) - - # Add Pulumi infrastructure information to the config file - config.add_stack(stack.stack_name, stack.local_stack_path) - - # Upload config to blob storage - config.upload() - - # Add the SHM domain as a custom domain in AzureAD - graph_api.verify_custom_domain( - config.shm.fqdn, - stack.output("networking")["fqdn_nameservers"], - ) - - # Provision SHM with anything that could not be done in Pulumi - manager = SHMProvisioningManager( - subscription_name=config.context.subscription_name, - stack=stack, - ) - manager.run() - except DataSafeHavenError as exc: - msg = f"Could not deploy Data Safe Haven Management environment.\n{exc}" - raise DataSafeHavenError(msg) from exc diff --git a/data_safe_haven/commands/deploy_sre.py b/data_safe_haven/commands/deploy_sre.py deleted file mode 100644 index 5fe4b43bb8..0000000000 --- a/data_safe_haven/commands/deploy_sre.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Deploy a Secure Research Environment component""" -from data_safe_haven.config import Config -from data_safe_haven.exceptions import ( - DataSafeHavenError, -) -from data_safe_haven.external import GraphApi -from data_safe_haven.functions import alphanumeric, bcrypt_salt -from data_safe_haven.infrastructure import SHMStackManager, SREStackManager -from data_safe_haven.provisioning import SREProvisioningManager -from data_safe_haven.utility import DatabaseSystem, SoftwarePackageCategory - - -def deploy_sre( - name: str, - *, - allow_copy: bool | None = None, - allow_paste: bool | None = None, - data_provider_ip_addresses: list[str] | None = None, - databases: list[DatabaseSystem] | None = None, - force: bool | None = None, - workspace_skus: list[str] | None = None, - software_packages: SoftwarePackageCategory | None = None, - user_ip_addresses: list[str] | None = None, -) -> None: - """Deploy a Secure Research Environment component""" - sre_name = "UNKNOWN" - try: - # Use a JSON-safe SRE name - sre_name = alphanumeric(name).lower() - - # Load and validate config file - config = Config() - config.sre(sre_name).update( - data_provider_ip_addresses=data_provider_ip_addresses, - databases=databases, - workspace_skus=workspace_skus, - software_packages=software_packages, - user_ip_addresses=user_ip_addresses, - ) - config.sre(sre_name).remote_desktop.update( - allow_copy=allow_copy, - allow_paste=allow_paste, - ) - - # Load GraphAPI as this may require user-interaction that is not possible as - # part of a Pulumi declarative command - graph_api = GraphApi( - tenant_id=config.shm.aad_tenant_id, - default_scopes=["Application.ReadWrite.All", "Group.ReadWrite.All"], - ) - - # Initialise Pulumi stack - shm_stack = SHMStackManager(config) - stack = SREStackManager(config, sre_name, graph_api_token=graph_api.token) - # Set Azure options - stack.add_option("azure-native:location", config.azure.location, replace=False) - stack.add_option( - "azure-native:subscriptionId", - config.azure.subscription_id, - replace=False, - ) - stack.add_option("azure-native:tenantId", config.azure.tenant_id, replace=False) - # Load SHM stack outputs - stack.add_option( - "shm-domain_controllers-domain_sid", - shm_stack.output("domain_controllers")["domain_sid"], - replace=True, - ) - stack.add_option( - "shm-domain_controllers-ldap_root_dn", - shm_stack.output("domain_controllers")["ldap_root_dn"], - replace=True, - ) - stack.add_option( - "shm-domain_controllers-ldap_server_ip", - shm_stack.output("domain_controllers")["ldap_server_ip"], - replace=True, - ) - stack.add_option( - "shm-domain_controllers-netbios_name", - shm_stack.output("domain_controllers")["netbios_name"], - replace=True, - ) - stack.add_option( - "shm-firewall-private-ip-address", - shm_stack.output("firewall")["private_ip_address"], - replace=True, - ) - stack.add_option( - "shm-monitoring-automation_account_name", - shm_stack.output("monitoring")["automation_account_name"], - replace=True, - ) - stack.add_option( - "shm-monitoring-log_analytics_workspace_id", - shm_stack.output("monitoring")["log_analytics_workspace_id"], - replace=True, - ) - stack.add_secret( - "shm-monitoring-log_analytics_workspace_key", - shm_stack.output("monitoring")["log_analytics_workspace_key"], - replace=True, - ) - stack.add_option( - "shm-monitoring-resource_group_name", - shm_stack.output("monitoring")["resource_group_name"], - replace=True, - ) - stack.add_option( - "shm-networking-private_dns_zone_base_id", - shm_stack.output("networking")["private_dns_zone_base_id"], - replace=True, - ) - stack.add_option( - "shm-networking-resource_group_name", - shm_stack.output("networking")["resource_group_name"], - replace=True, - ) - stack.add_option( - "shm-networking-subnet_identity_servers_prefix", - shm_stack.output("networking")["subnet_identity_servers_prefix"], - replace=True, - ) - stack.add_option( - "shm-networking-subnet_subnet_monitoring_prefix", - shm_stack.output("networking")["subnet_monitoring_prefix"], - replace=True, - ) - stack.add_option( - "shm-networking-subnet_update_servers_prefix", - shm_stack.output("networking")["subnet_update_servers_prefix"], - replace=True, - ) - stack.add_option( - "shm-networking-virtual_network_name", - shm_stack.output("networking")["virtual_network_name"], - replace=True, - ) - stack.add_option( - "shm-update_servers-ip_address_linux", - shm_stack.output("update_servers")["ip_address_linux"], - replace=True, - ) - # Add necessary secrets - stack.copy_secret("password-domain-ldap-searcher", shm_stack) - stack.add_secret("salt-dns-server-admin", bcrypt_salt(), replace=False) - - # Deploy Azure infrastructure with Pulumi - if force is None: - stack.deploy() - else: - stack.deploy(force=force) - - # Add Pulumi infrastructure information to the config file - config.add_stack(stack.stack_name, stack.local_stack_path) - - # Upload config to blob storage - config.upload() - - # Provision SRE with anything that could not be done in Pulumi - manager = SREProvisioningManager( - shm_stack=shm_stack, - sre_name=sre_name, - sre_stack=stack, - subscription_name=config.context.subscription_name, - timezone=config.shm.timezone, - ) - manager.run() - except DataSafeHavenError as exc: - msg = f"Could not deploy Secure Research Environment {sre_name}.\n{exc}" - raise DataSafeHavenError(msg) from exc From 5fe21f83363a337ebab1744e84a7b0a9c7552eb4 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 30 Nov 2023 11:54:48 +0000 Subject: [PATCH 42/65] Update Pulumi version draft README --- data_safe_haven/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/data_safe_haven/README.md b/data_safe_haven/README.md index 03d55a48a8..13cadd9f02 100644 --- a/data_safe_haven/README.md +++ b/data_safe_haven/README.md @@ -14,6 +14,14 @@ Install the following requirements before starting > dsh context create ``` +- Create the configuration + +```console +> dsh config template --file config.yaml +> vim config.yaml +> dsh config upload config.yaml +``` + - Next deploy the Safe Haven Management (SHM) infrastructure [approx 30 minutes]: ```console From e6fb404f7b36eb67c910f54592855f1a3ab255a0 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 12 Dec 2023 10:07:49 +0000 Subject: [PATCH 43/65] Use new config in teardown commands --- data_safe_haven/commands/teardown.py | 66 ++++++++++++++++++++++-- data_safe_haven/commands/teardown_shm.py | 32 ------------ data_safe_haven/commands/teardown_sre.py | 49 ------------------ data_safe_haven/config/config.py | 3 +- 4 files changed, 64 insertions(+), 86 deletions(-) delete mode 100644 data_safe_haven/commands/teardown_shm.py delete mode 100644 data_safe_haven/commands/teardown_sre.py diff --git a/data_safe_haven/commands/teardown.py b/data_safe_haven/commands/teardown.py index 5e6e7d08ec..e77af4e8f6 100644 --- a/data_safe_haven/commands/teardown.py +++ b/data_safe_haven/commands/teardown.py @@ -3,8 +3,14 @@ import typer -from .teardown_shm import teardown_shm -from .teardown_sre import teardown_sre +from data_safe_haven.config import Config, ContextSettings +from data_safe_haven.exceptions import ( + DataSafeHavenError, + DataSafeHavenInputError, +) +from data_safe_haven.external import GraphApi +from data_safe_haven.functions import alphanumeric +from data_safe_haven.infrastructure import SHMStackManager, SREStackManager teardown_command_group = typer.Typer() @@ -13,7 +19,27 @@ help="Tear down a deployed a Safe Haven Management component." ) def shm() -> None: - teardown_shm() + context = ContextSettings.from_file().context + config = Config.from_remote(context) + + try: + # Remove infrastructure deployed with Pulumi + try: + stack = SHMStackManager(config) + stack.teardown() + except Exception as exc: + msg = f"Unable to teardown Pulumi infrastructure.\n{exc}" + raise DataSafeHavenInputError(msg) from exc + + # Remove information from config file + if stack.stack_name in config.pulumi.stacks.keys(): + del config.pulumi.stacks[stack.stack_name] + + # Upload config to blob storage + config.upload() + except DataSafeHavenError as exc: + msg = f"Could not teardown Safe Haven Management component.\n{exc}" + raise DataSafeHavenError(msg) from exc @teardown_command_group.command( @@ -22,4 +48,36 @@ def shm() -> None: def sre( name: Annotated[str, typer.Argument(help="Name of SRE to teardown.")], ) -> None: - teardown_sre(name) + context = ContextSettings.from_file().context + config = Config.from_remote(context) + + sre_name = alphanumeric(name).lower() + try: + # Load GraphAPI as this may require user-interaction that is not possible as + # part of a Pulumi declarative command + graph_api = GraphApi( + tenant_id=config.shm.aad_tenant_id, + default_scopes=["Application.ReadWrite.All", "Group.ReadWrite.All"], + ) + + # Remove infrastructure deployed with Pulumi + try: + stack = SREStackManager(config, sre_name, graph_api_token=graph_api.token) + if stack.work_dir.exists(): + stack.teardown() + else: + msg = f"SRE {sre_name} not found - check the name is spelt correctly." + raise DataSafeHavenInputError(msg) + except Exception as exc: + msg = f"Unable to teardown Pulumi infrastructure.\n{exc}" + raise DataSafeHavenInputError(msg) from exc + + # Remove information from config file + config.remove_stack(stack.stack_name) + config.remove_sre(sre_name) + + # Upload config to blob storage + config.upload() + except DataSafeHavenError as exc: + msg = f"Could not teardown Secure Research Environment '{sre_name}'.\n{exc}" + raise DataSafeHavenError(msg) from exc diff --git a/data_safe_haven/commands/teardown_shm.py b/data_safe_haven/commands/teardown_shm.py deleted file mode 100644 index c05bf81ca4..0000000000 --- a/data_safe_haven/commands/teardown_shm.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Teardown a deployed a Safe Haven Management component""" -from data_safe_haven.config import Config -from data_safe_haven.exceptions import ( - DataSafeHavenError, - DataSafeHavenInputError, -) -from data_safe_haven.infrastructure import SHMStackManager - - -def teardown_shm() -> None: - """Teardown a deployed a Safe Haven Management component""" - try: - # Load config file - config = Config() - - # Remove infrastructure deployed with Pulumi - try: - stack = SHMStackManager(config) - stack.teardown() - except Exception as exc: - msg = f"Unable to teardown Pulumi infrastructure.\n{exc}" - raise DataSafeHavenInputError(msg) from exc - - # Remove information from config file - if stack.stack_name in config.pulumi.stacks.keys(): - del config.pulumi.stacks[stack.stack_name] - - # Upload config to blob storage - config.upload() - except DataSafeHavenError as exc: - msg = f"Could not teardown Safe Haven Management component.\n{exc}" - raise DataSafeHavenError(msg) from exc diff --git a/data_safe_haven/commands/teardown_sre.py b/data_safe_haven/commands/teardown_sre.py deleted file mode 100644 index 6c51ca856a..0000000000 --- a/data_safe_haven/commands/teardown_sre.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Teardown a deployed Secure Research Environment""" -from data_safe_haven.config import Config -from data_safe_haven.exceptions import ( - DataSafeHavenError, - DataSafeHavenInputError, -) -from data_safe_haven.external import GraphApi -from data_safe_haven.functions import alphanumeric -from data_safe_haven.infrastructure import SREStackManager - - -def teardown_sre(name: str) -> None: - """Teardown a deployed Secure Research Environment""" - sre_name = "UNKNOWN" - try: - # Use a JSON-safe SRE name - sre_name = alphanumeric(name).lower() - - # Load config file - config = Config() - - # Load GraphAPI as this may require user-interaction that is not possible as - # part of a Pulumi declarative command - graph_api = GraphApi( - tenant_id=config.shm.aad_tenant_id, - default_scopes=["Application.ReadWrite.All", "Group.ReadWrite.All"], - ) - - # Remove infrastructure deployed with Pulumi - try: - stack = SREStackManager(config, sre_name, graph_api_token=graph_api.token) - if stack.work_dir.exists(): - stack.teardown() - else: - msg = f"SRE {sre_name} not found - check the name is spelt correctly." - raise DataSafeHavenInputError(msg) - except Exception as exc: - msg = f"Unable to teardown Pulumi infrastructure.\n{exc}" - raise DataSafeHavenInputError(msg) from exc - - # Remove information from config file - config.remove_stack(stack.stack_name) - config.remove_sre(sre_name) - - # Upload config to blob storage - config.upload() - except DataSafeHavenError as exc: - msg = f"Could not teardown Secure Research Environment '{sre_name}'.\n{exc}" - raise DataSafeHavenError(msg) from exc diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index d47cfa83da..4faf134abd 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -273,7 +273,8 @@ def pulumi_encryption_key(self) -> KeyVaultKey: @property def pulumi_encryption_key_version(self) -> str: - return self.pulumi_encryption_key.id.split("/")[-1] + key_id: str = self.pulumi_encryption_key.id + return key_id.split("/")[-1] def is_complete(self, *, require_sres: bool) -> bool: if require_sres: From 56ec25795fc80f7f1859823e73cea38881631193 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 12 Dec 2023 10:18:00 +0000 Subject: [PATCH 44/65] Rename context infrastructure class --- data_safe_haven/commands/context.py | 2 +- data_safe_haven/context/__init__.py | 4 ++-- data_safe_haven/context/context.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index 7fa6eb142f..10a49334bd 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -6,7 +6,7 @@ from data_safe_haven.config import ContextSettings from data_safe_haven.config.context_settings import Context, default_config_file_path -from data_safe_haven.context import Context as ContextInfra +from data_safe_haven.context import ContextInfra from data_safe_haven.functions.typer_validators import typer_validate_aad_guid context_command_group = typer.Typer() diff --git a/data_safe_haven/context/__init__.py b/data_safe_haven/context/__init__.py index 94370d9ba1..3fb25abbf5 100644 --- a/data_safe_haven/context/__init__.py +++ b/data_safe_haven/context/__init__.py @@ -1,5 +1,5 @@ -from .context import Context +from .context import ContextInfra __all__ = [ - "Context", + "ContextInfra", ] diff --git a/data_safe_haven/context/context.py b/data_safe_haven/context/context.py index e57fa0d5e0..a7a798d3c3 100644 --- a/data_safe_haven/context/context.py +++ b/data_safe_haven/context/context.py @@ -4,7 +4,7 @@ from data_safe_haven.external import AzureApi -class Context: +class ContextInfra: """Azure resources to support Data Safe Haven context""" def __init__(self, context: Context) -> None: From 9fd210c8edd5b0018e742eb0dbbd3573cacaf0c6 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 12 Dec 2023 10:25:49 +0000 Subject: [PATCH 45/65] Correct references --- data_safe_haven/commands/config.py | 2 +- data_safe_haven/context/context.py | 2 +- data_safe_haven/infrastructure/stack_manager.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index 12c591d216..a7271f4123 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -18,7 +18,7 @@ def template( ] = None ) -> None: """Write a template Data Safe Haven configuration.""" - context = ContextSettings.from_file() + context = ContextSettings.from_file().context config = Config.template(context) if file: with open(file, "w") as outfile: diff --git a/data_safe_haven/context/context.py b/data_safe_haven/context/context.py index a7a798d3c3..22d8df3757 100644 --- a/data_safe_haven/context/context.py +++ b/data_safe_haven/context/context.py @@ -60,7 +60,7 @@ def create(self) -> None: storage_account_name=storage_account.name, ) _ = self.azure_api.ensure_storage_blob_container( - container_name=self.pulumi.storage_container_name, + container_name=self.context.storage_container_name, resource_group_name=resource_group.name, storage_account_name=storage_account.name, ) diff --git a/data_safe_haven/infrastructure/stack_manager.py b/data_safe_haven/infrastructure/stack_manager.py index 0d2ad96acb..1c4fbdafee 100644 --- a/data_safe_haven/infrastructure/stack_manager.py +++ b/data_safe_haven/infrastructure/stack_manager.py @@ -100,7 +100,7 @@ def stack(self) -> automation.Stack: stack_name=self.stack_name, program=self.program.run, opts=automation.LocalWorkspaceOptions( - secrets_provider=f"azurekeyvault://{self.cfg.context.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.encryption_key_version}", + secrets_provider=f"azurekeyvault://{self.cfg.context.key_vault_name}.vault.azure.net/keys/{self.cfg.pulumi.encryption_key_name}/{self.cfg.pulumi_encryption_key_version}", work_dir=str(self.work_dir), env_vars=self.account.env, ), From 829d7acaf784456601af0c50522e4b4e98884bd5 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 12 Dec 2023 10:33:56 +0000 Subject: [PATCH 46/65] Use new Config in admin commands --- data_safe_haven/commands/admin_add_users.py | 12 ++++++------ data_safe_haven/commands/admin_list_users.py | 12 ++++++------ data_safe_haven/commands/admin_register_users.py | 16 +++++++--------- data_safe_haven/commands/admin_remove_users.py | 12 ++++++------ .../commands/admin_unregister_users.py | 16 +++++++--------- 5 files changed, 32 insertions(+), 36 deletions(-) diff --git a/data_safe_haven/commands/admin_add_users.py b/data_safe_haven/commands/admin_add_users.py index c6ca9bb24e..acb5631365 100644 --- a/data_safe_haven/commands/admin_add_users.py +++ b/data_safe_haven/commands/admin_add_users.py @@ -2,19 +2,19 @@ import pathlib from data_safe_haven.administration.users import UserHandler -from data_safe_haven.config import Config +from data_safe_haven.config import Config, ContextSettings from data_safe_haven.exceptions import DataSafeHavenError from data_safe_haven.external import GraphApi def admin_add_users(csv_path: pathlib.Path) -> None: """Add users to a deployed Data Safe Haven""" - shm_name = "UNKNOWN" - try: - # Load config file - config = Config() - shm_name = config.context.name + context = ContextSettings.from_file().context + config = Config.from_remote(context) + + shm_name = context.shm_name + try: # Load GraphAPI as this may require user-interaction that is not # possible as part of a Pulumi declarative command graph_api = GraphApi( diff --git a/data_safe_haven/commands/admin_list_users.py b/data_safe_haven/commands/admin_list_users.py index 0e69eede16..1bf7e522d3 100644 --- a/data_safe_haven/commands/admin_list_users.py +++ b/data_safe_haven/commands/admin_list_users.py @@ -1,18 +1,18 @@ """List users from a deployed Data Safe Haven""" from data_safe_haven.administration.users import UserHandler -from data_safe_haven.config import Config +from data_safe_haven.config import Config, ContextSettings from data_safe_haven.exceptions import DataSafeHavenError from data_safe_haven.external import GraphApi def admin_list_users() -> None: """List users from a deployed Data Safe Haven""" - shm_name = "UNKNOWN" - try: - # Load config file - config = Config() - shm_name = config.context.name + context = ContextSettings.from_file().context + config = Config.from_remote(context) + + shm_name = context.shm_name + try: # Load GraphAPI as this may require user-interaction that is not # possible as part of a Pulumi declarative command graph_api = GraphApi( diff --git a/data_safe_haven/commands/admin_register_users.py b/data_safe_haven/commands/admin_register_users.py index af6b028dc0..fcc4e68a4c 100644 --- a/data_safe_haven/commands/admin_register_users.py +++ b/data_safe_haven/commands/admin_register_users.py @@ -1,6 +1,6 @@ """Register existing users with a deployed SRE""" from data_safe_haven.administration.users import UserHandler -from data_safe_haven.config import Config +from data_safe_haven.config import Config, ContextSettings from data_safe_haven.exceptions import DataSafeHavenError from data_safe_haven.external import GraphApi from data_safe_haven.functions import alphanumeric @@ -12,16 +12,14 @@ def admin_register_users( sre: str, ) -> None: """Register existing users with a deployed SRE""" - shm_name = "UNKNOWN" - sre_name = "UNKNOWN" - try: - # Use a JSON-safe SRE name - sre_name = alphanumeric(sre).lower() + context = ContextSettings.from_file().context + config = Config.from_remote(context) - # Load config file - config = Config() - shm_name = config.context.name + shm_name = context.shm_name + # Use a JSON-safe SRE name + sre_name = alphanumeric(sre).lower() + try: # Check that SRE option has been provided if not sre_name: msg = "SRE name must be specified." diff --git a/data_safe_haven/commands/admin_remove_users.py b/data_safe_haven/commands/admin_remove_users.py index 5f3a464c25..19d9f420ff 100644 --- a/data_safe_haven/commands/admin_remove_users.py +++ b/data_safe_haven/commands/admin_remove_users.py @@ -1,6 +1,6 @@ """Remove existing users from a deployed Data Safe Haven""" from data_safe_haven.administration.users import UserHandler -from data_safe_haven.config import Config +from data_safe_haven.config import Config, ContextSettings from data_safe_haven.exceptions import DataSafeHavenError from data_safe_haven.external import GraphApi @@ -9,12 +9,12 @@ def admin_remove_users( usernames: list[str], ) -> None: """Remove existing users from a deployed Data Safe Haven""" - shm_name = "UNKNOWN" - try: - # Load config file - config = Config() - shm_name = config.context.name + context = ContextSettings.from_file().context + config = Config.from_remote(context) + + shm_name = context.shm_name + try: # Load GraphAPI as this may require user-interaction that is not # possible as part of a Pulumi declarative command graph_api = GraphApi( diff --git a/data_safe_haven/commands/admin_unregister_users.py b/data_safe_haven/commands/admin_unregister_users.py index e2fbcb86b0..8c98dfb27c 100644 --- a/data_safe_haven/commands/admin_unregister_users.py +++ b/data_safe_haven/commands/admin_unregister_users.py @@ -1,6 +1,6 @@ """Unregister existing users from a deployed SRE""" from data_safe_haven.administration.users import UserHandler -from data_safe_haven.config import Config +from data_safe_haven.config import Config, ContextSettings from data_safe_haven.exceptions import DataSafeHavenError from data_safe_haven.external import GraphApi from data_safe_haven.functions import alphanumeric @@ -12,16 +12,14 @@ def admin_unregister_users( sre: str, ) -> None: """Unregister existing users from a deployed SRE""" - shm_name = "UNKNOWN" - sre_name = "UNKNOWN" - try: - # Use a JSON-safe SRE name - sre_name = alphanumeric(sre).lower() + context = ContextSettings.from_file().context + config = Config.from_remote(context) - # Load config file - config = Config() - shm_name = config.context.name + shm_name = context.shm_name + # Use a JSON-safe SRE name + sre_name = alphanumeric(sre).lower() + try: # Check that SRE option has been provided if not sre_name: msg = "SRE name must be specified." From 16e73ffd5868f7b70e828fa013dc5ba00b9c7dc7 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Thu, 14 Dec 2023 12:09:02 +0000 Subject: [PATCH 47/65] Add help text for config show command --- data_safe_haven/commands/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index a7271f4123..4fc7822528 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -41,6 +41,7 @@ def upload( @config_command_group.command() def show() -> None: + """Print the configuration for the selected Data Safe Haven context""" context = ContextSettings.from_file().context config = Config.from_remote(context) print(config.to_yaml()) From c9ceb39c9acf671263765e6427cd6fe95de0892b Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 14:54:51 +0000 Subject: [PATCH 48/65] Fix import --- data_safe_haven/config/__init__.py | 3 ++- tests_/commands/test_context.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/data_safe_haven/config/__init__.py b/data_safe_haven/config/__init__.py index 4723c0bf08..23e3c4641d 100644 --- a/data_safe_haven/config/__init__.py +++ b/data_safe_haven/config/__init__.py @@ -1,7 +1,8 @@ from .config import Config -from .context_settings import ContextSettings +from .context_settings import Context, ContextSettings __all__ = [ "Config", + "Context", "ContextSettings", ] diff --git a/tests_/commands/test_context.py b/tests_/commands/test_context.py index 5af53b9350..8fe3e23a7d 100644 --- a/tests_/commands/test_context.py +++ b/tests_/commands/test_context.py @@ -1,5 +1,5 @@ from data_safe_haven.commands.context import context_command_group -from data_safe_haven.context import Context +from data_safe_haven.context import ContextInfra class TestShow: @@ -165,7 +165,7 @@ def test_create(self, runner, monkeypatch): def mock_create(self): # noqa: ARG001 print("mock create") # noqa: T201 - monkeypatch.setattr(Context, "create", mock_create) + monkeypatch.setattr(ContextInfra, "create", mock_create) result = runner.invoke(context_command_group, ["create"]) assert "mock create" in result.stdout @@ -177,7 +177,7 @@ def test_teardown(self, runner, monkeypatch): def mock_teardown(self): # noqa: ARG001 print("mock teardown") # noqa: T201 - monkeypatch.setattr(Context, "teardown", mock_teardown) + monkeypatch.setattr(ContextInfra, "teardown", mock_teardown) result = runner.invoke(context_command_group, ["teardown"]) assert "mock teardown" in result.stdout From 9b91daf3a53247148bd58f061e09770e05f8d06c Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 14:35:24 +0000 Subject: [PATCH 49/65] Allow no selected context --- data_safe_haven/commands/admin_add_users.py | 2 +- data_safe_haven/commands/admin_list_users.py | 2 +- .../commands/admin_register_users.py | 2 +- .../commands/admin_remove_users.py | 2 +- .../commands/admin_unregister_users.py | 2 +- data_safe_haven/commands/config.py | 6 +-- data_safe_haven/commands/context.py | 20 +++++----- data_safe_haven/commands/deploy.py | 4 +- data_safe_haven/commands/teardown.py | 4 +- data_safe_haven/config/context_settings.py | 38 +++++++++++++------ 10 files changed, 50 insertions(+), 32 deletions(-) diff --git a/data_safe_haven/commands/admin_add_users.py b/data_safe_haven/commands/admin_add_users.py index acb5631365..bf8b7ab446 100644 --- a/data_safe_haven/commands/admin_add_users.py +++ b/data_safe_haven/commands/admin_add_users.py @@ -9,7 +9,7 @@ def admin_add_users(csv_path: pathlib.Path) -> None: """Add users to a deployed Data Safe Haven""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) shm_name = context.shm_name diff --git a/data_safe_haven/commands/admin_list_users.py b/data_safe_haven/commands/admin_list_users.py index 1bf7e522d3..25bf5f11ce 100644 --- a/data_safe_haven/commands/admin_list_users.py +++ b/data_safe_haven/commands/admin_list_users.py @@ -7,7 +7,7 @@ def admin_list_users() -> None: """List users from a deployed Data Safe Haven""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) shm_name = context.shm_name diff --git a/data_safe_haven/commands/admin_register_users.py b/data_safe_haven/commands/admin_register_users.py index fcc4e68a4c..18fca79b58 100644 --- a/data_safe_haven/commands/admin_register_users.py +++ b/data_safe_haven/commands/admin_register_users.py @@ -12,7 +12,7 @@ def admin_register_users( sre: str, ) -> None: """Register existing users with a deployed SRE""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) shm_name = context.shm_name diff --git a/data_safe_haven/commands/admin_remove_users.py b/data_safe_haven/commands/admin_remove_users.py index 19d9f420ff..8718e891dd 100644 --- a/data_safe_haven/commands/admin_remove_users.py +++ b/data_safe_haven/commands/admin_remove_users.py @@ -9,7 +9,7 @@ def admin_remove_users( usernames: list[str], ) -> None: """Remove existing users from a deployed Data Safe Haven""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) shm_name = context.shm_name diff --git a/data_safe_haven/commands/admin_unregister_users.py b/data_safe_haven/commands/admin_unregister_users.py index 8c98dfb27c..e7620949c7 100644 --- a/data_safe_haven/commands/admin_unregister_users.py +++ b/data_safe_haven/commands/admin_unregister_users.py @@ -12,7 +12,7 @@ def admin_unregister_users( sre: str, ) -> None: """Unregister existing users from a deployed SRE""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) shm_name = context.shm_name diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index 4fc7822528..9873cdf6ee 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -18,7 +18,7 @@ def template( ] = None ) -> None: """Write a template Data Safe Haven configuration.""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.template(context) if file: with open(file, "w") as outfile: @@ -32,7 +32,7 @@ def upload( file: Annotated[Path, typer.Argument(help="Path to configuration file")] ) -> None: """Upload a configuration to the Data Safe Haven context""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() with open(file) as config_file: config_yaml = config_file.read() config = Config.from_yaml(context, config_yaml) @@ -42,6 +42,6 @@ def upload( @config_command_group.command() def show() -> None: """Print the configuration for the selected Data Safe Haven context""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) print(config.to_yaml()) diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index 10a49334bd..0c345879cb 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -18,13 +18,14 @@ def show() -> None: settings = ContextSettings.from_file() current_context_key = settings.selected - current_context = settings.context + current_context = settings.assert_context() print(f"Current context: [green]{current_context_key}") - print(f"\tName: {current_context.name}") - print(f"\tAdmin Group ID: {current_context.admin_group_id}") - print(f"\tSubscription name: {current_context.subscription_name}") - print(f"\tLocation: {current_context.location}") + if current_context is not None: + print(f"\tName: {current_context.name}") + print(f"\tAdmin Group ID: {current_context.admin_group_id}") + print(f"\tSubscription name: {current_context.subscription_name}") + print(f"\tLocation: {current_context.location}") @context_command_group.command() @@ -35,8 +36,9 @@ def available() -> None: current_context_key = settings.selected available = settings.available - available.remove(current_context_key) - available = [f"[green]{current_context_key}*[/]", *available] + if current_context_key is not None: + available.remove(current_context_key) + available = [f"[green]{current_context_key}*[/]", *available] print("\n".join(available)) @@ -158,7 +160,7 @@ def remove( @context_command_group.command() def create() -> None: """Create Data Safe Haven context infrastructure.""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() context_infra = ContextInfra(context) context_infra.create() @@ -166,6 +168,6 @@ def create() -> None: @context_command_group.command() def teardown() -> None: """Tear down Data Safe Haven context infrastructure.""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() context_infra = ContextInfra(context) context_infra.teardown() diff --git a/data_safe_haven/commands/deploy.py b/data_safe_haven/commands/deploy.py index 5f2d0b12e2..89b4c2d827 100644 --- a/data_safe_haven/commands/deploy.py +++ b/data_safe_haven/commands/deploy.py @@ -25,7 +25,7 @@ def shm( ] = None, ) -> None: """Deploy a Safe Haven Management component""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) try: @@ -98,7 +98,7 @@ def sre( ] = None, ) -> None: """Deploy a Secure Research Environment""" - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) try: diff --git a/data_safe_haven/commands/teardown.py b/data_safe_haven/commands/teardown.py index e77af4e8f6..33158357dd 100644 --- a/data_safe_haven/commands/teardown.py +++ b/data_safe_haven/commands/teardown.py @@ -19,7 +19,7 @@ help="Tear down a deployed a Safe Haven Management component." ) def shm() -> None: - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) try: @@ -48,7 +48,7 @@ def shm() -> None: def sre( name: Annotated[str, typer.Argument(help="Name of SRE to teardown.")], ) -> None: - context = ContextSettings.from_file().context + context = ContextSettings.from_file().assert_context() config = Config.from_remote(context) sre_name = alphanumeric(name).lower() diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index d37ca48613..1c48165dbc 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -81,24 +81,25 @@ class ContextSettings(BaseModel, validate_assignment=True): ... """ - selected_: str = Field(..., alias="selected") + selected_: str | None = Field(..., alias="selected") contexts: dict[str, Context] logger: ClassVar[LoggingSingleton] = LoggingSingleton() @model_validator(mode="after") def ensure_selected_is_valid(self) -> ContextSettings: - if self.selected not in self.available: - msg = f"Selected context '{self.selected}' is not defined." - raise ValueError(msg) + if self.selected is not None: + if self.selected not in self.available: + msg = f"Selected context '{self.selected}' is not defined." + raise ValueError(msg) return self @property - def selected(self) -> str: - return str(self.selected_) + def selected(self) -> str | None: + return self.selected_ @selected.setter - def selected(self, context_name: str) -> None: - if context_name in self.available: + def selected(self, context_name: str | None) -> None: + if context_name in self.available or context_name is None: self.selected_ = context_name self.logger.info(f"Switched context to '{context_name}'.") else: @@ -106,8 +107,18 @@ def selected(self, context_name: str) -> None: raise DataSafeHavenParameterError(msg) @property - def context(self) -> Context: - return self.contexts[self.selected] + def context(self) -> Context | None: + if self.selected is None: + return None + else: + return self.contexts[self.selected] + + def assert_context(self) -> Context: + if context := self.context: + return context + else: + msg = "No context selected" + raise DataSafeHavenConfigError(msg) @property def available(self) -> list[str]: @@ -121,7 +132,7 @@ def update( name: str | None = None, subscription_name: str | None = None, ) -> None: - context = self.contexts[self.selected] + context = self.assert_context() if admin_group_id: self.logger.debug( @@ -165,8 +176,13 @@ def remove(self, key: str) -> None: if key not in self.available: msg = f"No context with key '{key}'." raise DataSafeHavenParameterError(msg) + del self.contexts[key] + # Prevent having a deleted context selected + if key == self.selected: + self.selected = None + @classmethod def from_yaml(cls, settings_yaml: str) -> ContextSettings: try: From 9cea20ece0fafdcd3118fc5c46c7eb0fced4c6fd Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 14:59:39 +0000 Subject: [PATCH 50/65] Add test for constructor with no selected context --- tests_/config/test_context_settings.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 41901c4c7e..9fb31cb9fd 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -96,6 +96,16 @@ def test_constructor(self): ) assert isinstance(settings, ContextSettings) + def test_null_selected(self, context_yaml): + context_yaml = context_yaml.replace("selected: acme_deployment", "selected: null") + + settings = ContextSettings.from_yaml(context_yaml) + assert settings.selected is None + assert settings.context is None + with pytest.raises(DataSafeHavenConfigError) as exc: + settings.assert_context() + assert "No context selected" in exc + def test_missing_selected(self, context_yaml): context_yaml = "\n".join(context_yaml.splitlines()[1:]) From c2f3de59c69fce3fd389034e53067cf8a3fd4b15 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 15:04:07 +0000 Subject: [PATCH 51/65] Add assert context tests --- tests_/config/test_context_settings.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 9fb31cb9fd..a558ba6131 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -169,6 +169,16 @@ def test_set_context(self, context_yaml, context_settings): for item in yaml_dict["contexts"]["gems"].keys() ) + def test_assert_context(self, context_settings): + context = context_settings.assert_context() + assert context.name == "Acme Deployment" + + def test_assert_context_none(self, context_settings): + context_settings.selected = None + with pytest.raises(DataSafeHavenConfigError) as exc: + context_settings.assert_context() + assert "No context selected" in exc + def test_available(self, context_settings): available = context_settings.available assert isinstance(available, list) From 4d291a115d4c9eb0bca7df40a7a520829e84f3fd Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 15:07:27 +0000 Subject: [PATCH 52/65] Add test for selecting no context --- tests_/config/test_context_settings.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index a558ba6131..ba7745a4c8 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -169,6 +169,11 @@ def test_set_context(self, context_yaml, context_settings): for item in yaml_dict["contexts"]["gems"].keys() ) + def test_set_context_none(self, context_settings): + context_settings.selected = None + assert context_settings.selected is None + assert context_settings.context is None + def test_assert_context(self, context_settings): context = context_settings.assert_context() assert context.name == "Acme Deployment" From 9321eb88e03fdbfc10b1bd67c9df5c76e2d437b1 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 15:13:39 +0000 Subject: [PATCH 53/65] Add update test for no selected context --- tests_/config/test_context_settings.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index ba7745a4c8..ffda53bf0b 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -201,6 +201,12 @@ def test_set_update(self, context_settings): context_settings.update(name="replaced") assert context_settings.context.name == "replaced" + def test_update_none(self, context_settings): + context_settings.selected = None + with pytest.raises(DataSafeHavenConfigError) as exc: + context_settings.update(name="replaced") + assert "No context selected" in exc + def test_add(self, context_settings): context_settings.add( key="example", From 8576f50c5172adeb7dd495cce02d353ccd3656a1 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 15:16:03 +0000 Subject: [PATCH 54/65] Add test for removing selected context --- tests_/config/test_context_settings.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index ffda53bf0b..c0fc025d02 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -232,14 +232,20 @@ def test_invalid_add(self, context_settings): assert "A context with key 'acme' is already defined." in exc def test_remove(self, context_settings): - context_settings.remove("acme_deployment") - assert "acme_deployment" not in context_settings.available + context_settings.remove("gems") + assert "gems" not in context_settings.available + assert context_settings.selected == "acme_deployment" def test_invalid_remove(self, context_settings): with pytest.raises(DataSafeHavenParameterError) as exc: context_settings.remove("invalid") assert "No context with key 'invalid'." in exc + def test_remove_selected(self, context_settings): + context_settings.remove("acme_deployment") + assert "acme_deployment" not in context_settings.available + assert context_settings.selected is None + def test_from_file(self, tmp_path, context_yaml): config_file_path = tmp_path / "config.yaml" with open(config_file_path, "w") as f: From 3486091862c928cf99b48724655a922c24fd4f28 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 15:44:27 +0000 Subject: [PATCH 55/65] Allow showing no selected context --- data_safe_haven/commands/context.py | 2 +- tests_/commands/test_context.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/data_safe_haven/commands/context.py b/data_safe_haven/commands/context.py index 0c345879cb..5774670465 100644 --- a/data_safe_haven/commands/context.py +++ b/data_safe_haven/commands/context.py @@ -18,7 +18,7 @@ def show() -> None: settings = ContextSettings.from_file() current_context_key = settings.selected - current_context = settings.assert_context() + current_context = settings.context print(f"Current context: [green]{current_context_key}") if current_context is not None: diff --git a/tests_/commands/test_context.py b/tests_/commands/test_context.py index 8fe3e23a7d..66010b0164 100644 --- a/tests_/commands/test_context.py +++ b/tests_/commands/test_context.py @@ -9,6 +9,13 @@ def test_show(self, runner): assert "Current context: acme_deployment" in result.stdout assert "Name: Acme Deployment" in result.stdout + def test_show_none(self, runner): + result = runner.invoke(context_command_group, ["remove", "acme_deployment"]) + assert result.exit_code == 0 + result = runner.invoke(context_command_group, ["show"]) + assert result.exit_code == 0 + assert "Current context: None" in result.stdout + class TestAvailable: def test_available(self, runner): From 512f386518799377a7e07e024e5c0d2134c3e69c Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 15 Dec 2023 15:52:51 +0000 Subject: [PATCH 56/65] Add test for available with no selected context --- tests_/commands/conftest.py | 29 +++++++++++++++++++++++++++-- tests_/commands/test_context.py | 12 ++++++++---- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/tests_/commands/conftest.py b/tests_/commands/conftest.py index ba563e900d..9fe7e6f88c 100644 --- a/tests_/commands/conftest.py +++ b/tests_/commands/conftest.py @@ -1,7 +1,10 @@ from pytest import fixture from typer.testing import CliRunner -context_settings = """\ + +@fixture +def context_settings(): + return """\ selected: acme_deployment contexts: acme_deployment: @@ -17,7 +20,16 @@ @fixture -def tmp_contexts(tmp_path): +def tmp_contexts(tmp_path, context_settings): + config_file_path = tmp_path / "contexts.yaml" + with open(config_file_path, "w") as f: + f.write(context_settings) + return tmp_path + + +@fixture +def tmp_contexts_none(tmp_path, context_settings): + context_settings = context_settings.replace("selected: acme_deployment", "selected: null") config_file_path = tmp_path / "contexts.yaml" with open(config_file_path, "w") as f: f.write(context_settings) @@ -35,3 +47,16 @@ def runner(tmp_contexts): mix_stderr=False, ) return runner + + +@fixture +def runner_none(tmp_contexts_none): + runner = CliRunner( + env={ + "DSH_CONFIG_DIRECTORY": str(tmp_contexts_none), + "COLUMNS": "500", # Set large number of columns to avoid rich wrapping text + "TERM": "dumb", # Disable colours, style and interactive rich features + }, + mix_stderr=False, + ) + return runner diff --git a/tests_/commands/test_context.py b/tests_/commands/test_context.py index 66010b0164..919daddc9d 100644 --- a/tests_/commands/test_context.py +++ b/tests_/commands/test_context.py @@ -9,10 +9,8 @@ def test_show(self, runner): assert "Current context: acme_deployment" in result.stdout assert "Name: Acme Deployment" in result.stdout - def test_show_none(self, runner): - result = runner.invoke(context_command_group, ["remove", "acme_deployment"]) - assert result.exit_code == 0 - result = runner.invoke(context_command_group, ["show"]) + def test_show_none(self, runner_none): + result = runner_none.invoke(context_command_group, ["show"]) assert result.exit_code == 0 assert "Current context: None" in result.stdout @@ -24,6 +22,12 @@ def test_available(self, runner): assert "acme_deployment*" in result.stdout assert "gems" in result.stdout + def test_available_none(self, runner_none): + result = runner_none.invoke(context_command_group, ["available"]) + assert result.exit_code == 0 + assert "acme_deployment" in result.stdout + assert "gems" in result.stdout + class TestSwitch: def test_switch(self, runner): From 430926b031718c1418d35b3e7b00db3672186608 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 15 Jan 2024 14:51:52 +0000 Subject: [PATCH 57/65] Add fqdn validation --- data_safe_haven/config/config.py | 3 ++- data_safe_haven/functions/validators.py | 9 +++++++ data_safe_haven/utility/annotated_types.py | 4 ++- pyproject.toml | 1 + tests_/functions/test_validators.py | 30 +++++++++++++++++++++- 5 files changed, 44 insertions(+), 3 deletions(-) diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 4faf134abd..864b37478f 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -36,6 +36,7 @@ AzureLocation, AzureVmSku, EmailAdress, + Fqdn, Guid, IpAddress, TimeZone, @@ -64,7 +65,7 @@ class ConfigSectionSHM(BaseModel, validate_assignment=True): aad_tenant_id: Guid admin_email_address: EmailAdress admin_ip_addresses: list[IpAddress] - fqdn: str + fqdn: Fqdn name: str = Field(..., exclude=True) timezone: TimeZone diff --git a/data_safe_haven/functions/validators.py b/data_safe_haven/functions/validators.py index a1526acd83..d1d8a67eae 100644 --- a/data_safe_haven/functions/validators.py +++ b/data_safe_haven/functions/validators.py @@ -1,6 +1,7 @@ import ipaddress import re +import fqdn import pytz @@ -28,6 +29,14 @@ def validate_azure_vm_sku(azure_vm_sku: str) -> str: return azure_vm_sku +def validate_fqdn(domain: str) -> str: + trial_fqdn = fqdn.FQDN(domain) + if not trial_fqdn.is_valid: + msg = "Expected valid fully qualified domain name, for example 'example.com'." + raise ValueError(msg) + return domain + + def validate_email_address(email_address: str) -> str: if not re.match(r"^\S+@\S+$", email_address): msg = "Expected valid email address, for example 'sherlock@holmes.com'." diff --git a/data_safe_haven/utility/annotated_types.py b/data_safe_haven/utility/annotated_types.py index 1db246dabf..68c1c124c6 100644 --- a/data_safe_haven/utility/annotated_types.py +++ b/data_safe_haven/utility/annotated_types.py @@ -3,11 +3,12 @@ from pydantic import Field from pydantic.functional_validators import AfterValidator -from data_safe_haven.functions import ( +from data_safe_haven.functions.validators import ( validate_aad_guid, validate_azure_location, validate_azure_vm_sku, validate_email_address, + validate_fqdn, validate_ip_address, validate_timezone, ) @@ -17,6 +18,7 @@ AzureLocation = Annotated[str, AfterValidator(validate_azure_location)] AzureVmSku = Annotated[str, AfterValidator(validate_azure_vm_sku)] EmailAdress = Annotated[str, AfterValidator(validate_email_address)] +Fqdn = Annotated[str, AfterValidator(validate_fqdn)] Guid = Annotated[str, AfterValidator(validate_aad_guid)] IpAddress = Annotated[str, AfterValidator(validate_ip_address)] TimeZone = Annotated[str, AfterValidator(validate_timezone)] diff --git a/pyproject.toml b/pyproject.toml index 01b90cc932..2246f75b88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "chili~=2.1", "cryptography~=41.0", "dnspython~=2.3", + "fqdn~=1.5", "msal~=1.21", "psycopg~=3.1", "pulumi~=3.80", diff --git a/tests_/functions/test_validators.py b/tests_/functions/test_validators.py index c8351fdcc7..2588f2d24a 100644 --- a/tests_/functions/test_validators.py +++ b/tests_/functions/test_validators.py @@ -1,6 +1,6 @@ import pytest -from data_safe_haven.functions.validators import validate_aad_guid +from data_safe_haven.functions.validators import validate_aad_guid, validate_fqdn class TestValidateAadGuid: @@ -25,3 +25,31 @@ def test_validate_aad_guid_fail(self, guid): with pytest.raises(ValueError) as exc: validate_aad_guid(guid) assert "Expected GUID" in exc + + +class TestValidateFqdn: + @pytest.mark.parametrize( + "fqdn", + [ + "shm.acme.com", + "example.com", + "a.b.c.com.", + "a-b-c.com", + ], + ) + def test_validate_fqdn(self, fqdn): + assert validate_fqdn(fqdn) == fqdn + + @pytest.mark.parametrize( + "fqdn", + [ + "invalid", + "%example.com", + "a b c.com", + "a_b_c.com", + ], + ) + def test_validate_fqdn_fail(self, fqdn): + with pytest.raises(ValueError) as exc: + validate_fqdn(fqdn) + assert "Expected valid fully qualified domain name" in exc From 7d6c50d07e7b6d1018c20db90113b4af27a1253b Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 15 Jan 2024 15:18:22 +0000 Subject: [PATCH 58/65] Fix invalid fqdn in tests --- tests_/config/test_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests_/config/test_config.py b/tests_/config/test_config.py index bfe102e2e6..fcb4b3ea9b 100644 --- a/tests_/config/test_config.py +++ b/tests_/config/test_config.py @@ -75,8 +75,8 @@ def test_constructor(self, context): def test_update(self, shm_config): assert shm_config.fqdn == "shm.acme.com" - shm_config.update(fqdn="modified") - assert shm_config.fqdn == "modified" + shm_config.update(fqdn="shm.example.com") + assert shm_config.fqdn == "shm.example.com" def test_update_validation(self, shm_config): with pytest.raises(ValidationError) as exc: From 3990fdc587e046d156659a09578b0a4dd21185fb Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 15 Jan 2024 15:21:04 +0000 Subject: [PATCH 59/65] Fix test failing due to env variable --- tests_/config/test_context_settings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 41901c4c7e..7732c6ff86 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -40,7 +40,8 @@ def test_invalid_subscription_name(self, context_dict): def test_shm_name(self, context): assert context.shm_name == "acmedeployment" - def test_work_directory(self, context): + def test_work_directory(self, context, monkeypatch): + monkeypatch.delenv("DSH_CONFIG_DIRECTORY", raising=False) assert "data_safe_haven/acmedeployment" in str(context.work_directory) def test_config_filename(self, context): From 1493dcbb87baae801520914f3fff0ffed2731039 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 15 Jan 2024 15:39:25 +0000 Subject: [PATCH 60/65] Run lint:fmt --- tests_/commands/conftest.py | 4 +++- tests_/config/test_context_settings.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests_/commands/conftest.py b/tests_/commands/conftest.py index 9fe7e6f88c..61e7b2fe26 100644 --- a/tests_/commands/conftest.py +++ b/tests_/commands/conftest.py @@ -29,7 +29,9 @@ def tmp_contexts(tmp_path, context_settings): @fixture def tmp_contexts_none(tmp_path, context_settings): - context_settings = context_settings.replace("selected: acme_deployment", "selected: null") + context_settings = context_settings.replace( + "selected: acme_deployment", "selected: null" + ) config_file_path = tmp_path / "contexts.yaml" with open(config_file_path, "w") as f: f.write(context_settings) diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index 4a6a17d74d..fe08df07bf 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -98,7 +98,9 @@ def test_constructor(self): assert isinstance(settings, ContextSettings) def test_null_selected(self, context_yaml): - context_yaml = context_yaml.replace("selected: acme_deployment", "selected: null") + context_yaml = context_yaml.replace( + "selected: acme_deployment", "selected: null" + ) settings = ContextSettings.from_yaml(context_yaml) assert settings.selected is None From 50d83a939ae0c5177147b8dd6333872d7915719c Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 15 Jan 2024 16:02:01 +0000 Subject: [PATCH 61/65] Print Ruff version in CI --- .github/workflows/lint_code.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/lint_code.yaml b/.github/workflows/lint_code.yaml index 84de63b302..79dbac6bdc 100644 --- a/.github/workflows/lint_code.yaml +++ b/.github/workflows/lint_code.yaml @@ -59,6 +59,8 @@ jobs: python-version: 3.11 - name: Install hatch run: pip install hatch + - name: Print Ruff version + run: ruff --version - name: Lint Python run: hatch run lint:all From 672d7df570cdceffdf082cf9f6f46b5615ce1e2d Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Mon, 15 Jan 2024 16:07:11 +0000 Subject: [PATCH 62/65] Fix linting errors --- .github/workflows/lint_code.yaml | 2 +- .../administration/users/active_directory_users.py | 4 ++-- data_safe_haven/administration/users/azure_ad_users.py | 4 ++-- data_safe_haven/administration/users/guacamole_users.py | 2 +- data_safe_haven/administration/users/user_handler.py | 4 ++-- pyproject.toml | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/lint_code.yaml b/.github/workflows/lint_code.yaml index 79dbac6bdc..56ca0d6b7a 100644 --- a/.github/workflows/lint_code.yaml +++ b/.github/workflows/lint_code.yaml @@ -60,7 +60,7 @@ jobs: - name: Install hatch run: pip install hatch - name: Print Ruff version - run: ruff --version + run: hatch run lint:ruff --version - name: Lint Python run: hatch run lint:all diff --git a/data_safe_haven/administration/users/active_directory_users.py b/data_safe_haven/administration/users/active_directory_users.py index 192cef3549..8a60fe9b40 100644 --- a/data_safe_haven/administration/users/active_directory_users.py +++ b/data_safe_haven/administration/users/active_directory_users.py @@ -77,7 +77,7 @@ def add(self, new_users: Sequence[ResearchUser]) -> None: for line in output.split("\n"): self.logger.parse(line) - def list(self, sre_name: str | None = None) -> Sequence[ResearchUser]: # noqa: A003 + def list(self, sre_name: str | None = None) -> Sequence[ResearchUser]: """List users in a local Active Directory""" list_users_script = FileReader( self.resources_path / "active_directory" / "list_users.ps1" @@ -142,7 +142,7 @@ def remove(self, users: Sequence[ResearchUser]) -> None: for line in output.split("\n"): self.logger.parse(line) - def set(self, users: Sequence[ResearchUser]) -> None: # noqa: A003 + def set(self, users: Sequence[ResearchUser]) -> None: """Set local Active Directory users to specified list""" users_to_remove = [user for user in self.list() if user not in users] self.remove(users_to_remove) diff --git a/data_safe_haven/administration/users/azure_ad_users.py b/data_safe_haven/administration/users/azure_ad_users.py index 15bf9a58c8..33141e9923 100644 --- a/data_safe_haven/administration/users/azure_ad_users.py +++ b/data_safe_haven/administration/users/azure_ad_users.py @@ -56,7 +56,7 @@ def add(self, new_users: Sequence[ResearchUser]) -> None: # # Also add the user to the research users group # self.graph_api.add_user_to_group(user.username, self.researchers_group_name) - def list(self) -> Sequence[ResearchUser]: # noqa: A003 + def list(self) -> Sequence[ResearchUser]: user_list = self.graph_api.read_users() return [ ResearchUser( @@ -105,7 +105,7 @@ def remove(self, users: Sequence[ResearchUser]) -> None: # ) pass - def set(self, users: Sequence[ResearchUser]) -> None: # noqa: A003 + def set(self, users: Sequence[ResearchUser]) -> None: """Set Guacamole users to specified list""" users_to_remove = [user for user in self.list() if user not in users] self.remove(users_to_remove) diff --git a/data_safe_haven/administration/users/guacamole_users.py b/data_safe_haven/administration/users/guacamole_users.py index 6b05a47c54..3df0f0a89f 100644 --- a/data_safe_haven/administration/users/guacamole_users.py +++ b/data_safe_haven/administration/users/guacamole_users.py @@ -30,7 +30,7 @@ def __init__(self, config: Config, sre_name: str, *args: Any, **kwargs: Any): self.sre_name = sre_name self.group_name = f"Data Safe Haven SRE {sre_name} Users" - def list(self) -> Sequence[ResearchUser]: # noqa: A003 + def list(self) -> Sequence[ResearchUser]: """List all Guacamole users""" if self.users_ is None: # Allow for the possibility of an empty list of users postgres_output = self.postgres_provisioner.execute_scripts( diff --git a/data_safe_haven/administration/users/user_handler.py b/data_safe_haven/administration/users/user_handler.py index e51238d160..561b799d36 100644 --- a/data_safe_haven/administration/users/user_handler.py +++ b/data_safe_haven/administration/users/user_handler.py @@ -97,7 +97,7 @@ def get_usernames_guacamole(self, sre_name: str) -> list[str]: self.logger.error(f"Could not load users for SRE '{sre_name}'.") return [] - def list(self) -> None: # noqa: A003 + def list(self) -> None: """List Active Directory, AzureAD and Guacamole users Raises: @@ -157,7 +157,7 @@ def remove(self, user_names: Sequence[str]) -> None: msg = f"Could not remove users: {user_names}.\n{exc}" raise DataSafeHavenUserHandlingError(msg) from exc - def set(self, users_csv_path: str) -> None: # noqa: A003 + def set(self, users_csv_path: str) -> None: """Set AzureAD and Guacamole users Raises: diff --git a/pyproject.toml b/pyproject.toml index 2246f75b88..eba6fee70f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dependencies = [ "black>=23.1.0", "mypy>=1.0.0", "pydantic>=2.4", - "ruff>=0.0.243", + "ruff>=0.1.0", "types-appdirs>=1.4.3.5", "types-chevron>=0.14.2.5", "types-pytz>=2023.3.0.0", From 3b5d55221f9b384bcbeb96d00f1245a4941e0d07 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 16 Jan 2024 11:05:30 +0000 Subject: [PATCH 63/65] Add fqdn typings --- typings/fqdn/__init__.pyi | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 typings/fqdn/__init__.pyi diff --git a/typings/fqdn/__init__.pyi b/typings/fqdn/__init__.pyi new file mode 100644 index 0000000000..61cf4a40d7 --- /dev/null +++ b/typings/fqdn/__init__.pyi @@ -0,0 +1,7 @@ +from typing import Any + + +class FQDN: + def __init__(self, fqdn: Any, *nothing: list[Any], **kwags: dict[Any, Any]) -> None: ... + @property + def is_valid(self) -> bool: ... From 2b9b268c7a33b17c27e82e2541f92f08d7cfb5d8 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Tue, 16 Jan 2024 11:21:04 +0000 Subject: [PATCH 64/65] Add typing hint --- data_safe_haven/external/api/azure_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_safe_haven/external/api/azure_api.py b/data_safe_haven/external/api/azure_api.py index 4024217912..fc36d9469b 100644 --- a/data_safe_haven/external/api/azure_api.py +++ b/data_safe_haven/external/api/azure_api.py @@ -712,7 +712,7 @@ def get_storage_account_keys( if not isinstance(storage_keys, StorageAccountListKeysResult): msg = f"Could not connect to {msg_sa} in {msg_rg}." raise DataSafeHavenAzureError(msg) - keys = storage_keys.keys + keys: list[StorageAccountKey] = storage_keys.keys if not keys or not isinstance(keys, list) or len(keys) == 0: msg = f"No keys were retrieved for {msg_sa} in {msg_rg}." raise DataSafeHavenAzureError(msg) From ed4ec999217b1435908cb10a66d8ede5073eb3b6 Mon Sep 17 00:00:00 2001 From: Jim Madge Date: Fri, 19 Jan 2024 09:00:36 +0000 Subject: [PATCH 65/65] Update data_safe_haven/README.md Co-authored-by: Matt Craddock --- data_safe_haven/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/data_safe_haven/README.md b/data_safe_haven/README.md index 13cadd9f02..07c414dcad 100644 --- a/data_safe_haven/README.md +++ b/data_safe_haven/README.md @@ -11,6 +11,7 @@ Install the following requirements before starting ```console > dsh context add ... +> dsh context switch ... > dsh context create ```