-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Introduce annotated types for validation
- Loading branch information
Showing
5 changed files
with
86 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,117 +1,50 @@ | ||
import ipaddress | ||
import re | ||
from collections.abc import Callable | ||
from typing import Any | ||
|
||
import pytz | ||
import typer | ||
|
||
|
||
def validate_aad_guid(aad_guid: str | None) -> str | None: | ||
if aad_guid is not None: | ||
if not re.match( | ||
r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$", | ||
aad_guid, | ||
): | ||
msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'." | ||
raise typer.BadParameter(msg) | ||
def validate_aad_guid(aad_guid: str) -> str: | ||
if not re.match( | ||
r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$", | ||
aad_guid, | ||
): | ||
msg = "Expected GUID, for example '10de18e7-b238-6f1e-a4ad-772708929203'." | ||
raise ValueError(msg) | ||
return aad_guid | ||
|
||
|
||
def validate_azure_location(azure_location: str | None) -> str | None: | ||
if azure_location is not None: | ||
if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location): | ||
msg = "Expected valid Azure location, for example 'uksouth'." | ||
raise typer.BadParameter(msg) | ||
def validate_azure_location(azure_location: str) -> str: | ||
if not re.match(r"^[a-z]+[0-9]?[a-z]*$", azure_location): | ||
msg = "Expected valid Azure location, for example 'uksouth'." | ||
raise ValueError(msg) | ||
return azure_location | ||
|
||
|
||
def validate_azure_vm_sku(azure_vm_sku: str | None) -> str | None: | ||
if azure_vm_sku is not None: | ||
if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku): | ||
msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'." | ||
raise typer.BadParameter(msg) | ||
def validate_azure_vm_sku(azure_vm_sku: str) -> str: | ||
if not re.match(r"^(Standard|Basic)_\w+$", azure_vm_sku): | ||
msg = "Expected valid Azure VM SKU, for example 'Standard_D2s_v4'." | ||
raise ValueError(msg) | ||
return azure_vm_sku | ||
|
||
|
||
def validate_email_address(email_address: str | None) -> str | None: | ||
if email_address is not None: | ||
if not re.match(r"^\S+@\S+$", email_address): | ||
msg = "Expected valid email address, for example '[email protected]'." | ||
raise typer.BadParameter(msg) | ||
def validate_email_address(email_address: str) -> str: | ||
if not re.match(r"^\S+@\S+$", email_address): | ||
msg = "Expected valid email address, for example '[email protected]'." | ||
raise ValueError(msg) | ||
return email_address | ||
|
||
|
||
def validate_ip_address( | ||
ip_address: str | None, | ||
) -> str | None: | ||
def validate_ip_address(ip_address: str) -> str: | ||
try: | ||
if ip_address: | ||
return str(ipaddress.ip_network(ip_address)) | ||
return None | ||
return str(ipaddress.ip_network(ip_address)) | ||
except Exception as exc: | ||
msg = "Expected valid IPv4 address, for example '1.1.1.1'." | ||
raise typer.BadParameter(msg) from exc | ||
raise ValueError(msg) from exc | ||
|
||
|
||
def validate_list( | ||
value: list[Any], | ||
validator: Callable[[Any], Any] | None = None, | ||
*, | ||
allow_empty: bool = False, | ||
) -> list[Any]: | ||
try: | ||
if not allow_empty: | ||
validate_non_empty_list(value) | ||
if validator: | ||
for element in value: | ||
validator(element) | ||
return value | ||
except Exception as exc: | ||
msg = f"Expected valid list.\n{exc}" | ||
raise typer.BadParameter(msg) from exc | ||
|
||
|
||
def validate_non_empty_list(value: list[Any]) -> list[Any]: | ||
if len(value) == 0: | ||
msg = "Expected non-empty list." | ||
raise typer.BadParameter(msg) | ||
return value | ||
|
||
|
||
def validate_non_empty_string(value: Any) -> str: | ||
try: | ||
return validate_string_length(value, min_length=1) | ||
except Exception as exc: | ||
msg = "Expected non-empty string." | ||
raise typer.BadParameter(msg) from exc | ||
|
||
|
||
def validate_string_length( | ||
value: Any, min_length: int | None = None, max_length: int | None = None | ||
) -> str: | ||
if isinstance(value, str): | ||
if min_length and len(value) < min_length: | ||
msg = f"Expected string with minimum length {min_length}." | ||
raise typer.BadParameter(msg) | ||
if max_length and len(value) > max_length: | ||
msg = f"Expected string with maximum length {max_length}." | ||
raise typer.BadParameter(msg) | ||
return str(value) | ||
msg = "Expected string." | ||
raise typer.BadParameter(msg) | ||
|
||
|
||
def validate_timezone(timezone: str | None) -> str | None: | ||
if timezone is not None: | ||
if timezone not in pytz.all_timezones: | ||
msg = "Expected valid timezone, for example 'Europe/London'." | ||
raise typer.BadParameter(msg) | ||
def validate_timezone(timezone: str) -> str: | ||
if timezone not in pytz.all_timezones: | ||
msg = "Expected valid timezone, for example 'Europe/London'." | ||
raise ValueError(msg) | ||
return timezone | ||
|
||
|
||
def validate_type(value: Any, type_: type) -> Any: | ||
if not isinstance(value, type_): | ||
msg = f"Expected type '{type_.__name__}' but received '{type(value).__name__}'." | ||
raise typer.BadParameter(msg) | ||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import Annotated | ||
|
||
from pydantic import Field | ||
from pydantic.functional_validators import AfterValidator | ||
|
||
from data_safe_haven.functions import ( | ||
validate_aad_guid, | ||
validate_azure_location, | ||
validate_azure_vm_sku, | ||
validate_email_address, | ||
validate_ip_address, | ||
validate_timezone, | ||
) | ||
|
||
AzureShortName = Annotated[str, Field(min_length=1, max_length=24)] | ||
AzureLongName = Annotated[str, Field(min_length=1, max_length=64)] | ||
AzureLocation = Annotated[str, AfterValidator(validate_azure_location)] | ||
AzureVmSku = Annotated[str, AfterValidator(validate_azure_vm_sku)] | ||
EmailAdress = Annotated[str, AfterValidator(validate_email_address)] | ||
Guid = Annotated[str, AfterValidator(validate_aad_guid)] | ||
IpAddress = Annotated[str, AfterValidator(validate_ip_address)] | ||
TimeZone = Annotated[str, AfterValidator(validate_timezone)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters