diff --git a/data_safe_haven/commands/config.py b/data_safe_haven/commands/config.py index a774868516..4d096f45ac 100644 --- a/data_safe_haven/commands/config.py +++ b/data_safe_haven/commands/config.py @@ -107,7 +107,10 @@ def available() -> None: @config_command_group.command() def show( - name: Annotated[str, typer.Argument(help="Name of SRE to show")], + name: Annotated[ + str, + typer.Argument(help="Name of SRE to show"), + ], file: Annotated[ Optional[Path], # noqa: UP007 typer.Option(help="File path to write configuration template to."), diff --git a/data_safe_haven/config/sre_config.py b/data_safe_haven/config/sre_config.py index 9fba89e12f..53adb673e0 100644 --- a/data_safe_haven/config/sre_config.py +++ b/data_safe_haven/config/sre_config.py @@ -4,9 +4,8 @@ from typing import ClassVar, Self -from data_safe_haven.functions import json_safe from data_safe_haven.serialisers import AzureSerialisableModel, ContextBase -from data_safe_haven.types import SafeString, SoftwarePackageCategory +from data_safe_haven.types import SafeSreName, SoftwarePackageCategory from .config_sections import ( ConfigSectionAzure, @@ -18,8 +17,8 @@ def sre_config_name(sre_name: str) -> str: - """Construct a safe YAML filename given an input SRE name.""" - return f"sre-{json_safe(sre_name)}.yaml" + """Construct a YAML filename given an input SRE name.""" + return f"sre-{sre_name}.yaml" class SREConfig(AzureSerialisableModel): @@ -31,7 +30,7 @@ class SREConfig(AzureSerialisableModel): azure: ConfigSectionAzure description: str dockerhub: ConfigSectionDockerHub - name: SafeString + name: SafeSreName sre: ConfigSectionSRE @property diff --git a/data_safe_haven/functions/__init__.py b/data_safe_haven/functions/__init__.py index e11b326135..4a83a76463 100644 --- a/data_safe_haven/functions/__init__.py +++ b/data_safe_haven/functions/__init__.py @@ -3,7 +3,6 @@ alphanumeric, b64encode, get_key_vault_name, - json_safe, next_occurrence, password, replace_separators, @@ -18,7 +17,6 @@ "current_ip_address", "get_key_vault_name", "ip_address_in_list", - "json_safe", "next_occurrence", "password", "replace_separators", diff --git a/data_safe_haven/functions/strings.py b/data_safe_haven/functions/strings.py index 0d5b06b33e..bf229c4f5e 100644 --- a/data_safe_haven/functions/strings.py +++ b/data_safe_haven/functions/strings.py @@ -27,11 +27,6 @@ def get_key_vault_name(stack_name: str) -> str: return f"{''.join(truncate_tokens(stack_name.split('-'), 17))}secrets" -def json_safe(input_string: str) -> str: - """Construct a JSON-safe version of an input string""" - return alphanumeric(input_string).lower() - - def next_occurrence( hour: int, minute: int, timezone: str, *, time_format: str = "iso" ) -> str: diff --git a/data_safe_haven/types/__init__.py b/data_safe_haven/types/__init__.py index e12c76bdc6..728df06c19 100644 --- a/data_safe_haven/types/__init__.py +++ b/data_safe_haven/types/__init__.py @@ -8,6 +8,7 @@ Fqdn, Guid, IpAddress, + SafeSreName, SafeString, TimeZone, UniqueList, @@ -52,6 +53,7 @@ "PathType", "PermittedDomains", "Ports", + "SafeSreName", "SafeString", "SoftwarePackageCategory", "TimeZone", diff --git a/data_safe_haven/types/annotated_types.py b/data_safe_haven/types/annotated_types.py index 639bf03129..d6258b0e7a 100644 --- a/data_safe_haven/types/annotated_types.py +++ b/data_safe_haven/types/annotated_types.py @@ -21,6 +21,7 @@ Fqdn = Annotated[str, AfterValidator(validators.fqdn)] Guid = Annotated[str, AfterValidator(validators.aad_guid)] IpAddress = Annotated[str, AfterValidator(validators.ip_address)] +SafeSreName = Annotated[str, AfterValidator(validators.safe_sre_name)] SafeString = Annotated[str, AfterValidator(validators.safe_string)] TimeZone = Annotated[str, AfterValidator(validators.timezone)] TH = TypeVar("TH", bound=Hashable) diff --git a/data_safe_haven/validators/__init__.py b/data_safe_haven/validators/__init__.py index 849b199857..30316e0834 100644 --- a/data_safe_haven/validators/__init__.py +++ b/data_safe_haven/validators/__init__.py @@ -6,6 +6,7 @@ typer_entra_group_name, typer_fqdn, typer_ip_address, + typer_safe_sre_name, typer_safe_string, typer_timezone, ) @@ -18,6 +19,7 @@ entra_group_name, fqdn, ip_address, + safe_sre_name, safe_string, timezone, unique_list, @@ -32,6 +34,7 @@ "entra_group_name", "fqdn", "ip_address", + "safe_sre_name", "safe_string", "timezone", "typer_aad_guid", @@ -41,6 +44,7 @@ "typer_entra_group_name", "typer_fqdn", "typer_ip_address", + "typer_safe_sre_name", "typer_safe_string", "typer_timezone", "unique_list", diff --git a/data_safe_haven/validators/typer.py b/data_safe_haven/validators/typer.py index f1c8239ecc..fd50774290 100644 --- a/data_safe_haven/validators/typer.py +++ b/data_safe_haven/validators/typer.py @@ -33,5 +33,6 @@ def typer_validator(x: Any) -> Any: typer_entra_group_name = typer_validator_factory(validators.entra_group_name) typer_fqdn = typer_validator_factory(validators.fqdn) typer_ip_address = typer_validator_factory(validators.ip_address) +typer_safe_sre_name = typer_validator_factory(validators.safe_sre_name) typer_safe_string = typer_validator_factory(validators.safe_string) typer_timezone = typer_validator_factory(validators.timezone) diff --git a/data_safe_haven/validators/validators.py b/data_safe_haven/validators/validators.py index dd4458ec57..c9cea495c9 100644 --- a/data_safe_haven/validators/validators.py +++ b/data_safe_haven/validators/validators.py @@ -129,12 +129,19 @@ def ip_address(ip_address: str) -> str: def safe_string(safe_string: str) -> str: - if not re.match(r"^[a-zA-Z0-9_-]*$", safe_string) or not safe_string: + if not re.match(r"^[a-zA-Z0-9_-]+$", safe_string) or not safe_string: msg = "Expected valid string containing only letters, numbers, hyphens and underscores." raise ValueError(msg) return safe_string +def safe_sre_name(safe_sre_name: str) -> str: + if not re.match(r"^[a-z0-9_-]+$", safe_sre_name) or not safe_sre_name: + msg = "Expected valid string containing only lowercase letters, numbers, hyphens and underscores." + raise ValueError(msg) + return safe_sre_name + + def timezone(timezone: str) -> str: if timezone not in pytz.all_timezones: msg = "Expected valid timezone, for example 'Europe/London'." diff --git a/tests/commands/test_config_sre.py b/tests/commands/test_config_sre.py index 7460a908eb..263def9236 100644 --- a/tests/commands/test_config_sre.py +++ b/tests/commands/test_config_sre.py @@ -167,7 +167,7 @@ class TestUploadSRE: def test_upload_new( self, mocker, context, runner, sre_config_yaml, sre_config_file ): - sre_name = "SandBox" + sre_name = "sandbox" sre_filename = sre_config_name(sre_name) mock_exists = mocker.patch.object( SREConfig, "remote_exists", return_value=False @@ -191,7 +191,7 @@ def test_upload_new( def test_upload_no_changes( self, mocker, context, runner, sre_config, sre_config_file ): - sre_name = "SandBox" + sre_name = "sandbox" sre_filename = sre_config_name(sre_name) mock_exists = mocker.patch.object(SREConfig, "remote_exists", return_value=True) mock_from_remote = mocker.patch.object( @@ -249,7 +249,7 @@ def test_upload_changes( def test_upload_changes_n( self, mocker, context, runner, sre_config_alternate, sre_config_file ): - sre_name = "SandBox" + sre_name = "sandbox" sre_filename = sre_config_name(sre_name) mock_exists = mocker.patch.object(SREConfig, "remote_exists", return_value=True) mock_from_remote = mocker.patch.object( @@ -287,7 +287,7 @@ def test_upload_file_does_not_exist(self, mocker, runner): def test_upload_invalid_config( self, mocker, runner, context, sre_config_file, sre_config_yaml ): - sre_name = "SandBox" + sre_name = "sandbox" sre_filename = sre_config_name(sre_name) mock_exists = mocker.patch.object(SREConfig, "remote_exists", return_value=True) @@ -310,7 +310,7 @@ def test_upload_invalid_config( def test_upload_invalid_config_force( self, mocker, runner, context, sre_config_file, sre_config_yaml ): - sre_name = "SandBox" + sre_name = "sandbox" sre_filename = sre_config_name(sre_name) mocker.patch.object( diff --git a/tests/config/test_sre_config.py b/tests/config/test_sre_config.py index 66bd50a40d..7ac6d61981 100644 --- a/tests/config/test_sre_config.py +++ b/tests/config/test_sre_config.py @@ -7,7 +7,6 @@ ConfigSectionDockerHub, ConfigSectionSRE, ) -from data_safe_haven.config.sre_config import sre_config_name from data_safe_haven.exceptions import ( DataSafeHavenTypeError, ) @@ -126,14 +125,5 @@ def test_upload(self, mocker, context, sre_config) -> None: context.storage_container_name, ) - -@pytest.mark.parametrize( - "value,expected", - [ - (r"Test SRE", "sre-testsre.yaml"), - (r"*a^b$c", "sre-abc.yaml"), - (r";'@-", "sre-.yaml"), - ], -) -def test_sre_config_name(value, expected): - assert sre_config_name(value) == expected + def test_sre_config_yaml_name(self, sre_config: SREConfig) -> None: + assert sre_config.filename == "sre-sandbox.yaml" diff --git a/tests/functions/test_strings.py b/tests/functions/test_strings.py index 3e57965d98..7bd12d9490 100644 --- a/tests/functions/test_strings.py +++ b/tests/functions/test_strings.py @@ -4,7 +4,6 @@ from data_safe_haven.exceptions import DataSafeHavenValueError from data_safe_haven.functions import ( get_key_vault_name, - json_safe, next_occurrence, ) @@ -70,11 +69,3 @@ def test_invalid_timeformat(self): ) def test_get_key_vault_name(value, expected): assert get_key_vault_name(value) == expected - - -@pytest.mark.parametrize( - "value,expected", - [(r"Test SRE", "testsre"), (r"%*aBc", "abc"), (r"MY_SRE", "mysre")], -) -def test_json_safe(value, expected): - assert json_safe(value) == expected