Skip to content

Commit

Permalink
WIP: Introduce annotated types for validation
Browse files Browse the repository at this point in the history
  • Loading branch information
JimMadge committed Nov 15, 2023
1 parent 9093b97 commit d374b1c
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 113 deletions.
11 changes: 8 additions & 3 deletions data_safe_haven/config/context_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@
DataSafeHavenParameterError,
)
from data_safe_haven.utility import LoggingSingleton, config_dir
from data_safe_haven.utility.annotated_types import (
AzureLocation,
AzureLongName,
Guid,
)


def default_config_file_path() -> Path:
return config_dir() / "contexts.yaml"


class Context(BaseModel):
admin_group_id: str
location: str
admin_group_id: Guid
location: AzureLocation
name: str
subscription_name: str
subscription_name: AzureLongName


class ContextSettings(BaseModel):
Expand Down
10 changes: 0 additions & 10 deletions data_safe_haven/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
validate_azure_vm_sku,
validate_email_address,
validate_ip_address,
validate_list,
validate_non_empty_list,
validate_non_empty_string,
validate_string_length,
validate_timezone,
validate_type,
)

__all__ = [
Expand All @@ -54,10 +49,5 @@
"validate_azure_vm_sku",
"validate_email_address",
"validate_ip_address",
"validate_list",
"validate_non_empty_list",
"validate_non_empty_string",
"validate_string_length",
"validate_timezone",
"validate_type",
]
119 changes: 26 additions & 93 deletions data_safe_haven/functions/validators.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,50 @@
import ipaddress
import re
from collections.abc import Callable
from typing import Any

import pytz
import typer


def validate_aad_guid(aad_guid: str | None) -> str | None:
if aad_guid is not None:
if not re.match(
r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$",
aad_guid,
):
msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'."
raise typer.BadParameter(msg)
def validate_aad_guid(aad_guid: str) -> str:
if not re.match(
r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$",
aad_guid,
):
msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'."
raise ValueError(msg)
return aad_guid


def validate_azure_location(azure_location: str | None) -> str | None:
if azure_location is not None:
if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location):
msg = "Expected valid Azure location, for example 'uksouth'."
raise typer.BadParameter(msg)
def validate_azure_location(azure_location: str) -> str:
if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location):
msg = "Expected valid Azure location, for example 'uksouth'."
raise ValueError(msg)
return azure_location


def validate_azure_vm_sku(azure_vm_sku: str | None) -> str | None:
if azure_vm_sku is not None:
if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku):
msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'."
raise typer.BadParameter(msg)
def validate_azure_vm_sku(azure_vm_sku: str) -> str:
if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku):
msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'."
raise ValueError(msg)
return azure_vm_sku


def validate_email_address(email_address: str | None) -> str | None:
if email_address is not None:
if not re.match(r"^\S+@\S+$", email_address):
msg = "Expected valid email address, for example '[email protected]'."
raise typer.BadParameter(msg)
def validate_email_address(email_address: str) -> str:
if not re.match(r"^\S+@\S+$", email_address):
msg = "Expected valid email address, for example '[email protected]'."
raise ValueError(msg)
return email_address


def validate_ip_address(
ip_address: str | None,
) -> str | None:
def validate_ip_address(ip_address: str) -> str:
try:
if ip_address:
return str(ipaddress.ip_network(ip_address))
return None
return str(ipaddress.ip_network(ip_address))
except Exception as exc:
msg = "Expected valid IPv4 address, for example '1.1.1.1'."
raise typer.BadParameter(msg) from exc
raise ValueError(msg) from exc


def validate_list(
value: list[Any],
validator: Callable[[Any], Any] | None = None,
*,
allow_empty: bool = False,
) -> list[Any]:
try:
if not allow_empty:
validate_non_empty_list(value)
if validator:
for element in value:
validator(element)
return value
except Exception as exc:
msg = f"Expected valid list.\n{exc}"
raise typer.BadParameter(msg) from exc


def validate_non_empty_list(value: list[Any]) -> list[Any]:
if len(value) == 0:
msg = "Expected non-empty list."
raise typer.BadParameter(msg)
return value


def validate_non_empty_string(value: Any) -> str:
try:
return validate_string_length(value, min_length=1)
except Exception as exc:
msg = "Expected non-empty string."
raise typer.BadParameter(msg) from exc


def validate_string_length(
value: Any, min_length: int | None = None, max_length: int | None = None
) -> str:
if isinstance(value, str):
if min_length and len(value) < min_length:
msg = f"Expected string with minimum length {min_length}."
raise typer.BadParameter(msg)
if max_length and len(value) > max_length:
msg = f"Expected string with maximum length {max_length}."
raise typer.BadParameter(msg)
return str(value)
msg = "Expected string."
raise typer.BadParameter(msg)


def validate_timezone(timezone: str | None) -> str | None:
if timezone is not None:
if timezone not in pytz.all_timezones:
msg = "Expected valid timezone, for example 'Europe/London'."
raise typer.BadParameter(msg)
def validate_timezone(timezone: str) -> str:
if timezone not in pytz.all_timezones:
msg = "Expected valid timezone, for example 'Europe/London'."
raise ValueError(msg)
return timezone


def validate_type(value: Any, type_: type) -> Any:
if not isinstance(value, type_):
msg = f"Expected type '{type_.__name__}' but received '{type(value).__name__}'."
raise typer.BadParameter(msg)
return value
22 changes: 22 additions & 0 deletions data_safe_haven/utility/annotated_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Annotated

from pydantic import Field
from pydantic.functional_validators import AfterValidator

from data_safe_haven.functions import (
validate_aad_guid,
validate_azure_location,
validate_azure_vm_sku,
validate_email_address,
validate_ip_address,
validate_timezone,
)

AzureShortName = Annotated[str, Field(min_length=1, max_length=24)]
AzureLongName = Annotated[str, Field(min_length=1, max_length=64)]
AzureLocation = Annotated[str, AfterValidator(validate_azure_location)]
AzureVmSku = Annotated[str, AfterValidator(validate_azure_vm_sku)]
EmailAdress = Annotated[str, AfterValidator(validate_email_address)]
Guid = Annotated[str, AfterValidator(validate_aad_guid)]
IpAddress = Annotated[str, AfterValidator(validate_ip_address)]
TimeZone = Annotated[str, AfterValidator(validate_timezone)]
37 changes: 30 additions & 7 deletions tests_/config/test_context_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,46 @@

import pytest
import yaml
from pydantic import ValidationError
from pytest import fixture


@fixture
def context_dict():
return {
"admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd",
"location": "uksouth",
"name": "Acme Deployment",
"subscription_name": "Data Safe Haven (Acme)"
}


class TestContext:
def test_constructor(self):
context_dict = {
"admin_group_id": "d5c5c439-1115-4cb6-ab50-b8e547b6c8dd",
"location": "uksouth",
"name": "Acme Deployment",
"subscription_name": "Data Safe Haven (Acme)"
}
def test_constructor(self, context_dict):
context = Context(**context_dict)
assert isinstance(context, Context)
assert all([
getattr(context, item) == context_dict[item] for item in context_dict.keys()
])

def test_invalid_guid(self, context_dict):
context_dict["admin_group_id"] = "not a guid"
with pytest.raises(ValidationError) as exc:
Context(**context_dict)
assert "Value error, Expected GUID, for example" in exc

def test_invalid_location(self, context_dict):
context_dict["location"] = "not_a_location"
with pytest.raises(ValidationError) as exc:
Context(**context_dict)
assert "Value error, Expected valid Azure location" in exc

def test_invalid_subscription_name(self, context_dict):
context_dict["subscription_name"] = "very "*12 + "long name"
with pytest.raises(ValidationError) as exc:
Context(**context_dict)
assert "String should have at most 64 characters" in exc


@fixture
def context_yaml():
Expand Down

0 comments on commit d374b1c

Please sign in to comment.