Skip to content

Commit

Permalink
ENH(maps): allow setting lmax per map in transform_maps() (#52)
Browse files Browse the repository at this point in the history
Make it possible to pass a mapping to the `lmax=` parameter of
`transform_maps()` that contains individual `lmax` values for each map.

Closes: #48
  • Loading branch information
ntessore authored Oct 10, 2023
1 parent 9d7ca1a commit abae466
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
9 changes: 7 additions & 2 deletions heracles/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import typing as t
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import Generator
from collections.abc import Generator, Mapping
from functools import partial, wraps

import healpy as hp
Expand Down Expand Up @@ -829,6 +829,7 @@ def map_catalogs(
def transform_maps(
maps: t.Mapping[tuple[t.Any, t.Any], MapData],
*,
lmax: t.Union[int, t.Mapping[t.Any, int], None] = None,
out: t.MutableMapping[t.Any, t.Any] = None,
progress: bool = False,
**kwargs,
Expand All @@ -849,6 +850,10 @@ def transform_maps(
# convert maps to alms, taking care of complex and spin-weighted maps
for (k, i), m in maps.items():
nside = hp.get_nside(m)
if isinstance(lmax, Mapping):
_lmax = lmax.get((k, i)) or lmax.get(k)
else:
_lmax = lmax

md = m.dtype.metadata or {}
spin = md.get("spin", 0)
Expand All @@ -864,7 +869,7 @@ def transform_maps(
msg = f"spin-{spin} maps not yet supported"
raise NotImplementedError(msg)

alms = hp.map2alm(m, pol=pol, **kwargs)
alms = hp.map2alm(m, lmax=_lmax, pol=pol, **kwargs)

if spin == 0:
alms = {(k, i): alms}
Expand Down
11 changes: 11 additions & 0 deletions tests/test_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ def test_transform_maps(rng):
assert alms["P_E", 1].dtype.metadata["nside"] == nside
assert alms["P_B", 1].dtype.metadata["nside"] == nside

# explicit lmax per map
maps = {("T", 0): t, ("P", 1): p}
lmax = {"T": 10, "P": 20}
alms = transform_maps(maps, lmax=lmax)

assert len(alms) == 3
assert alms.keys() == {("T", 0), ("P_E", 1), ("P_B", 1)}
assert alms["T", 0].size == (lmax["T"] + 1) * (lmax["T"] + 2) // 2
assert alms["P_E", 1].size == (lmax["P"] + 1) * (lmax["P"] + 2) // 2
assert alms["P_B", 1].size == (lmax["P"] + 1) * (lmax["P"] + 2) // 2


def test_update_metadata():
from heracles.maps import update_metadata
Expand Down

0 comments on commit abae466

Please sign in to comment.