Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: make bias work for HEALPix and spherical convolution #111

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/example.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions heracles/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,11 @@ def update_metadata(array, *sources, **metadata):
raise ValueError(msg)
# set the new dtype in array
array.dtype = dt


def items_with_suffix(d: Mapping[str, Any], suffix: str) -> Mapping[str, Any]:
"""
Return items from *d* where keys end in *suffix*. Returns a mapping
where *suffix* is removed from keys.
"""
return {k.removesuffix(suffix): v for k, v in d.items() if k.endswith(suffix)}
69 changes: 42 additions & 27 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ async def __call__(

# total weighted variance from online algorithm
ngal = 0
wmean, w2mean, var = 0.0, 0.0, 0.0
wmean, var = 0.0, 0.0

# go through pages in catalogue and map values
async for page in _pages(catalog, progress):
Expand All @@ -338,7 +338,6 @@ async def __call__(
var += (v**2 - var).sum() / ngal
else:
wmean += (w - wmean).sum() / ngal
w2mean += (w**2 - w2mean).sum() / ngal
var += ((w * v) ** 2 - var).sum() / ngal

del lon, lat, v, w
Expand All @@ -348,7 +347,7 @@ async def __call__(

# fix mean weight if there was no column for it
if wcol is None:
wmean = w2mean = 1.0
wmean = 1.0

# compute mean visibility
if catalog.visibility is None:
Expand All @@ -365,11 +364,8 @@ async def __call__(
# compute bias from variance (per object)
bias = 4 * np.pi * vbar**2 * (var / wmean**2) / ngal

# bias correction factor for intrinsic variance
bcor = 4 * np.pi * vbar**2 * (w2mean / wmean**2) / ngal

# set metadata of array
update_metadata(val, self, catalog, mapper, wbar=wbar, bias=bias, bcor=bcor)
update_metadata(val, self, catalog, mapper, wbar=wbar, bias=bias)

# return the value map
return val
Expand Down Expand Up @@ -402,7 +398,7 @@ async def __call__(

# total weighted variance from online algorithm
ngal = 0
wmean, w2mean, var = 0.0, 0.0, 0.0
wmean, var = 0.0, 0.0

# go through pages in catalogue and get the shear values,
async for page in _pages(catalog, progress):
Expand All @@ -421,7 +417,6 @@ async def __call__(
var += (re**2 + im**2 - var).sum() / ngal
else:
wmean += (w - wmean).sum() / ngal
w2mean += (w**2 - w2mean).sum() / ngal
var += ((w * re) ** 2 + (w * im) ** 2 - var).sum() / ngal

del lon, lat, re, im, w
Expand All @@ -430,7 +425,7 @@ async def __call__(

# set mean weight if there was no column for it
if wcol is None:
wmean = w2mean = 1.0
wmean = 1.0

# compute mean visibility
if catalog.visibility is None:
Expand All @@ -447,11 +442,8 @@ async def __call__(
# bias from measured variance, for E/B decomposition
bias = 2 * np.pi * vbar**2 * (var / wmean**2) / ngal

# bias correction factor for intrinsic field variance
bcor = 2 * np.pi * vbar**2 * (w2mean / wmean**2) / ngal

# set metadata of array
update_metadata(val, self, catalog, mapper, wbar=wbar, bias=bias, bcor=bcor)
update_metadata(val, self, catalog, mapper, wbar=wbar, bias=bias)

# return the shear map
return val
Expand Down Expand Up @@ -513,29 +505,52 @@ async def __call__(
# weight map
wht = np.zeros(mapper.size, mapper.dtype)

# total weighted variance from online algorithm
ngal = 0
wmean, w2mean = 0.0, 0.0

# map catalogue
async for page in _pages(catalog, progress):
lon, lat = page.get(*col)
if wcol is not None:
page.delete(page[wcol] == 0)

if wcol is None:
w = None
else:
w = page.get(wcol)
if page.size:
lon, lat = page.get(*col)

w = page.get(wcol) if wcol is not None else None

mapper.map_values(lon, lat, [wht], None, w)

ngal += page.size
if w is not None:
wmean += (w - wmean).sum() / ngal
w2mean += (w**2 - w2mean).sum() / ngal

mapper.map_values(lon, lat, [wht], None, w)
del lon, lat, w

del page

# set mean weight if there was no column for it
if wcol is None:
wmean = w2mean = 1.0

del page, lon, lat, w
# compute mean visibility
if catalog.visibility is None:
vbar = 1
else:
vbar = np.mean(catalog.visibility)

# compute average weight in nonzero pixels
wbar = wht.mean()
if catalog.visibility is not None:
wbar /= np.mean(catalog.visibility)
# mean weight per effective mapper "pixel"
wbar = ngal / (4 * np.pi * vbar) * wmean * mapper.area

# normalise the map
wht /= wbar

# set metadata of arrays
update_metadata(wht, self, catalog, mapper, wbar=wbar)
# bias from weights
bias = 4 * np.pi * vbar**2 * (w2mean / wmean**2) / ngal

# set metadata of array
update_metadata(wht, self, catalog, mapper, wbar=wbar, bias=bias)

# return the weight map
return wht
Expand Down
3 changes: 2 additions & 1 deletion heracles/maps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
"Mapper",
"get_kernels",
"map_catalogs",
"mapper_from_dict",
"transform_maps",
)

from ._healpix import Healpix
from ._mapper import Mapper, get_kernels
from ._mapper import Mapper, get_kernels, mapper_from_dict
from ._mapping import map_catalogs, transform_maps
48 changes: 32 additions & 16 deletions heracles/maps/_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
from ._mapper import Mapper

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Self

from numpy.typing import ArrayLike, DTypeLike
from numpy.typing import ArrayLike, DTypeLike, NDArray


def _asnative(arr):
Expand Down Expand Up @@ -120,12 +121,25 @@ class Healpix(Mapper, kernel="healpix"):
available as the *nside* property.
"""

DATAPATH: str | None = None

@classmethod
def from_dict(cls, d: Mapping[str, Any]) -> Self:
"""
Create a HEALPix mapper from a dictionary.
"""
nside = d["nside"]
if isinstance(nside, str) and nside.isdigit():
nside = int(nside)
elif not isinstance(nside, int):
msg = f"value for key 'nside' is {type(nside).__name__}, expected int"
raise ValueError(msg)
return cls(nside)

def __init__(
self,
nside: int,
dtype: DTypeLike = np.float64,
*,
datapath: str | None = None,
) -> None:
"""
Mapper for HEALPix maps with the given *nside* parameter.
Expand All @@ -134,7 +148,6 @@ def __init__(
self.__nside = nside
self.__npix = hp.nside2npix(nside)
self.__dtype = np.dtype(dtype)
self.__datapath = datapath
self._metadata |= {
"nside": nside,
}
Expand Down Expand Up @@ -222,7 +235,7 @@ def transform(
lmax=lmax,
pol=pol,
use_pixel_weights=True,
datapath=self.__datapath,
datapath=self.DATAPATH,
)

if spin == 0:
Expand All @@ -234,23 +247,26 @@ def transform(

return alms

def deconvolve(self, alm: ArrayLike, *, inplace: bool = False) -> ArrayLike:
def kl(self, lmax: int, spin: int = 0) -> NDArray[Any]:
"""
Remove HEALPix pixel window function from *alm*.
Return the HEALPix pixel window function.
"""

lmax = hp.Alm.getlmax(alm.size)

md = alm.dtype.metadata or {}
spin = md.get("spin", 0)

pw: NDArray[Any]
if spin == 0:
pw = hp.pixwin(self.__nside, lmax=lmax)
elif spin == 2:
_, pw = hp.pixwin(self.__nside, lmax=lmax, pol=True)
pw[:2] = 1.0
else:
msg = f"unsupported spin for deconvolve: {spin}"
msg = f"unsupported spin: {spin}"
raise ValueError(msg)

return hp.almxfl(alm, 1 / pw, inplace=inplace)
return pw

def bl(self, lmax: int, spin: int = 0) -> NDArray[Any]:
"""
Return the biasing kernel for HEALPix.
"""
kl = self.kl(lmax, spin)
where = np.arange(lmax + 1) >= abs(spin)
return np.divide(1.0, kl, where=where, out=np.zeros(lmax + 1))
57 changes: 44 additions & 13 deletions heracles/maps/_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing import Self

from numpy.typing import ArrayLike, DTypeLike
from numpy.typing import ArrayLike, DTypeLike, NDArray

# dictionary of kernel names and their corresponding Mapper classes
_KERNELS: dict[str, type[Mapper]] = {}
Expand All @@ -43,39 +44,63 @@ def get_kernels() -> Mapping[str, type[Mapper]]:
return MappingProxyType(_KERNELS)


class MapperMeta(ABCMeta):
def mapper_from_dict(d: Mapping[str, Any]) -> Mapper:
"""
Metaclass for mappers.
Return a mapper that matches the given metadata.
"""

_kernel: str | None = None
try:
kernel = d["kernel"]
except KeyError:
msg = "no 'kernel' in mapping"
raise ValueError(msg) from None

@property
def kernel(cls) -> str | None:
return cls._kernel
try:
cls = _KERNELS[kernel]
except KeyError:
msg = f"unknown kernel: {kernel}"
raise ValueError(msg) from None

return cls.from_dict(d)


class Mapper(metaclass=MapperMeta):
class Mapper(metaclass=ABCMeta):
"""
Abstract base class for mappers.
"""

def __init_subclass__(cls, /, kernel: str, **kwargs):
__kernel: str

def __init_subclass__(cls, /, kernel: str, **kwargs) -> None:
"""
Initialise mapper subclasses with a *kernel* parameter.
"""
super().__init_subclass__(**kwargs)
cls._kernel = kernel
cls.__kernel = kernel
_KERNELS[kernel] = cls

@classmethod
@abstractmethod
def from_dict(cls, d: Mapping[str, Any]) -> Self:
"""
Create a new mapper instance from a dictionary of parameters.
"""

def __init__(self) -> None:
"""
Initialise a new mapper instance.
"""
self._metadata: dict[str, Any] = {
"kernel": self.__class__.kernel,
"kernel": self.__kernel,
}

@property
def kernel(self) -> str:
"""
Return the name of the kernel for this mapper.
"""
return self.__kernel

@property
def metadata(self) -> Mapping[str, Any]:
"""
Expand Down Expand Up @@ -128,7 +153,13 @@ def transform(
"""

@abstractmethod
def deconvolve(self, alm: ArrayLike, *, inplace: bool = False) -> ArrayLike:
def kl(self, lmax: int, spin: int = 0) -> NDArray[Any]:
"""
Return the convolution kernel in harmonic space.
"""

def bl(self, lmax: int, spin: int = 0) -> None | NDArray[Any]:
"""
Remove this mapper's convolution kernel from *alm*.
Return the biasing kernel in harmonic space.
"""
return None
Loading
Loading