Skip to content

Commit

Permalink
refactor indexer to not allow possibility of misinterpreting None
Browse files Browse the repository at this point in the history
failing 40 test wip commit

all unit tests passing

changes found during manual testing of workflows

fix remaining unit tests
  • Loading branch information
walshmm committed Nov 25, 2024
1 parent bb8cf23 commit 1b109e3
Show file tree
Hide file tree
Showing 40 changed files with 582 additions and 640 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ markers = [
"integration: mark a test as an integration test",
"mount_snap: mark a test as using /SNS/SNAP/ data mount",
"golden_data(*, path=None, short_name=None, date=None): mark golden data to use with a test",
"datarepo: mark a test as using snapred-data repo"
"datarepo: mark a test as using snapred-data repo",
"ui: mark a test as a UI test",
]
# The following will be overridden by the commandline option "-m integration"
addopts = "-m 'not (integration or datarepo)'"
Expand Down
91 changes: 30 additions & 61 deletions src/snapred/backend/dao/indexing/Versioning.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,47 @@
from typing import Any, Optional

from numpy import integer
from pydantic import BaseModel, computed_field, field_serializer
from pydantic import BaseModel, ConfigDict, field_validator
from snapred.meta.Config import Config
from snapred.meta.Enum import StrEnum

VERSION_START = Config["version.start"]
VERSION_NONE_NAME = Config["version.friendlyName.error"]
VERSION_DEFAULT_NAME = Config["version.friendlyName.default"]

# VERSION_DEFAULT is a SNAPRed-internal "magic" integer:
# * it is implicitely set during `Config` initialization.
VERSION_DEFAULT = Config["version.default"]

class VersionState(StrEnum):
DEFAULT = Config["version.friendlyName.default"]
LATEST = "latest"
NEXT = "next"


Version = int | VersionState


class VersionedObject(BaseModel):
# Base class for all versioned DAO

# In pydantic, a leading double underscore activates
# the `__pydantic_private__` feature, which limits the visibility
# of the attribute to the interior scope of its own class.
__version: Optional[int] = None
version: Version

@classmethod
def parseVersion(cls, version, *, exclude_none: bool = False, exclude_default: bool = False) -> int | None:
v: int | None
# handle two special cases
if (not exclude_none) and (version is None or version == VERSION_NONE_NAME):
v = None
elif (not exclude_default) and (version == VERSION_DEFAULT_NAME or version == VERSION_DEFAULT):
v = VERSION_DEFAULT
# parse integers
elif isinstance(version, int | integer):
if int(version) >= VERSION_START:
v = int(version)
else:
raise ValueError(f"Given version {version} is smaller than start version {VERSION_START}")
# otherwise this is an error
else:
raise ValueError(f"Cannot initialize version as {version}")
return v
@field_validator("version", mode="before")
def validate_version(cls, value: Version) -> Version:
if value in VersionState.values():
return value

@classmethod
def writeVersion(cls, version) -> int | str:
v: int | str
if version is None:
v = VERSION_NONE_NAME
elif version == VERSION_DEFAULT:
v = VERSION_DEFAULT_NAME
elif isinstance(version, int | integer):
v = int(version)
else:
raise ValueError("Version is not valid")
return v
if isinstance(value, str):
raise ValueError(f"Version must be an int or {VersionState.values()}")

def __init__(self, **kwargs):
version = kwargs.pop("version", None)
super().__init__(**kwargs)
self.__version = self.parseVersion(version)
if value is None:
raise ValueError("Version must be specified")

@field_serializer("version", check_fields=False, when_used="json")
def write_user_defaults(self, value: Any): # noqa ARG002
return self.writeVersion(self.__version)
if value < VERSION_START:
raise ValueError(f"Version must be greater than {VERSION_START}")

# NOTE some serialization still using the dict() method
def dict(self, **kwargs):
res = super().dict(**kwargs)
res["version"] = self.writeVersion(res["version"])
return res
return value

@computed_field
@property
def version(self) -> int:
return self.__version
# NOTE: This approach was taken because 'field_serializer' was checking against the
# INITIAL value of version for some reason. This is a workaround.
#
def model_dump_json(self, *args, **kwargs): # noqa ARG002
if self.version in VersionState.values():
raise ValueError(f"Version {self.version} must be flattened to an int before writing to JSON")
return super().model_dump_json(*args, **kwargs)

@version.setter
def version(self, v):
self.__version = self.parseVersion(v, exclude_none=True)
model_config = ConfigDict(use_enum_values=True, validate_assignment=True)
12 changes: 4 additions & 8 deletions src/snapred/backend/dao/normalization/NormalizationRecord.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, List

from pydantic import field_serializer, field_validator
from pydantic import field_validator

from snapred.backend.dao.indexing.Record import Record
from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT, VersionedObject
from snapred.backend.dao.indexing.Versioning import VERSION_START, Version, VersionedObject
from snapred.backend.dao.Limit import Limit
from snapred.backend.dao.normalization.Normalization import Normalization
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName
Expand Down Expand Up @@ -31,7 +31,7 @@ class NormalizationRecord(Record, extra="ignore"):
smoothingParameter: float
# detectorPeaks: List[DetectorPeak] # TODO: need to save this for reference during reduction
workspaceNames: List[WorkspaceName] = []
calibrationVersionUsed: int = VERSION_DEFAULT
calibrationVersionUsed: Version = VERSION_START
crystalDBounds: Limit[float]
normalizationCalibrantSamplePath: str

Expand All @@ -44,8 +44,4 @@ def validate_backgroundRunNumber(cls, v: Any) -> Any:
@field_validator("calibrationVersionUsed", mode="before")
@classmethod
def version_is_integer(cls, v: Any) -> Any:
return VersionedObject.parseVersion(v)

@field_serializer("calibrationVersionUsed", when_used="json")
def write_user_defaults(self, value: Any): # noqa ARG002
return VersionedObject.writeVersion(self.calibrationVersionUsed)
return isinstance(VersionedObject(version=v).version, int)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.calibration.FocusGroupMetric import FocusGroupMetric
from snapred.backend.dao.CrystallographicInfo import CrystallographicInfo
from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.state.PixelGroup import PixelGroup
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName, WorkspaceType

Expand All @@ -18,7 +19,7 @@ class CreateCalibrationRecordRequest(BaseModel, extra="forbid"):

runNumber: str
useLiteMode: bool
version: Optional[int] = None
version: Version = VersionState.NEXT
calculationParameters: Calibration
crystalInfo: CrystallographicInfo
pixelGroups: Optional[List[PixelGroup]] = None
Expand Down
6 changes: 3 additions & 3 deletions src/snapred/backend/dao/request/CreateIndexEntryRequest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from pydantic import BaseModel

from snapred.backend.dao.indexing.Versioning import Version, VersionState


class CreateIndexEntryRequest(BaseModel):
"""
Expand All @@ -10,7 +10,7 @@ class CreateIndexEntryRequest(BaseModel):

runNumber: str
useLiteMode: bool
version: Optional[int] = None
version: Version = VersionState.NEXT
comments: str
author: str
appliesTo: str
11 changes: 8 additions & 3 deletions src/snapred/backend/dao/request/FarmFreshIngredients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pydantic import BaseModel, ConfigDict, ValidationError, field_validator, model_validator

from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.Limit import Limit, Pair
from snapred.backend.dao.state import FocusGroup
from snapred.meta.Config import Config
from snapred.meta.mantid.AllowedPeakTypes import SymmetricPeakEnum

# TODO: this declaration is duplicated in `ReductionRequest`.
Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])])
Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)])


class FarmFreshIngredients(BaseModel):
Expand All @@ -21,7 +22,7 @@ class FarmFreshIngredients(BaseModel):

runNumber: str

versions: Versions = Versions(None, None)
versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST)

# allow 'versions' to be accessed as a single version,
# or, to be accessed ambiguously
Expand Down Expand Up @@ -83,6 +84,10 @@ def focusGroup(self, fg: FocusGroup):
def validate_versions(cls, v) -> Versions:
if not isinstance(v, Versions):
v = Versions(v)
if v.calibration is None:
raise ValueError("Calibration version must be specified")
if v.normalization is None:
raise ValueError("Normalization version must be specified")
return v

@field_validator("crystalDBounds", mode="before")
Expand Down Expand Up @@ -119,4 +124,4 @@ def validate_focusGroups(cls, v: Any):
del v["focusGroup"]
return v

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", validate_assignment=True)
9 changes: 7 additions & 2 deletions src/snapred/backend/dao/request/ReductionRequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pydantic import BaseModel, ConfigDict, field_validator

from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.ingredients import ArtificialNormalizationIngredients
from snapred.backend.dao.state.FocusGroup import FocusGroup
from snapred.backend.error.ContinueWarning import ContinueWarning
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceNameGenerator as wng

Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])])
Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)])


class ReductionRequest(BaseModel):
Expand All @@ -22,7 +23,7 @@ class ReductionRequest(BaseModel):

# Calibration and normalization versions:
# `None` => <use latest version>
versions: Versions = Versions(None, None)
versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST)

pixelMasks: List[WorkspaceName] = []
artificialNormalizationIngredients: Optional[ArtificialNormalizationIngredients] = None
Expand All @@ -37,6 +38,10 @@ def validate_versions(cls, v) -> Versions:
if not isinstance(v, Tuple):
raise ValueError("'versions' must be a tuple: '(<calibration version>, <normalization version>)'")
v = Versions(v)
if v.calibration is None:
raise ValueError("Calibration version must be specified")
if v.normalization is None:
raise ValueError("Normalization version must be specified")
return v

model_config = ConfigDict(
Expand Down
10 changes: 5 additions & 5 deletions src/snapred/backend/data/DataExportService.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from pathlib import Path
from typing import Tuple
from typing import Optional, Tuple

from pydantic import validate_call

Expand Down Expand Up @@ -64,11 +64,11 @@ def exportCalibrationIndexEntry(self, entry: IndexEntry):
"""
self.dataService.writeCalibrationIndexEntry(entry)

def exportCalibrationRecord(self, record: CalibrationRecord):
def exportCalibrationRecord(self, record: CalibrationRecord, entry: Optional[IndexEntry] = None):
"""
Record must have correct version set.
"""
self.dataService.writeCalibrationRecord(record)
self.dataService.writeCalibrationRecord(record, entry)

def exportCalibrationWorkspaces(self, record: CalibrationRecord):
"""
Expand All @@ -94,11 +94,11 @@ def exportNormalizationIndexEntry(self, entry: IndexEntry):
"""
self.dataService.writeNormalizationIndexEntry(entry)

def exportNormalizationRecord(self, record: NormalizationRecord):
def exportNormalizationRecord(self, record: NormalizationRecord, entry: Optional[IndexEntry] = None):
"""
Record must have correct version set.
"""
self.dataService.writeNormalizationRecord(record)
self.dataService.writeNormalizationRecord(record, entry)

def exportNormalizationWorkspaces(self, record: NormalizationRecord):
"""
Expand Down
Loading

0 comments on commit 1b109e3

Please sign in to comment.