Skip to content

Commit

Permalink
API(fields): simplify column definition for fields (#94)
Browse files Browse the repository at this point in the history
This simplifies the column definition in `Field` subclasses so that they
only need to provide a list of required and optional column names. All
checking is done in the base class.

Closes: #91
  • Loading branch information
ntessore authored Jan 2, 2024
1 parent ef16490 commit a1f3eba
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 48 deletions.
4 changes: 2 additions & 2 deletions examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,13 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:247: UserWarning: position and visibility maps have \n",
"different NSIDE\n",
" warnings.warn(\"position and visibility maps have different NSIDE\")\n",
"</pre>\n"
],
"text/plain": [
"/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:227: UserWarning: position and visibility maps have \n",
"/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:247: UserWarning: position and visibility maps have \n",
"different NSIDE\n",
" warnings.warn(\"position and visibility maps have different NSIDE\")\n"
]
Expand Down
86 changes: 44 additions & 42 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .core import update_metadata

if TYPE_CHECKING:
from collections.abc import AsyncIterable, Mapping
from collections.abc import AsyncIterable, Mapping, Sequence
from typing import Any

from numpy.typing import ArrayLike
Expand All @@ -53,38 +53,63 @@ class Field(metaclass=ABCMeta):
"""

# column names: "col1", "col2", "[optional]"
uses: Sequence[str] | str | None = None

# every field subclass has a static spin weight attribute, which can be
# overwritten by the class (or even an individual instance)
__spin: int | None = None

# definition of required and optional columns
__ncol: tuple[int, int]

def __init_subclass__(cls, *, spin: int | None = None) -> None:
"""Initialise spin weight of field subclasses."""
super().__init_subclass__()
if spin is not None:
cls.__spin = spin
uses = cls.uses
if uses is None:
uses = ()
elif isinstance(uses, str):
uses = (uses,)
ncol = len(uses)
nopt = 0
for u in uses[::-1]:
if u.startswith("[") and u.endswith("]"):
nopt += 1
else:
break
cls.__ncol = (ncol - nopt, ncol)

def __init__(self, *columns: str) -> None:
"""Initialise the field."""
super().__init__()
self.__columns: Columns | None
if columns:
try:
self.__columns = self._init_columns(*columns)
except TypeError as exc:
msg = str(exc).replace("_init_columns", "__init__")
raise TypeError(msg) from None
else:
self.__columns = None
self.__columns = self._init_columns(*columns) if columns else None
self._metadata: dict[str, Any] = {}
if (spin := self.__spin) is not None:
self._metadata["spin"] = spin

@staticmethod
@abstractmethod
def _init_columns(*columns: str) -> Columns:
@classmethod
def _init_columns(cls, *columns: str) -> Columns:
"""Initialise the given set of columns for a specific field
subclass."""
...
nmin, nmax = cls.__ncol
if not nmin <= len(columns) <= nmax:
uses = cls.uses
if uses is None:
uses = ()
if isinstance(uses, str):
uses = (uses,)
count = f"{nmin}"
if nmax != nmin:
count += f" to {nmax}"
msg = f"field of type '{cls.__name__}' accepts {count} columns"
if uses:
msg += " (" + ", ".join(uses) + ")"
msg += f", received {len(columns)}"
raise ValueError(msg)
return columns + (None,) * (nmax - len(columns))

@property
def columns(self) -> Columns | None:
Expand Down Expand Up @@ -155,6 +180,8 @@ class Positions(Field, spin=0):
"""

uses = "longitude", "latitude"

def __init__(
self,
*columns: str,
Expand All @@ -166,10 +193,6 @@ def __init__(
self.__overdensity = overdensity
self.__nbar = nbar

@staticmethod
def _init_columns(lon: str, lat: str) -> Columns:
return lon, lat

@property
def overdensity(self) -> bool:
"""Flag to create overdensity maps."""
Expand Down Expand Up @@ -269,14 +292,7 @@ async def __call__(
class ScalarField(Field, spin=0):
"""Field of real scalar values in a catalogue."""

@staticmethod
def _init_columns(
lon: str,
lat: str,
value: str,
weight: str | None = None,
) -> Columns:
return lon, lat, value, weight
uses = "longitude", "latitude", "value", "[weight]"

async def __call__(
self,
Expand Down Expand Up @@ -354,15 +370,7 @@ class ComplexField(Field, spin=0):
"""

@staticmethod
def _init_columns(
lon: str,
lat: str,
real: str,
imag: str,
weight: str | None = None,
) -> Columns:
return lon, lat, real, imag, weight
uses = "longitude", "latitude", "real", "imag", "[weight]"

async def __call__(
self,
Expand Down Expand Up @@ -435,10 +443,6 @@ async def __call__(
class Visibility(Field, spin=0):
"""Copy visibility map from catalogue at given resolution."""

@staticmethod
def _init_columns() -> Columns:
return ()

async def __call__(
self,
catalog: Catalog,
Expand Down Expand Up @@ -475,9 +479,7 @@ async def __call__(
class Weights(Field, spin=0):
"""Field of weight values from a catalogue."""

@staticmethod
def _init_columns(lon: str, lat: str, weight: str | None = None) -> Columns:
return lon, lat, weight
uses = "longitude", "latitude", "[weight]"

async def __call__(
self,
Expand Down
6 changes: 2 additions & 4 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ async def __call__(self):
f.spin

class TestField(Field, spin=0):
@staticmethod
def _init_columns(lon, lat, weight=None) -> Columns:
return lon, lat, weight
uses = "lon", "lat", "[weight]"

async def __call__(self):
pass
Expand All @@ -114,7 +112,7 @@ async def __call__(self):
with pytest.raises(ValueError):
f.columns_or_error

with pytest.raises(TypeError, match=r"__init__\(\) missing 1 required"):
with pytest.raises(ValueError, match="accepts 2 to 3 columns"):
TestField("lon")

f = TestField("lon", "lat")
Expand Down

0 comments on commit a1f3eba

Please sign in to comment.