Skip to content

Commit

Permalink
turn spin weight into a class property
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Jan 1, 2024
1 parent 536b9d0 commit 2259ca1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 35 deletions.
4 changes: 2 additions & 2 deletions examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,13 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:214: UserWarning: position and visibility maps have \n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n",
"different NSIDE\n",
" warnings.warn(\"position and visibility maps have different NSIDE\")\n",
"</pre>\n"
],
"text/plain": [
"/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:214: UserWarning: position and visibility maps have \n",
"/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n",
"different NSIDE\n",
" warnings.warn(\"position and visibility maps have different NSIDE\")\n"
]
Expand Down
60 changes: 30 additions & 30 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from types import MappingProxyType
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -54,10 +53,19 @@ class Field(metaclass=ABCMeta):
"""

# every field subclass has a static spin weight attribute, which can be
# overwritten by the class (or even an individual instance)
__spin: int | None = None

def __init_subclass__(cls, spin: int | None = None) -> None:
"""Initialise spin weight of field subclasses."""
super().__init_subclass__()
if spin is not None:
cls.__spin = spin

def __init__(
self,
*columns: str,
spin: int = 0,
) -> None:
"""Initialise the field."""
super().__init__()
Expand All @@ -70,9 +78,9 @@ def __init__(
raise TypeError(msg) from None
else:
self.__columns = None
self._metadata: dict[str, Any] = {
"spin": spin,
}
self._metadata: dict[str, Any] = {}
if (spin := self.__spin) is not None:
self._metadata["spin"] = spin

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -103,7 +111,12 @@ def metadata(self) -> Mapping[str, Any]:
@property
def spin(self) -> int:
"""Spin weight of field."""
return self._metadata["spin"]
spin = self.__spin
if spin is None:
clsname = self.__class__.__name__
msg = f"field of type '{clsname}' has undefined spin weight"
raise ValueError(msg)
return spin

@abstractmethod
async def __call__(
Expand Down Expand Up @@ -137,7 +150,7 @@ async def _pages(
await coroutines.sleep()


class Positions(Field):
class Positions(Field, spin=0):
"""Field of positions in a catalogue.
Can produce both overdensity maps and number count maps, depending
Expand All @@ -152,7 +165,7 @@ def __init__(
nbar: float | None = None,
) -> None:
"""Create a position field."""
super().__init__(*columns, spin=0)
super().__init__(*columns)
self.__overdensity = overdensity
self.__nbar = nbar

Expand Down Expand Up @@ -256,13 +269,9 @@ async def __call__(
return pos


class ScalarField(Field):
class ScalarField(Field, spin=0):
"""Field of real scalar values in a catalogue."""

def __init__(self, *columns: str) -> None:
"""Create a scalar field."""
super().__init__(*columns, spin=0)

@staticmethod
def _init_columns(
lon: str,
Expand Down Expand Up @@ -343,15 +352,11 @@ async def __call__(
class ComplexField(Field):
"""Field of complex values in a catalogue.
Complex fields can have non-zero spin weight, set using the
``spin=`` parameter.
The :class:`ComplexField` subclasses such as :class:`Spin2Field`
have non-zero spin weight.
"""

def __init__(self, *columns: str, spin: int = 0) -> None:
"""Create a complex field."""
super().__init__(*columns, spin=spin)

@staticmethod
def _init_columns(
lon: str,
Expand Down Expand Up @@ -430,13 +435,9 @@ async def __call__(
return val


class Visibility(Field):
class Visibility(Field, spin=0):
"""Copy visibility map from catalogue at given resolution."""

def __init__(self) -> None:
"""Create a visibility field."""
super().__init__(spin=0)

@staticmethod
def _init_columns() -> Columns:
return ()
Expand Down Expand Up @@ -474,13 +475,9 @@ async def __call__(
return vmap


class Weights(Field):
class Weights(Field, spin=0):
"""Field of weight values from a catalogue."""

def __init__(self, *columns: str) -> None:
"""Create a weight field."""
super().__init__(*columns, spin=0)

@staticmethod
def _init_columns(lon: str, lat: str, weight: str | None = None) -> Columns:
return lon, lat, weight
Expand Down Expand Up @@ -528,6 +525,9 @@ async def __call__(
return wht


Spin2Field = partial(ComplexField, spin=2)
class Spin2Field(ComplexField, spin=2):
"""Spin-2 complex field."""


Shears = Spin2Field
Ellipticities = Spin2Field
18 changes: 15 additions & 3 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,19 @@ def test_field_abc():
with pytest.raises(TypeError):
Field()

class TestField(Field):
class SpinLessField(Field):
def _init_columns(self, *columns: str) -> Columns:
return columns

async def __call__(self):
pass

f = SpinLessField()

with pytest.raises(ValueError, match="undefined spin weight"):
f.spin

class TestField(Field, spin=0):
@staticmethod
def _init_columns(lon, lat, weight=None) -> Columns:
return lon, lat, weight
Expand Down Expand Up @@ -269,11 +281,11 @@ def test_scalar_field(mapper, catalog):


def test_complex_field(mapper, catalog):
from heracles.fields import ComplexField
from heracles.fields import Spin2Field

npix = 12 * mapper.nside**2

f = ComplexField("ra", "dec", "g1", "g2", "w", spin=2)
f = Spin2Field("ra", "dec", "g1", "g2", "w")
m = coroutines.run(f(catalog, mapper))

w = next(iter(catalog))["w"]
Expand Down

0 comments on commit 2259ca1

Please sign in to comment.