diff --git a/data_safe_haven/config/config.py b/data_safe_haven/config/config.py index 543f49812e..b8cf991b10 100644 --- a/data_safe_haven/config/config.py +++ b/data_safe_haven/config/config.py @@ -8,17 +8,13 @@ from azure.keyvault.keys import KeyVaultKey from pydantic import ( BaseModel, - ConfigDict, Field, - FieldSerializationInfo, ValidationError, - field_serializer, field_validator, ) from yaml import YAMLError from data_safe_haven import __version__ -from data_safe_haven.config.context_settings import Context from data_safe_haven.exceptions import ( DataSafeHavenConfigError, DataSafeHavenParameterError, @@ -43,6 +39,8 @@ TimeZone, ) +from .context_settings import Context + class ConfigSectionAzure(BaseModel, validate_assignment=True): admin_group_id: Guid = Field(..., exclude=True) @@ -150,7 +148,6 @@ def update( class ConfigSectionSRE(BaseModel, validate_assignment=True): - model_config = ConfigDict(use_enum_values=True) databases: list[DatabaseSystem] = Field(..., default_factory=list[DatabaseSystem]) data_provider_ip_addresses: list[IpAddress] = Field( ..., default_factory=list[IpAddress] @@ -168,22 +165,13 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True): @field_validator("databases") @classmethod def all_databases_must_be_unique( - cls, v: list[DatabaseSystem] + cls, v: list[DatabaseSystem | str] ) -> list[DatabaseSystem]: - if len(v) != len(set(v)): + v_ = [DatabaseSystem(d) for d in v] + if len(v_) != len(set(v_)): msg = "all databases must be unique" raise ValueError(msg) - return v - - @field_serializer("software_packages") - def software_packages_serializer( - self, - packages: SoftwarePackageCategory | str, - info: FieldSerializationInfo, # noqa: ARG002 - ) -> str: - if isinstance(packages, str): - packages = SoftwarePackageCategory(packages) - return packages.value + return v_ def update( self, @@ -389,7 +377,7 @@ def from_yaml(cls, context: Context, config_yaml: str) -> Config: config_dict[section]["context"] = context try: - return Config.model_validate(config_dict) + return Config.model_validate(config_dict, strict=True) except ValidationError as exc: msg = f"Could not load configuration.\n{exc}" raise DataSafeHavenParameterError(msg) from exc @@ -406,7 +394,7 @@ def from_remote(cls, context: Context) -> Config: return Config.from_yaml(context, config_yaml) def to_yaml(self) -> str: - return yaml.dump(self.model_dump(), indent=2) + return yaml.dump(self.model_dump(mode="json"), indent=2) def upload(self) -> None: """Upload config to Azure storage"""