From 8b8d18883ed70982b68c544ee3b81ee8d49bd289 Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Fri, 5 Jan 2024 08:49:03 +0000 Subject: [PATCH] ENH(fields): return weights for a set of fields (#97) Add the `weights_for_fields()` function which returns a set of weight function names for a set of fields. It has a `comb=` parameter that can be used to produce tuples of weight function names for combinations of *N* fields. --- heracles/fields.py | 60 ++++++++++++++++++++++++++++++++++++++++++-- tests/test_fields.py | 40 +++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/heracles/fields.py b/heracles/fields.py index 3cea1c3..65bcca5 100644 --- a/heracles/fields.py +++ b/heracles/fields.py @@ -22,17 +22,19 @@ import warnings from abc import ABCMeta, abstractmethod +from functools import partial +from itertools import combinations_with_replacement, product from types import MappingProxyType from typing import TYPE_CHECKING import coroutines import numpy as np -from .core import update_metadata +from .core import toc_match, update_metadata if TYPE_CHECKING: from collections.abc import AsyncIterable, Mapping, Sequence - from typing import Any + from typing import Any, TypeGuard from numpy.typing import ArrayLike @@ -537,3 +539,57 @@ class Spin2Field(ComplexField, spin=2): Shears = Spin2Field Ellipticities = Spin2Field + + +def weights_for_fields( + fields: Mapping[str, Field], + *, + comb: int | None = None, + include: Sequence[Sequence[str]] | None = None, + exclude: Sequence[Sequence[str]] | None = None, + append_eb: bool = False, +) -> Sequence[str] | Sequence[tuple[str, ...]]: + """ + Return the weights for a given set of fields. + + If *comb* is given, produce combinations of weights for combinations + of a number *comb* of fields. + + The fields (not weights) can be filtered using the *include* and + *exclude* parameters. If *append_eb* is true, the filter is applied + to field names including the E/B-mode suffix when the spin weight is + non-zero. + + """ + + isgood = partial(toc_match, include=include, exclude=exclude) + + def _key_eb(key: str) -> tuple[str, ...]: + """Return the key of the given field with _E/_B appended (or not).""" + if append_eb and fields[key].spin != 0: + return (f"{key}_E", f"{key}_B") + return (key,) + + def _all_str(seq: tuple[str | None, ...]) -> TypeGuard[tuple[str, ...]]: + """Return true if all items in *seq* are strings.""" + return not any(item is None for item in seq) + + if comb is None: + weights_no_comb: list[str] = [] + for key, field in fields.items(): + if field.weight is None: + continue + if not any(map(isgood, _key_eb(key))): + continue + weights_no_comb.append(field.weight) + return weights_no_comb + + weights_comb: list[tuple[str, ...]] = [] + for keys in combinations_with_replacement(fields, comb): + item = tuple(fields[key].weight for key in keys) + if not _all_str(item): + continue + if not any(map(isgood, product(*map(_key_eb, keys)))): + continue + weights_comb.append(item) + return weights_comb diff --git a/tests/test_fields.py b/tests/test_fields.py index 718baf3..4a88dda 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -326,3 +326,43 @@ def test_weights(mapper, catalog): "nside": mapper.nside, } np.testing.assert_array_almost_equal(m, w / wbar) + + +def test_weights_for_fields(): + from unittest.mock import Mock + + from heracles.fields import weights_for_fields + + fields = { + "A": Mock(weight="X", spin=0), + "B": Mock(weight="Y", spin=2), + "C": Mock(weight=None), + } + + weights = weights_for_fields(fields) + + assert weights == ["X", "Y"] + + weights = weights_for_fields(fields, comb=1) + + assert weights == [("X",), ("Y",)] + + weights = weights_for_fields(fields, comb=2) + + assert weights == [("X", "X"), ("X", "Y"), ("Y", "Y")] + + weights = weights_for_fields(fields, comb=2, include=[("A",)]) + + assert weights == [("X", "X"), ("X", "Y")] + + weights = weights_for_fields(fields, comb=2, exclude=[("A", "B")]) + + assert weights == [("X", "X"), ("Y", "Y")] + + weights = weights_for_fields(fields, comb=2, include=[("A", "B")], append_eb=True) + + assert weights == [] + + weights = weights_for_fields(fields, comb=2, include=[("A", "B_E")], append_eb=True) + + assert weights == [("X", "Y")]