diff --git a/data_safe_haven/config/context_settings.py b/data_safe_haven/config/context_settings.py index 795bfa1d72..10dab336c0 100644 --- a/data_safe_haven/config/context_settings.py +++ b/data_safe_haven/config/context_settings.py @@ -16,6 +16,11 @@ DataSafeHavenParameterError, ) from data_safe_haven.utility import LoggingSingleton, config_dir +from data_safe_haven.utility.annotated_types import ( + AzureLocation, + AzureLongName, + Guid, +) def default_config_file_path() -> Path: @@ -23,10 +28,10 @@ def default_config_file_path() -> Path: class Context(BaseModel): - admin_group_id: str - location: str + admin_group_id: Guid + location: AzureLocation name: str - subscription_name: str + subscription_name: AzureLongName class ContextSettings(BaseModel): diff --git a/data_safe_haven/functions/__init__.py b/data_safe_haven/functions/__init__.py index fcdb02ce72..6b444728ff 100644 --- a/data_safe_haven/functions/__init__.py +++ b/data_safe_haven/functions/__init__.py @@ -24,12 +24,7 @@ validate_azure_vm_sku, validate_email_address, validate_ip_address, - validate_list, - validate_non_empty_list, - validate_non_empty_string, - validate_string_length, validate_timezone, - validate_type, ) __all__ = [ @@ -54,10 +49,5 @@ "validate_azure_vm_sku", "validate_email_address", "validate_ip_address", - "validate_list", - "validate_non_empty_list", - "validate_non_empty_string", - "validate_string_length", "validate_timezone", - "validate_type", ] diff --git a/data_safe_haven/functions/validators.py b/data_safe_haven/functions/validators.py index da41d3db24..a1526acd83 100644 --- a/data_safe_haven/functions/validators.py +++ b/data_safe_haven/functions/validators.py @@ -1,117 +1,50 @@ import ipaddress import re -from collections.abc import Callable -from typing import Any import pytz -import typer -def validate_aad_guid(aad_guid: str | None) -> str | None: - if aad_guid is not None: - if not re.match( - r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$", - aad_guid, - ): - msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'." - raise typer.BadParameter(msg) +def validate_aad_guid(aad_guid: str) -> str: + if not re.match( + r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$", + aad_guid, + ): + msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'." + raise ValueError(msg) return aad_guid -def validate_azure_location(azure_location: str | None) -> str | None: - if azure_location is not None: - if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location): - msg = "Expected valid Azure location, for example 'uksouth'." - raise typer.BadParameter(msg) +def validate_azure_location(azure_location: str) -> str: + if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location): + msg = "Expected valid Azure location, for example 'uksouth'." + raise ValueError(msg) return azure_location -def validate_azure_vm_sku(azure_vm_sku: str | None) -> str | None: - if azure_vm_sku is not None: - if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku): - msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'." - raise typer.BadParameter(msg) +def validate_azure_vm_sku(azure_vm_sku: str) -> str: + if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku): + msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'." + raise ValueError(msg) return azure_vm_sku -def validate_email_address(email_address: str | None) -> str | None: - if email_address is not None: - if not re.match(r"^\S+@\S+$", email_address): - msg = "Expected valid email address, for example 'sherlock@holmes.com'." - raise typer.BadParameter(msg) +def validate_email_address(email_address: str) -> str: + if not re.match(r"^\S+@\S+$", email_address): + msg = "Expected valid email address, for example 'sherlock@holmes.com'." + raise ValueError(msg) return email_address -def validate_ip_address( - ip_address: str | None, -) -> str | None: +def validate_ip_address(ip_address: str) -> str: try: - if ip_address: - return str(ipaddress.ip_network(ip_address)) - return None + return str(ipaddress.ip_network(ip_address)) except Exception as exc: msg = "Expected valid IPv4 address, for example '1.1.1.1'." - raise typer.BadParameter(msg) from exc + raise ValueError(msg) from exc -def validate_list( - value: list[Any], - validator: Callable[[Any], Any] | None = None, - *, - allow_empty: bool = False, -) -> list[Any]: - try: - if not allow_empty: - validate_non_empty_list(value) - if validator: - for element in value: - validator(element) - return value - except Exception as exc: - msg = f"Expected valid list.\n{exc}" - raise typer.BadParameter(msg) from exc - - -def validate_non_empty_list(value: list[Any]) -> list[Any]: - if len(value) == 0: - msg = "Expected non-empty list." - raise typer.BadParameter(msg) - return value - - -def validate_non_empty_string(value: Any) -> str: - try: - return validate_string_length(value, min_length=1) - except Exception as exc: - msg = "Expected non-empty string." - raise typer.BadParameter(msg) from exc - - -def validate_string_length( - value: Any, min_length: int | None = None, max_length: int | None = None -) -> str: - if isinstance(value, str): - if min_length and len(value) < min_length: - msg = f"Expected string with minimum length {min_length}." - raise typer.BadParameter(msg) - if max_length and len(value) > max_length: - msg = f"Expected string with maximum length {max_length}." - raise typer.BadParameter(msg) - return str(value) - msg = "Expected string." - raise typer.BadParameter(msg) - - -def validate_timezone(timezone: str | None) -> str | None: - if timezone is not None: - if timezone not in pytz.all_timezones: - msg = "Expected valid timezone, for example 'Europe/London'." - raise typer.BadParameter(msg) +def validate_timezone(timezone: str) -> str: + if timezone not in pytz.all_timezones: + msg = "Expected valid timezone, for example 'Europe/London'." + raise ValueError(msg) return timezone - - -def validate_type(value: Any, type_: type) -> Any: - if not isinstance(value, type_): - msg = f"Expected type '{type_.__name__}' but received '{type(value).__name__}'." - raise typer.BadParameter(msg) - return value diff --git a/data_safe_haven/utility/annotated_types.py b/data_safe_haven/utility/annotated_types.py new file mode 100644 index 0000000000..1db246dabf --- /dev/null +++ b/data_safe_haven/utility/annotated_types.py @@ -0,0 +1,22 @@ +from typing import Annotated + +from pydantic import Field +from pydantic.functional_validators import AfterValidator + +from data_safe_haven.functions import ( + validate_aad_guid, + validate_azure_location, + validate_azure_vm_sku, + validate_email_address, + validate_ip_address, + validate_timezone, +) + +AzureShortName = Annotated[str, Field(min_length=1, max_length=24)] +AzureLongName = Annotated[str, Field(min_length=1, max_length=64)] +AzureLocation = Annotated[str, AfterValidator(validate_azure_location)] +AzureVmSku = Annotated[str, AfterValidator(validate_azure_vm_sku)] +EmailAdress = Annotated[str, AfterValidator(validate_email_address)] +Guid = Annotated[str, AfterValidator(validate_aad_guid)] +IpAddress = Annotated[str, AfterValidator(validate_ip_address)] +TimeZone = Annotated[str, AfterValidator(validate_timezone)] diff --git a/tests_/config/test_context_settings.py b/tests_/config/test_context_settings.py index e1cc3cfac5..8f04ed3fc1 100644 --- a/tests_/config/test_context_settings.py +++ b/tests_/config/test_context_settings.py @@ -3,23 +3,46 @@ import pytest import yaml +from pydantic import ValidationError from pytest import fixture +@fixture +def context_dict(): + return { + "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", + "location": "uksouth", + "name": "Acme Deployment", + "subscription_name": "Data Safe Haven (Acme)" + } + + class TestContext: - def test_constructor(self): - context_dict = { - "admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd", - "location": "uksouth", - "name": "Acme Deployment", - "subscription_name": "Data Safe Haven (Acme)" - } + def test_constructor(self, context_dict): context = Context(**context_dict) assert isinstance(context, Context) assert all([ getattr(context, item) == context_dict[item] for item in context_dict.keys() ]) + def test_invalid_guid(self, context_dict): + context_dict["admin_group_id"] = "not a guid" + with pytest.raises(ValidationError) as exc: + Context(**context_dict) + assert "Value error, Expected GUID, for example" in exc + + def test_invalid_location(self, context_dict): + context_dict["location"] = "not_a_location" + with pytest.raises(ValidationError) as exc: + Context(**context_dict) + assert "Value error, Expected valid Azure location" in exc + + def test_invalid_subscription_name(self, context_dict): + context_dict["subscription_name"] = "very "*12 + "long name" + with pytest.raises(ValidationError) as exc: + Context(**context_dict) + assert "String should have at most 64 characters" in exc + @fixture def context_yaml():