From a1f3ebae1df71dfc2f9401e5c083a2b92827cd3b Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Tue, 2 Jan 2024 10:55:09 +0000 Subject: [PATCH] API(fields): simplify column definition for fields (#94) This simplifies the column definition in `Field` subclasses so that they only need to provide a list of required and optional column names. All checking is done in the base class. Closes: #91 --- examples/example.ipynb | 4 +- heracles/fields.py | 86 +++++++++++++++++++++--------------------- tests/test_fields.py | 6 +-- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/examples/example.ipynb b/examples/example.ipynb index 6af497b..1c9cd13 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -534,13 +534,13 @@ { "data": { "text/html": [ - "
/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n",
+       "
/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:247: UserWarning: position and visibility maps have \n",
        "different NSIDE\n",
        "  warnings.warn(\"position and visibility maps have different NSIDE\")\n",
        "
\n" ], "text/plain": [ - "/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n", + "/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:247: UserWarning: position and visibility maps have \n", "different NSIDE\n", " warnings.warn(\"position and visibility maps have different NSIDE\")\n" ] diff --git a/heracles/fields.py b/heracles/fields.py index 650547f..1ca8053 100644 --- a/heracles/fields.py +++ b/heracles/fields.py @@ -31,7 +31,7 @@ from .core import update_metadata if TYPE_CHECKING: - from collections.abc import AsyncIterable, Mapping + from collections.abc import AsyncIterable, Mapping, Sequence from typing import Any from numpy.typing import ArrayLike @@ -53,38 +53,63 @@ class Field(metaclass=ABCMeta): """ + # column names: "col1", "col2", "[optional]" + uses: Sequence[str] | str | None = None + # every field subclass has a static spin weight attribute, which can be # overwritten by the class (or even an individual instance) __spin: int | None = None + # definition of required and optional columns + __ncol: tuple[int, int] + def __init_subclass__(cls, *, spin: int | None = None) -> None: """Initialise spin weight of field subclasses.""" super().__init_subclass__() if spin is not None: cls.__spin = spin + uses = cls.uses + if uses is None: + uses = () + elif isinstance(uses, str): + uses = (uses,) + ncol = len(uses) + nopt = 0 + for u in uses[::-1]: + if u.startswith("[") and u.endswith("]"): + nopt += 1 + else: + break + cls.__ncol = (ncol - nopt, ncol) def __init__(self, *columns: str) -> None: """Initialise the field.""" super().__init__() - self.__columns: Columns | None - if columns: - try: - self.__columns = self._init_columns(*columns) - except TypeError as exc: - msg = str(exc).replace("_init_columns", "__init__") - raise TypeError(msg) from None - else: - self.__columns = None + self.__columns = self._init_columns(*columns) if columns else None self._metadata: dict[str, Any] = {} if (spin := self.__spin) is not None: self._metadata["spin"] = spin - @staticmethod - @abstractmethod - def _init_columns(*columns: str) -> Columns: + @classmethod + def _init_columns(cls, *columns: str) -> Columns: """Initialise the given set of columns for a specific field subclass.""" - ... + nmin, nmax = cls.__ncol + if not nmin <= len(columns) <= nmax: + uses = cls.uses + if uses is None: + uses = () + if isinstance(uses, str): + uses = (uses,) + count = f"{nmin}" + if nmax != nmin: + count += f" to {nmax}" + msg = f"field of type '{cls.__name__}' accepts {count} columns" + if uses: + msg += " (" + ", ".join(uses) + ")" + msg += f", received {len(columns)}" + raise ValueError(msg) + return columns + (None,) * (nmax - len(columns)) @property def columns(self) -> Columns | None: @@ -155,6 +180,8 @@ class Positions(Field, spin=0): """ + uses = "longitude", "latitude" + def __init__( self, *columns: str, @@ -166,10 +193,6 @@ def __init__( self.__overdensity = overdensity self.__nbar = nbar - @staticmethod - def _init_columns(lon: str, lat: str) -> Columns: - return lon, lat - @property def overdensity(self) -> bool: """Flag to create overdensity maps.""" @@ -269,14 +292,7 @@ async def __call__( class ScalarField(Field, spin=0): """Field of real scalar values in a catalogue.""" - @staticmethod - def _init_columns( - lon: str, - lat: str, - value: str, - weight: str | None = None, - ) -> Columns: - return lon, lat, value, weight + uses = "longitude", "latitude", "value", "[weight]" async def __call__( self, @@ -354,15 +370,7 @@ class ComplexField(Field, spin=0): """ - @staticmethod - def _init_columns( - lon: str, - lat: str, - real: str, - imag: str, - weight: str | None = None, - ) -> Columns: - return lon, lat, real, imag, weight + uses = "longitude", "latitude", "real", "imag", "[weight]" async def __call__( self, @@ -435,10 +443,6 @@ async def __call__( class Visibility(Field, spin=0): """Copy visibility map from catalogue at given resolution.""" - @staticmethod - def _init_columns() -> Columns: - return () - async def __call__( self, catalog: Catalog, @@ -475,9 +479,7 @@ async def __call__( class Weights(Field, spin=0): """Field of weight values from a catalogue.""" - @staticmethod - def _init_columns(lon: str, lat: str, weight: str | None = None) -> Columns: - return lon, lat, weight + uses = "longitude", "latitude", "[weight]" async def __call__( self, diff --git a/tests/test_fields.py b/tests/test_fields.py index 0c876d9..1a465d4 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -95,9 +95,7 @@ async def __call__(self): f.spin class TestField(Field, spin=0): - @staticmethod - def _init_columns(lon, lat, weight=None) -> Columns: - return lon, lat, weight + uses = "lon", "lat", "[weight]" async def __call__(self): pass @@ -114,7 +112,7 @@ async def __call__(self): with pytest.raises(ValueError): f.columns_or_error - with pytest.raises(TypeError, match=r"__init__\(\) missing 1 required"): + with pytest.raises(ValueError, match="accepts 2 to 3 columns"): TestField("lon") f = TestField("lon", "lat")