Skip to content

Commit

Permalink
API(maps): spherical harmonic transforms from mappers (#89)
Browse files Browse the repository at this point in the history
Move spherical harmonic transforms into mapper classes, since those are
the objects that know how to carry out a transform of their maps.

Closes: #88
  • Loading branch information
ntessore authored Dec 21, 2023
1 parent 17c7039 commit 6fc50b7
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 134 deletions.
13 changes: 6 additions & 7 deletions examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2b4e7378bbce4d9aad245cc95935bdc1",
"model_id": "b9e7f4072a8349bfbe5a45433b429037",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -690,7 +690,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "44ae51e18b0444cc917716e2eff2569e",
"model_id": "86b6b81c8732497c824cf663f0f24a6b",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -726,7 +726,7 @@
}
],
"source": [
"alms = transform_maps(maps, lmax=lmax, use_pixel_weights=True, progress=True)"
"alms = transform_maps(mapper, maps, lmax=lmax, progress=True)"
]
},
{
Expand Down Expand Up @@ -923,7 +923,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc2eb3fd9fe5451f87bb5129d9ced2cc",
"model_id": "28f4e1aa37f046d98d38751faf16fcac",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -1027,7 +1027,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "662e8b4142e34c27ad62d2b88a795240",
"model_id": "efab2494bdf444119ce4f008cf2232bb",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -1063,8 +1063,7 @@
}
],
"source": [
"alms_mm = transform_maps(maps_mm, progress=True,\n",
" lmax=lmax_mm, use_pixel_weights=True)"
"alms_mm = transform_maps(mapper_mm, maps_mm, progress=True, lmax=lmax_mm)"
]
},
{
Expand Down
43 changes: 29 additions & 14 deletions heracles/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
# License along with Heracles. If not, see <https://www.gnu.org/licenses/>.
"""module for common core functionality"""

from __future__ import annotations

from collections import UserDict
from collections.abc import Mapping, Sequence
from typing import Any, Callable, TypeVar

import numpy as np

T = TypeVar("T")


def toc_match(key, include=None, exclude=None):
"""return whether a tocdict entry matches include/exclude criteria"""
Expand All @@ -49,21 +54,31 @@ def toc_filter(obj, include=None, exclude=None):
raise TypeError(msg)


def toc_nearest(obj, key):
"""return the closest match to *key* in *obj*."""
if isinstance(key, Sequence):
t = tuple(key)
def multi_value_getter(obj: T | Mapping[Any, T]) -> Callable[[Any], T]:
"""Return a getter for values or mappings."""
if isinstance(obj, Mapping):

def getter(key: Any) -> T:
if isinstance(key, Sequence):
t = tuple(key)
else:
t = (key,)
while t:
if t in obj:
return obj[t]
if len(t) == 1 and t[0] in obj:
return obj[t[0]]
t = t[:-1]
if t in obj:
return obj[t]
raise KeyError(key)

else:
t = (key,)
while t:
if t in obj:
return obj[t]
if len(t) == 1 and t[0] in obj:
return obj[t[0]]
t = t[:-1]
if t in obj:
return obj[t]
raise KeyError(key)

def getter(key: Any) -> T:
return obj

return getter


# subclassing UserDict here since that returns the correct type from methods
Expand Down
8 changes: 4 additions & 4 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def __call__(
# map catalogue data asynchronously
async for page in _pages(catalog, progress):
lon, lat = page.get(*col)
mapper(lon, lat, [pos])
mapper.map_values(lon, lat, [pos])

ngal += page.size

Expand Down Expand Up @@ -300,7 +300,7 @@ async def __call__(
lon, lat, v = page.get(*col)
w = page.get(wcol) if wcol is not None else None

mapper(lon, lat, [val], [v], w)
mapper.map_values(lon, lat, [val], [v], w)

ngal += page.size
if w is None:
Expand Down Expand Up @@ -391,7 +391,7 @@ async def __call__(

w = page.get(wcol) if wcol is not None else None

mapper(lon, lat, [val[0], val[1]], [re, im], w)
mapper.map_values(lon, lat, [val[0], val[1]], [re, im], w)

ngal += page.size
if w is None:
Expand Down Expand Up @@ -509,7 +509,7 @@ async def __call__(
else:
w = page.get(wcol)

mapper(lon, lat, [wht], None, w)
mapper.map_values(lon, lat, [wht], None, w)

del page, lon, lat, w

Expand Down
51 changes: 49 additions & 2 deletions heracles/maps/_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import numpy as np
from numba import njit

from heracles.core import update_metadata

from ._mapper import Mapper

if TYPE_CHECKING:
Expand Down Expand Up @@ -118,14 +120,21 @@ class Healpix(Mapper, kernel="healpix"):
available as the *nside* property.
"""

def __init__(self, nside: int, dtype: DTypeLike = np.float64) -> None:
def __init__(
self,
nside: int,
dtype: DTypeLike = np.float64,
*,
datapath: str | None = None,
) -> None:
"""
Mapper for HEALPix maps with the given *nside* parameter.
"""
super().__init__()
self.__nside = nside
self.__npix = hp.nside2npix(nside)
self.__dtype = np.dtype(dtype)
self.__datapath = datapath
self._metadata |= {
"nside": nside,
}
Expand Down Expand Up @@ -158,7 +167,7 @@ def area(self) -> float:
"""
return hp.nside2pixarea(self.__nside)

def __call__(
def map_values(
self,
lon: ArrayLike,
lat: ArrayLike,
Expand Down Expand Up @@ -186,3 +195,41 @@ def __call__(
_map(ipix, maps, values)
else:
_mapw(ipix, maps, values, weight)

def transform(
self,
maps: ArrayLike,
lmax: int | None = None,
) -> ArrayLike | tuple[ArrayLike, ArrayLike]:
"""
Spherical harmonic transform of HEALPix maps.
"""

md = maps.dtype.metadata or {}
spin = md.get("spin", 0)

if spin == 0:
pol = False
elif spin == 2:
maps = [np.zeros(self.__npix), maps[0], maps[1]]
pol = True
else:
msg = f"spin-{spin} maps not yet supported"
raise NotImplementedError(msg)

alms = hp.map2alm(
maps,
lmax=lmax,
pol=pol,
use_pixel_weights=True,
datapath=self.__datapath,
)

if spin == 0:
update_metadata(alms, **md)
else:
alms = (alms[1], alms[2])
update_metadata(alms[0], **md)
update_metadata(alms[1], **md)

return alms
12 changes: 11 additions & 1 deletion heracles/maps/_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def area(self) -> float:
"""

@abstractmethod
def __call__(
def map_values(
self,
lon: ArrayLike,
lat: ArrayLike,
Expand All @@ -116,3 +116,13 @@ def __call__(
"""
Add values to maps.
"""

@abstractmethod
def transform(
self,
maps: ArrayLike,
lmax: int | None = None,
) -> ArrayLike | tuple[ArrayLike, ArrayLike]:
"""
The spherical harmonic transform for this mapper.
"""
59 changes: 21 additions & 38 deletions heracles/maps/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@

from __future__ import annotations

from collections.abc import Mapping, MutableMapping, Sequence
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any

import coroutines
import numpy as np

from heracles.core import TocDict, toc_match, toc_nearest, update_metadata
from heracles.core import TocDict, multi_value_getter, toc_match

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping, Sequence

from numpy.typing import NDArray

from heracles.catalog import Catalog
Expand Down Expand Up @@ -85,6 +85,9 @@ def map_catalogs(
if out is None:
out = TocDict()

# getter for mapper value or dict
mappergetter = multi_value_getter(mapper)

# collect groups of items to go through
# items are tuples of (key, field, catalog)
groups = [
Expand Down Expand Up @@ -115,11 +118,7 @@ def map_catalogs(
keys, coros = [], []
for key, field, catalog in items:
if toc_match(key, include, exclude):
# find the mapper for this key
if isinstance(mapper, Mapping):
_mapper = toc_nearest(mapper, key)
else:
_mapper = mapper
_mapper = mappergetter(key)

coro = _map_progress(key, field, catalog, _mapper, prog)

Expand All @@ -144,6 +143,7 @@ def map_catalogs(


def transform_maps(
mapper: Mapper | Mapping[Any, Mapper],
maps: Mapping[tuple[Any, Any], NDArray],
*,
lmax: int | Mapping[Any, int] | None = None,
Expand All @@ -153,12 +153,14 @@ def transform_maps(
) -> MutableMapping[tuple[Any, Any], NDArray]:
"""transform a set of maps to alms"""

import healpy as hp

# the output toc dict
if out is None:
out = TocDict()

# getter for values or dicts
mappergetter = multi_value_getter(mapper)
lmaxgetter = multi_value_getter(lmax)

# display a progress bar if asked to
progressbar: Progress | nullcontext
if progress:
Expand All @@ -172,14 +174,6 @@ def transform_maps(
# convert maps to alms, taking care of complex and spin-weighted maps
with progressbar as prog:
for (k, i), m in maps.items():
if isinstance(lmax, Mapping):
_lmax = toc_nearest(lmax, (k, i))
else:
_lmax = lmax

md = m.dtype.metadata or {}
spin = md.get("spin", 0)

if progress:
subtask = prog.task(
f"[{k}, {i}]",
Expand All @@ -188,29 +182,18 @@ def transform_maps(
total=None,
)

tfm: NDArray | list[NDArray]
if spin == 0:
tfm = m
pol = False
elif spin == 2:
tfm = [np.zeros(np.shape(m)[-1]), m[0], m[1]]
pol = True
else:
msg = f"spin-{spin} maps not yet supported"
raise NotImplementedError(msg)
_mapper = mappergetter((k, i))
_lmax = lmaxgetter((k, i))

alms = hp.map2alm(tfm, lmax=_lmax, pol=pol, **kwargs)
alms = _mapper.transform(m, _lmax)

if spin == 0:
alms = {(k, i): alms}
elif spin == 2:
alms = {(f"{k}_E", i): alms[1], (f"{k}_B", i): alms[2]}

for ki, alm in alms.items():
update_metadata(alm, **md)
out[ki] = alm
if isinstance(alms, tuple):
out[f"{k}_E", i] = alms[0]
out[f"{k}_B", i] = alms[1]
else:
out[k, i] = alms

del m, tfm, alms, alm
del m, alms

if progress:
subtask.remove()
Expand Down
Loading

0 comments on commit 6fc50b7

Please sign in to comment.