From 6fc50b7a5f8d179d2fb82a1401af79eb1da5845e Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Thu, 21 Dec 2023 11:21:01 +0000 Subject: [PATCH] API(maps): spherical harmonic transforms from mappers (#89) Move spherical harmonic transforms into mapper classes, since those are the objects that know how to carry out a transform of their maps. Closes: #88 --- examples/example.ipynb | 13 ++-- heracles/core.py | 43 +++++++++----- heracles/fields.py | 8 +-- heracles/maps/_healpix.py | 51 +++++++++++++++- heracles/maps/_mapper.py | 12 +++- heracles/maps/_mapping.py | 59 +++++++----------- tests/test_core.py | 34 ++++++++--- tests/test_maps.py | 122 ++++++++++++++++++++------------------ 8 files changed, 208 insertions(+), 134 deletions(-) diff --git a/examples/example.ipynb b/examples/example.ipynb index 0c56cf9..dc0b428 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -520,7 +520,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2b4e7378bbce4d9aad245cc95935bdc1", + "model_id": "b9e7f4072a8349bfbe5a45433b429037", "version_major": 2, "version_minor": 0 }, @@ -690,7 +690,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "44ae51e18b0444cc917716e2eff2569e", + "model_id": "86b6b81c8732497c824cf663f0f24a6b", "version_major": 2, "version_minor": 0 }, @@ -726,7 +726,7 @@ } ], "source": [ - "alms = transform_maps(maps, lmax=lmax, use_pixel_weights=True, progress=True)" + "alms = transform_maps(mapper, maps, lmax=lmax, progress=True)" ] }, { @@ -923,7 +923,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cc2eb3fd9fe5451f87bb5129d9ced2cc", + "model_id": "28f4e1aa37f046d98d38751faf16fcac", "version_major": 2, "version_minor": 0 }, @@ -1027,7 +1027,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "662e8b4142e34c27ad62d2b88a795240", + "model_id": "efab2494bdf444119ce4f008cf2232bb", "version_major": 2, "version_minor": 0 }, @@ -1063,8 +1063,7 @@ } ], "source": [ - "alms_mm = transform_maps(maps_mm, progress=True,\n", - " lmax=lmax_mm, use_pixel_weights=True)" + "alms_mm = transform_maps(mapper_mm, maps_mm, progress=True, lmax=lmax_mm)" ] }, { diff --git a/heracles/core.py b/heracles/core.py index b9f414d..72900db 100644 --- a/heracles/core.py +++ b/heracles/core.py @@ -18,11 +18,16 @@ # License along with Heracles. If not, see . """module for common core functionality""" +from __future__ import annotations + from collections import UserDict from collections.abc import Mapping, Sequence +from typing import Any, Callable, TypeVar import numpy as np +T = TypeVar("T") + def toc_match(key, include=None, exclude=None): """return whether a tocdict entry matches include/exclude criteria""" @@ -49,21 +54,31 @@ def toc_filter(obj, include=None, exclude=None): raise TypeError(msg) -def toc_nearest(obj, key): - """return the closest match to *key* in *obj*.""" - if isinstance(key, Sequence): - t = tuple(key) +def multi_value_getter(obj: T | Mapping[Any, T]) -> Callable[[Any], T]: + """Return a getter for values or mappings.""" + if isinstance(obj, Mapping): + + def getter(key: Any) -> T: + if isinstance(key, Sequence): + t = tuple(key) + else: + t = (key,) + while t: + if t in obj: + return obj[t] + if len(t) == 1 and t[0] in obj: + return obj[t[0]] + t = t[:-1] + if t in obj: + return obj[t] + raise KeyError(key) + else: - t = (key,) - while t: - if t in obj: - return obj[t] - if len(t) == 1 and t[0] in obj: - return obj[t[0]] - t = t[:-1] - if t in obj: - return obj[t] - raise KeyError(key) + + def getter(key: Any) -> T: + return obj + + return getter # subclassing UserDict here since that returns the correct type from methods diff --git a/heracles/fields.py b/heracles/fields.py index 06d8db0..3196eff 100644 --- a/heracles/fields.py +++ b/heracles/fields.py @@ -196,7 +196,7 @@ async def __call__( # map catalogue data asynchronously async for page in _pages(catalog, progress): lon, lat = page.get(*col) - mapper(lon, lat, [pos]) + mapper.map_values(lon, lat, [pos]) ngal += page.size @@ -300,7 +300,7 @@ async def __call__( lon, lat, v = page.get(*col) w = page.get(wcol) if wcol is not None else None - mapper(lon, lat, [val], [v], w) + mapper.map_values(lon, lat, [val], [v], w) ngal += page.size if w is None: @@ -391,7 +391,7 @@ async def __call__( w = page.get(wcol) if wcol is not None else None - mapper(lon, lat, [val[0], val[1]], [re, im], w) + mapper.map_values(lon, lat, [val[0], val[1]], [re, im], w) ngal += page.size if w is None: @@ -509,7 +509,7 @@ async def __call__( else: w = page.get(wcol) - mapper(lon, lat, [wht], None, w) + mapper.map_values(lon, lat, [wht], None, w) del page, lon, lat, w diff --git a/heracles/maps/_healpix.py b/heracles/maps/_healpix.py index 7ab219a..9f23145 100644 --- a/heracles/maps/_healpix.py +++ b/heracles/maps/_healpix.py @@ -29,6 +29,8 @@ import numpy as np from numba import njit +from heracles.core import update_metadata + from ._mapper import Mapper if TYPE_CHECKING: @@ -118,7 +120,13 @@ class Healpix(Mapper, kernel="healpix"): available as the *nside* property. """ - def __init__(self, nside: int, dtype: DTypeLike = np.float64) -> None: + def __init__( + self, + nside: int, + dtype: DTypeLike = np.float64, + *, + datapath: str | None = None, + ) -> None: """ Mapper for HEALPix maps with the given *nside* parameter. """ @@ -126,6 +134,7 @@ def __init__(self, nside: int, dtype: DTypeLike = np.float64) -> None: self.__nside = nside self.__npix = hp.nside2npix(nside) self.__dtype = np.dtype(dtype) + self.__datapath = datapath self._metadata |= { "nside": nside, } @@ -158,7 +167,7 @@ def area(self) -> float: """ return hp.nside2pixarea(self.__nside) - def __call__( + def map_values( self, lon: ArrayLike, lat: ArrayLike, @@ -186,3 +195,41 @@ def __call__( _map(ipix, maps, values) else: _mapw(ipix, maps, values, weight) + + def transform( + self, + maps: ArrayLike, + lmax: int | None = None, + ) -> ArrayLike | tuple[ArrayLike, ArrayLike]: + """ + Spherical harmonic transform of HEALPix maps. + """ + + md = maps.dtype.metadata or {} + spin = md.get("spin", 0) + + if spin == 0: + pol = False + elif spin == 2: + maps = [np.zeros(self.__npix), maps[0], maps[1]] + pol = True + else: + msg = f"spin-{spin} maps not yet supported" + raise NotImplementedError(msg) + + alms = hp.map2alm( + maps, + lmax=lmax, + pol=pol, + use_pixel_weights=True, + datapath=self.__datapath, + ) + + if spin == 0: + update_metadata(alms, **md) + else: + alms = (alms[1], alms[2]) + update_metadata(alms[0], **md) + update_metadata(alms[1], **md) + + return alms diff --git a/heracles/maps/_mapper.py b/heracles/maps/_mapper.py index cedf517..bf29dc5 100644 --- a/heracles/maps/_mapper.py +++ b/heracles/maps/_mapper.py @@ -105,7 +105,7 @@ def area(self) -> float: """ @abstractmethod - def __call__( + def map_values( self, lon: ArrayLike, lat: ArrayLike, @@ -116,3 +116,13 @@ def __call__( """ Add values to maps. """ + + @abstractmethod + def transform( + self, + maps: ArrayLike, + lmax: int | None = None, + ) -> ArrayLike | tuple[ArrayLike, ArrayLike]: + """ + The spherical harmonic transform for this mapper. + """ diff --git a/heracles/maps/_mapping.py b/heracles/maps/_mapping.py index 1b0402a..7f8b67a 100644 --- a/heracles/maps/_mapping.py +++ b/heracles/maps/_mapping.py @@ -22,16 +22,16 @@ from __future__ import annotations -from collections.abc import Mapping, MutableMapping, Sequence from contextlib import nullcontext from typing import TYPE_CHECKING, Any import coroutines -import numpy as np -from heracles.core import TocDict, toc_match, toc_nearest, update_metadata +from heracles.core import TocDict, multi_value_getter, toc_match if TYPE_CHECKING: + from collections.abc import Mapping, MutableMapping, Sequence + from numpy.typing import NDArray from heracles.catalog import Catalog @@ -85,6 +85,9 @@ def map_catalogs( if out is None: out = TocDict() + # getter for mapper value or dict + mappergetter = multi_value_getter(mapper) + # collect groups of items to go through # items are tuples of (key, field, catalog) groups = [ @@ -115,11 +118,7 @@ def map_catalogs( keys, coros = [], [] for key, field, catalog in items: if toc_match(key, include, exclude): - # find the mapper for this key - if isinstance(mapper, Mapping): - _mapper = toc_nearest(mapper, key) - else: - _mapper = mapper + _mapper = mappergetter(key) coro = _map_progress(key, field, catalog, _mapper, prog) @@ -144,6 +143,7 @@ def map_catalogs( def transform_maps( + mapper: Mapper | Mapping[Any, Mapper], maps: Mapping[tuple[Any, Any], NDArray], *, lmax: int | Mapping[Any, int] | None = None, @@ -153,12 +153,14 @@ def transform_maps( ) -> MutableMapping[tuple[Any, Any], NDArray]: """transform a set of maps to alms""" - import healpy as hp - # the output toc dict if out is None: out = TocDict() + # getter for values or dicts + mappergetter = multi_value_getter(mapper) + lmaxgetter = multi_value_getter(lmax) + # display a progress bar if asked to progressbar: Progress | nullcontext if progress: @@ -172,14 +174,6 @@ def transform_maps( # convert maps to alms, taking care of complex and spin-weighted maps with progressbar as prog: for (k, i), m in maps.items(): - if isinstance(lmax, Mapping): - _lmax = toc_nearest(lmax, (k, i)) - else: - _lmax = lmax - - md = m.dtype.metadata or {} - spin = md.get("spin", 0) - if progress: subtask = prog.task( f"[{k}, {i}]", @@ -188,29 +182,18 @@ def transform_maps( total=None, ) - tfm: NDArray | list[NDArray] - if spin == 0: - tfm = m - pol = False - elif spin == 2: - tfm = [np.zeros(np.shape(m)[-1]), m[0], m[1]] - pol = True - else: - msg = f"spin-{spin} maps not yet supported" - raise NotImplementedError(msg) + _mapper = mappergetter((k, i)) + _lmax = lmaxgetter((k, i)) - alms = hp.map2alm(tfm, lmax=_lmax, pol=pol, **kwargs) + alms = _mapper.transform(m, _lmax) - if spin == 0: - alms = {(k, i): alms} - elif spin == 2: - alms = {(f"{k}_E", i): alms[1], (f"{k}_B", i): alms[2]} - - for ki, alm in alms.items(): - update_metadata(alm, **md) - out[ki] = alm + if isinstance(alms, tuple): + out[f"{k}_E", i] = alms[0] + out[f"{k}_B", i] = alms[1] + else: + out[k, i] = alms - del m, tfm, alms, alm + del m, alms if progress: subtask.remove() diff --git a/tests/test_core.py b/tests/test_core.py index 4825312..b7efb8f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -38,15 +38,31 @@ def test_toc_filter(): toc_filter(object()) -def test_toc_nearest(): - from heracles.core import toc_nearest - - assert toc_nearest({(1,): "x"}, 1) == "x" - assert toc_nearest({(1,): "x", (2,): "y", (): "z"}, 2) == "y" - assert toc_nearest({(1,): "x", (2,): "y", (): "z"}, (2, 0)) == "y" - assert toc_nearest({(1,): "x", (2,): "y", (): "z"}, (3, 0)) == "z" - assert toc_nearest({(0, 1): "x", (0, 2): "y", (0,): "z"}, (0, 2)) == "y" - assert toc_nearest({0: "x", 1: "y", (): "z"}, (0, 2)) == "x" +def test_multi_value_getter(): + from heracles.core import multi_value_getter + + getter = multi_value_getter( + { + (1,): "x", + (2,): "y", + (): "z", + (2, 1): "w", + 4: "v", + }, + ) + + assert getter(1) == "x" + assert getter(2) == "y" + assert getter((2, 0)) == "y" + assert getter((3, 0)) == "z" + assert getter((2, 1)) == "w" + assert getter((4, 0)) == "v" + + getter = multi_value_getter("x") + + assert getter(1) == "x" + assert getter(2) == "x" + assert getter((2, 0)) == "x" def test_tocdict(): diff --git a/tests/test_maps.py b/tests/test_maps.py index a6b038b..acb070d 100644 --- a/tests/test_maps.py +++ b/tests/test_maps.py @@ -47,7 +47,7 @@ def test_healpix_maps(rng): # map positions m = np.zeros(mapper.size, mapper.dtype) - mapper(lon, lat, [m]) + mapper.map_values(lon, lat, [m]) expected = np.zeros(npix) np.add.at(expected, ipix, 1) @@ -57,7 +57,7 @@ def test_healpix_maps(rng): # map positions with weights m = np.zeros(mapper.size, mapper.dtype) - mapper(lon, lat, [m], None, w) + mapper.map_values(lon, lat, [m], None, w) expected = np.zeros(npix) np.add.at(expected, ipix, w) @@ -67,7 +67,7 @@ def test_healpix_maps(rng): # map one set of values m = np.zeros(mapper.size, mapper.dtype) - mapper(lon, lat, [m], [x]) + mapper.map_values(lon, lat, [m], [x]) expected = np.zeros(npix) np.add.at(expected, ipix, x) @@ -77,7 +77,7 @@ def test_healpix_maps(rng): # map two sets of values m = np.zeros((2, mapper.size), mapper.dtype) - mapper(lon, lat, [m[0], m[1]], [x, y]) + mapper.map_values(lon, lat, [m[0], m[1]], [x, y]) expected = np.zeros((2, npix)) np.add.at(expected[0], ipix, x) @@ -88,7 +88,7 @@ def test_healpix_maps(rng): # map one set of values with weights m = np.zeros(mapper.size, mapper.dtype) - mapper(lon, lat, [m], [x], w) + mapper.map_values(lon, lat, [m], [x], w) expected = np.zeros(npix) np.add.at(expected, ipix, w * x) @@ -98,7 +98,7 @@ def test_healpix_maps(rng): # map two sets of values with weights m = np.zeros((2, mapper.size), mapper.dtype) - mapper(lon, lat, [m[0], m[1]], [x, y], w) + mapper.map_values(lon, lat, [m[0], m[1]], [x, y], w) expected = np.zeros((2, npix)) np.add.at(expected[0], ipix, w * x) @@ -107,6 +107,52 @@ def test_healpix_maps(rng): npt.assert_array_equal(m, expected) +@unittest.mock.patch("healpy.map2alm") +def test_healpix_transform(mock_map2alm, rng): + from heracles.core import update_metadata + from heracles.maps import Healpix + + nside = 32 + npix = 12 * nside**2 + + mapper = Healpix(nside) + + # single scalar map + m = rng.standard_normal(npix) + update_metadata(m, spin=0, nside=nside, a=1) + + mock_map2alm.return_value = np.empty(0, dtype=complex) + + alms = mapper.transform(m) + + assert alms is mock_map2alm.return_value + assert alms.dtype.metadata["spin"] == 0 + assert alms.dtype.metadata["a"] == 1 + assert alms.dtype.metadata["nside"] == nside + + # polarisation map + m = rng.standard_normal((2, npix)) + update_metadata(m, spin=2, nside=nside, b=2) + + mock_map2alm.return_value = ( + np.empty(0, dtype=complex), + np.empty(0, dtype=complex), + np.empty(0, dtype=complex), + ) + + alms = mapper.transform(m) + + assert len(alms) == 2 + assert alms[0] is mock_map2alm.return_value[1] + assert alms[1] is mock_map2alm.return_value[2] + assert alms[0].dtype.metadata["spin"] == 2 + assert alms[1].dtype.metadata["spin"] == 2 + assert alms[0].dtype.metadata["b"] == 2 + assert alms[1].dtype.metadata["b"] == 2 + assert alms[0].dtype.metadata["nside"] == nside + assert alms[1].dtype.metadata["nside"] == nside + + class MockField: def __init__(self): self.args = [] @@ -166,63 +212,21 @@ def test_map_catalogs_match(): def test_transform_maps(rng): - from heracles.core import update_metadata from heracles.maps import transform_maps - nside = 32 - npix = 12 * nside**2 - - t = rng.standard_normal(npix) - update_metadata(t, spin=0, nside=nside, a=1) - p = rng.standard_normal((2, npix)) - update_metadata(p, spin=2, nside=nside, b=2) + alms_x = unittest.mock.Mock() + alms_ye = unittest.mock.Mock() + alms_yb = unittest.mock.Mock() - # single scalar map - maps = {("T", 0): t} - alms = transform_maps(maps) - - assert len(alms) == 1 - assert alms.keys() == maps.keys() - assert alms["T", 0].dtype.metadata["spin"] == 0 - assert alms["T", 0].dtype.metadata["a"] == 1 - assert alms["T", 0].dtype.metadata["nside"] == nside - - # polarisation map - maps = {("P", 0): p} - alms = transform_maps(maps) + mapper = unittest.mock.Mock() + mapper.transform.side_effect = (alms_x, (alms_ye, alms_yb)) - assert len(alms) == 2 - assert alms.keys() == {("P_E", 0), ("P_B", 0)} - assert alms["P_E", 0].dtype.metadata["spin"] == 2 - assert alms["P_B", 0].dtype.metadata["spin"] == 2 - assert alms["P_E", 0].dtype.metadata["b"] == 2 - assert alms["P_B", 0].dtype.metadata["b"] == 2 - assert alms["P_E", 0].dtype.metadata["nside"] == nside - assert alms["P_B", 0].dtype.metadata["nside"] == nside - - # mixed - maps = {("T", 0): t, ("P", 1): p} - alms = transform_maps(maps) + maps = {("X", 0): unittest.mock.Mock(), ("Y", 1): unittest.mock.Mock()} - assert len(alms) == 3 - assert alms.keys() == {("T", 0), ("P_E", 1), ("P_B", 1)} - assert alms["T", 0].dtype.metadata["spin"] == 0 - assert alms["P_E", 1].dtype.metadata["spin"] == 2 - assert alms["P_B", 1].dtype.metadata["spin"] == 2 - assert alms["T", 0].dtype.metadata["a"] == 1 - assert alms["P_E", 1].dtype.metadata["b"] == 2 - assert alms["P_B", 1].dtype.metadata["b"] == 2 - assert alms["T", 0].dtype.metadata["nside"] == nside - assert alms["P_E", 1].dtype.metadata["nside"] == nside - assert alms["P_B", 1].dtype.metadata["nside"] == nside - - # explicit lmax per map - maps = {("T", 0): t, ("P", 1): p} - lmax = {"T": 10, "P": 20} - alms = transform_maps(maps, lmax=lmax) + alms = transform_maps(mapper, maps) assert len(alms) == 3 - assert alms.keys() == {("T", 0), ("P_E", 1), ("P_B", 1)} - assert alms["T", 0].size == (lmax["T"] + 1) * (lmax["T"] + 2) // 2 - assert alms["P_E", 1].size == (lmax["P"] + 1) * (lmax["P"] + 2) // 2 - assert alms["P_B", 1].size == (lmax["P"] + 1) * (lmax["P"] + 2) // 2 + assert alms.keys() == {("X", 0), ("Y_E", 1), ("Y_B", 1)} + assert alms["X", 0] is alms_x + assert alms["Y_E", 1] is alms_ye + assert alms["Y_B", 1] is alms_yb