Skip to content

Commit

Permalink
ENH(fields): return weights for a set of fields (#97)
Browse files Browse the repository at this point in the history
Add the `weights_for_fields()` function which returns a set of weight
function names for a set of fields. It has a `comb=` parameter that can
be used to produce tuples of weight function names for combinations of
*N* fields.
  • Loading branch information
ntessore authored Jan 5, 2024
1 parent 20221df commit 8b8d188
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 2 deletions.
60 changes: 58 additions & 2 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@

import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from itertools import combinations_with_replacement, product
from types import MappingProxyType
from typing import TYPE_CHECKING

import coroutines
import numpy as np

from .core import update_metadata
from .core import toc_match, update_metadata

if TYPE_CHECKING:
from collections.abc import AsyncIterable, Mapping, Sequence
from typing import Any
from typing import Any, TypeGuard

from numpy.typing import ArrayLike

Expand Down Expand Up @@ -537,3 +539,57 @@ class Spin2Field(ComplexField, spin=2):

Shears = Spin2Field
Ellipticities = Spin2Field


def weights_for_fields(
fields: Mapping[str, Field],
*,
comb: int | None = None,
include: Sequence[Sequence[str]] | None = None,
exclude: Sequence[Sequence[str]] | None = None,
append_eb: bool = False,
) -> Sequence[str] | Sequence[tuple[str, ...]]:
"""
Return the weights for a given set of fields.
If *comb* is given, produce combinations of weights for combinations
of a number *comb* of fields.
The fields (not weights) can be filtered using the *include* and
*exclude* parameters. If *append_eb* is true, the filter is applied
to field names including the E/B-mode suffix when the spin weight is
non-zero.
"""

isgood = partial(toc_match, include=include, exclude=exclude)

def _key_eb(key: str) -> tuple[str, ...]:
"""Return the key of the given field with _E/_B appended (or not)."""
if append_eb and fields[key].spin != 0:
return (f"{key}_E", f"{key}_B")
return (key,)

def _all_str(seq: tuple[str | None, ...]) -> TypeGuard[tuple[str, ...]]:
"""Return true if all items in *seq* are strings."""
return not any(item is None for item in seq)

if comb is None:
weights_no_comb: list[str] = []
for key, field in fields.items():
if field.weight is None:
continue
if not any(map(isgood, _key_eb(key))):
continue
weights_no_comb.append(field.weight)
return weights_no_comb

weights_comb: list[tuple[str, ...]] = []
for keys in combinations_with_replacement(fields, comb):
item = tuple(fields[key].weight for key in keys)
if not _all_str(item):
continue
if not any(map(isgood, product(*map(_key_eb, keys)))):
continue
weights_comb.append(item)
return weights_comb
40 changes: 40 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,43 @@ def test_weights(mapper, catalog):
"nside": mapper.nside,
}
np.testing.assert_array_almost_equal(m, w / wbar)


def test_weights_for_fields():
from unittest.mock import Mock

from heracles.fields import weights_for_fields

fields = {
"A": Mock(weight="X", spin=0),
"B": Mock(weight="Y", spin=2),
"C": Mock(weight=None),
}

weights = weights_for_fields(fields)

assert weights == ["X", "Y"]

weights = weights_for_fields(fields, comb=1)

assert weights == [("X",), ("Y",)]

weights = weights_for_fields(fields, comb=2)

assert weights == [("X", "X"), ("X", "Y"), ("Y", "Y")]

weights = weights_for_fields(fields, comb=2, include=[("A",)])

assert weights == [("X", "X"), ("X", "Y")]

weights = weights_for_fields(fields, comb=2, exclude=[("A", "B")])

assert weights == [("X", "X"), ("Y", "Y")]

weights = weights_for_fields(fields, comb=2, include=[("A", "B")], append_eb=True)

assert weights == []

weights = weights_for_fields(fields, comb=2, include=[("A", "B_E")], append_eb=True)

assert weights == [("X", "Y")]

0 comments on commit 8b8d188

Please sign in to comment.