Skip to content

Commit

Permalink
🐛 Fix enum serialisation by using model_dump('json') and dropping exp…
Browse files Browse the repository at this point in the history
…licit field serialisation
  • Loading branch information
jemrobinson committed Jan 25, 2024
1 parent f7cc4f6 commit 1566f2d
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions data_safe_haven/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@
from azure.keyvault.keys import KeyVaultKey
from pydantic import (
BaseModel,
ConfigDict,
Field,
FieldSerializationInfo,
ValidationError,
field_serializer,
field_validator,
)
from yaml import YAMLError

from data_safe_haven import __version__
from data_safe_haven.config.context_settings import Context
from data_safe_haven.exceptions import (
DataSafeHavenConfigError,
DataSafeHavenParameterError,
Expand All @@ -43,6 +39,8 @@
TimeZone,
)

from .context_settings import Context


class ConfigSectionAzure(BaseModel, validate_assignment=True):
admin_group_id: Guid = Field(..., exclude=True)
Expand Down Expand Up @@ -150,7 +148,6 @@ def update(


class ConfigSectionSRE(BaseModel, validate_assignment=True):
model_config = ConfigDict(use_enum_values=True)
databases: list[DatabaseSystem] = Field(..., default_factory=list[DatabaseSystem])
data_provider_ip_addresses: list[IpAddress] = Field(
..., default_factory=list[IpAddress]
Expand All @@ -168,22 +165,13 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
@field_validator("databases")
@classmethod
def all_databases_must_be_unique(
cls, v: list[DatabaseSystem]
cls, v: list[DatabaseSystem | str]
) -> list[DatabaseSystem]:
if len(v) != len(set(v)):
v_ = [DatabaseSystem(d) for d in v]
if len(v_) != len(set(v_)):
msg = "all databases must be unique"
raise ValueError(msg)
return v

@field_serializer("software_packages")
def software_packages_serializer(
self,
packages: SoftwarePackageCategory | str,
info: FieldSerializationInfo, # noqa: ARG002
) -> str:
if isinstance(packages, str):
packages = SoftwarePackageCategory(packages)
return packages.value
return v_

def update(
self,
Expand Down Expand Up @@ -389,7 +377,7 @@ def from_yaml(cls, context: Context, config_yaml: str) -> Config:
config_dict[section]["context"] = context

try:
return Config.model_validate(config_dict)
return Config.model_validate(config_dict, strict=True)
except ValidationError as exc:
msg = f"Could not load configuration.\n{exc}"
raise DataSafeHavenParameterError(msg) from exc
Expand All @@ -406,7 +394,7 @@ def from_remote(cls, context: Context) -> Config:
return Config.from_yaml(context, config_yaml)

def to_yaml(self) -> str:
return yaml.dump(self.model_dump(), indent=2)
return yaml.dump(self.model_dump(mode="json"), indent=2)

def upload(self) -> None:
"""Upload config to Azure storage"""
Expand Down

0 comments on commit 1566f2d

Please sign in to comment.