diff --git a/examples/example.ipynb b/examples/example.ipynb index 9bda2ff..6af497b 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -534,13 +534,13 @@ { "data": { "text/html": [ - "
/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:214: UserWarning: position and visibility maps have \n", + "/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n", "different NSIDE\n", " warnings.warn(\"position and visibility maps have different NSIDE\")\n", "\n" ], "text/plain": [ - "/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:214: UserWarning: position and visibility maps have \n", + "/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n", "different NSIDE\n", " warnings.warn(\"position and visibility maps have different NSIDE\")\n" ] diff --git a/heracles/fields.py b/heracles/fields.py index 3196eff..f3cd285 100644 --- a/heracles/fields.py +++ b/heracles/fields.py @@ -22,7 +22,6 @@ import warnings from abc import ABCMeta, abstractmethod -from functools import partial from types import MappingProxyType from typing import TYPE_CHECKING @@ -54,10 +53,19 @@ class Field(metaclass=ABCMeta): """ + # every field subclass has a static spin weight attribute, which can be + # overwritten by the class (or even an individual instance) + __spin: int | None = None + + def __init_subclass__(cls, spin: int | None = None) -> None: + """Initialise spin weight of field subclasses.""" + super().__init_subclass__() + if spin is not None: + cls.__spin = spin + def __init__( self, *columns: str, - spin: int = 0, ) -> None: """Initialise the field.""" super().__init__() @@ -70,9 +78,9 @@ def __init__( raise TypeError(msg) from None else: self.__columns = None - self._metadata: dict[str, Any] = { - "spin": spin, - } + self._metadata: dict[str, Any] = {} + if (spin := self.__spin) is not None: + self._metadata["spin"] = spin @staticmethod @abstractmethod @@ -103,7 +111,12 @@ def metadata(self) -> Mapping[str, Any]: @property def spin(self) -> int: """Spin weight of field.""" - return self._metadata["spin"] + spin = self.__spin + if spin is None: + clsname = self.__class__.__name__ + msg = f"field of type '{clsname}' has undefined spin weight" + raise ValueError(msg) + return spin @abstractmethod async def __call__( @@ -137,7 +150,7 @@ async def _pages( await coroutines.sleep() -class Positions(Field): +class Positions(Field, spin=0): """Field of positions in a catalogue. Can produce both overdensity maps and number count maps, depending @@ -152,7 +165,7 @@ def __init__( nbar: float | None = None, ) -> None: """Create a position field.""" - super().__init__(*columns, spin=0) + super().__init__(*columns) self.__overdensity = overdensity self.__nbar = nbar @@ -256,13 +269,9 @@ async def __call__( return pos -class ScalarField(Field): +class ScalarField(Field, spin=0): """Field of real scalar values in a catalogue.""" - def __init__(self, *columns: str) -> None: - """Create a scalar field.""" - super().__init__(*columns, spin=0) - @staticmethod def _init_columns( lon: str, @@ -343,15 +352,11 @@ async def __call__( class ComplexField(Field): """Field of complex values in a catalogue. - Complex fields can have non-zero spin weight, set using the - ``spin=`` parameter. + The :class:`ComplexField` subclasses such as :class:`Spin2Field` + have non-zero spin weight. """ - def __init__(self, *columns: str, spin: int = 0) -> None: - """Create a complex field.""" - super().__init__(*columns, spin=spin) - @staticmethod def _init_columns( lon: str, @@ -430,13 +435,9 @@ async def __call__( return val -class Visibility(Field): +class Visibility(Field, spin=0): """Copy visibility map from catalogue at given resolution.""" - def __init__(self) -> None: - """Create a visibility field.""" - super().__init__(spin=0) - @staticmethod def _init_columns() -> Columns: return () @@ -474,13 +475,9 @@ async def __call__( return vmap -class Weights(Field): +class Weights(Field, spin=0): """Field of weight values from a catalogue.""" - def __init__(self, *columns: str) -> None: - """Create a weight field.""" - super().__init__(*columns, spin=0) - @staticmethod def _init_columns(lon: str, lat: str, weight: str | None = None) -> Columns: return lon, lat, weight @@ -528,6 +525,9 @@ async def __call__( return wht -Spin2Field = partial(ComplexField, spin=2) +class Spin2Field(ComplexField, spin=2): + """Spin-2 complex field.""" + + Shears = Spin2Field Ellipticities = Spin2Field diff --git a/tests/test_fields.py b/tests/test_fields.py index 33b6126..0c876d9 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -82,7 +82,19 @@ def test_field_abc(): with pytest.raises(TypeError): Field() - class TestField(Field): + class SpinLessField(Field): + def _init_columns(self, *columns: str) -> Columns: + return columns + + async def __call__(self): + pass + + f = SpinLessField() + + with pytest.raises(ValueError, match="undefined spin weight"): + f.spin + + class TestField(Field, spin=0): @staticmethod def _init_columns(lon, lat, weight=None) -> Columns: return lon, lat, weight @@ -269,11 +281,11 @@ def test_scalar_field(mapper, catalog): def test_complex_field(mapper, catalog): - from heracles.fields import ComplexField + from heracles.fields import Spin2Field npix = 12 * mapper.nside**2 - f = ComplexField("ra", "dec", "g1", "g2", "w", spin=2) + f = Spin2Field("ra", "dec", "g1", "g2", "w") m = coroutines.run(f(catalog, mapper)) w = next(iter(catalog))["w"]