From 9b2109e168b7b71484608e38eb65bc7bd80ff3b6 Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Fri, 5 Jan 2024 22:34:04 +0000 Subject: [PATCH] command line interface --- examples/heracles.cfg | 72 +++ heracles/__main__.py | 25 ++ heracles/cli.py | 893 ++++++++++++++++++++++++++++++++++++++ heracles/io.py | 25 +- heracles/maps/_mapping.py | 7 +- pyproject.toml | 11 +- tests/test_cli.py | 412 ++++++++++++++++++ tests/test_io.py | 6 +- 8 files changed, 1432 insertions(+), 19 deletions(-) create mode 100644 examples/heracles.cfg create mode 100644 heracles/__main__.py create mode 100644 heracles/cli.py create mode 100644 tests/test_cli.py diff --git a/examples/heracles.cfg b/examples/heracles.cfg new file mode 100644 index 0000000..75d321f --- /dev/null +++ b/examples/heracles.cfg @@ -0,0 +1,72 @@ +# example config file for Heracles +# values from [defaults] are applied to all sections + +[defaults] +lmin = 10 +bins = 32 log 2l+1 + +[spectra:clustering] +include = D, D +lmax = 2000 +l2max = 4000 + +[spectra:shear] +include = + G_E, G_E + G_B, G_B + G_E, G_B +lmax = 3000 +l2max = 5000 + +[spectra:ggl] +include = + D, G_E + D, G_B +lmax = 1000 +l2max = 2000 + +[fields:D] +type = positions +columns = + SHE_RA + SHE_DEC +mask = V +nside = 2048 +lmax = 2000 + +[fields:G] +type = shears +columns = + SHE_RA + SHE_DEC + SHE_E1_CAL + -SHE_E2_CAL + SHE_WEIGHT +mask = W +nside = 2048 +lmax = 3000 + +[fields:V] +type = visibility +nside = 4096 +lmax = 6000 + +[fields:W] +type = weights +columns = + SHE_RA + SHE_DEC + SHE_WEIGHT +nside = 8192 +lmax = 8000 + +[catalogs:fs2-dr1n-noia] +source = catalog.fits +selections = + 0 = TOM_BIN_ID==0 + 1 = TOM_BIN_ID==1 + 2 = TOM_BIN_ID==2 +visibility = + 0 = vmap.0.fits + 1 = vmap.1.fits + 2 = vmap.2.fits diff --git a/heracles/__main__.py b/heracles/__main__.py new file mode 100644 index 0000000..41bacb8 --- /dev/null +++ b/heracles/__main__.py @@ -0,0 +1,25 @@ +# Heracles: Euclid code for harmonic-space statistics on the sphere +# +# Copyright (C) 2023 Euclid Science Ground Segment +# +# This file is part of Heracles. +# +# Heracles is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Heracles is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with Heracles. If not, see . +"""Executable module.""" + +import sys + +from .cli import main + +sys.exit(main()) diff --git a/heracles/cli.py b/heracles/cli.py new file mode 100644 index 0000000..8e11901 --- /dev/null +++ b/heracles/cli.py @@ -0,0 +1,893 @@ +# Heracles: Euclid code for harmonic-space statistics on the sphere +# +# Copyright (C) 2023 Euclid Science Ground Segment +# +# This file is part of Heracles. +# +# Heracles is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Heracles is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with Heracles. If not, see . +"""The Heracles command line interface.""" + +from __future__ import annotations + +import argparse +import configparser +import logging +import os +from collections.abc import Iterable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Callable, Union + +import numpy as np + +if TYPE_CHECKING: + from numpy.typing import NDArray + +# valid option keys +FIELD_TYPES = { + "positions": "heracles.fields:Positions", + "shears": "heracles.fields:Shears", + "visibility": "heracles.fields:Visibility", + "weights": "heracles.fields:Weights", +} + +# global path to healpy's data files +HEALPIX_DATAPATH: str | None = None + + +def getlist(value: str) -> list[str]: + """Convert to list.""" + return list(filter(None, map(str.strip, value.splitlines()))) + + +def getdict(value: str) -> dict[str, str]: + """Convert to dictionary.""" + out = {} + for line in map(str.strip, value.splitlines()): + if not line: + continue + key, sep, val = line.partition("=") + if sep != "=": + msg = f"Invalid value: {line!r} (expected 'KEY = VALUE')" + raise ValueError(msg) + out[key.rstrip()] = val.lstrip() + return out + + +def getchoice(value: str, choices: dict[str, Any]) -> Any: + """Get choice from a fixed set of option values.""" + try: + return choices[value] + except KeyError: + expected = ", ".join(map(repr, choices)) + msg = f"Invalid value: {value!r} (expected {expected})" + raise ValueError(msg) from None + + +def getpath(value: str) -> str: + "Convert to path, expanding environment variables." + return os.path.expanduser(os.path.expandvars(value)) + + +def getfilter(value: str) -> list[tuple[Any, ...]]: + """Convert to list of include or exclude filters.""" + filt = [] + for row in getlist(value): + item = [] + for part in map(str.strip, row.split(",")): + if part == "...": + item.append(...) + elif part.isdigit(): + item.append(int(part)) + else: + item.append(part) + filt.append(tuple(item)) + return filt + + +class ConfigParser(configparser.ConfigParser): + """ConfigParser with additional getters.""" + + _UNSET = configparser._UNSET + + def __init__(self) -> None: + # fully specify parent class + super().__init__( + defaults=None, + dict_type=dict, + allow_no_value=False, + delimiters=("=",), + comment_prefixes=("#",), + inline_comment_prefixes=None, + strict=True, + empty_lines_in_values=False, + default_section="defaults", + interpolation=None, + converters={ + "list": getlist, + "dict": getdict, + "path": getpath, + "filter": getfilter, + }, + ) + + def getchoice( + self, + section, + option, + choices, + *, + raw=False, + vars=None, # noqa: A002 + fallback=_UNSET, + ): + """Get choice from a fixed set of option values.""" + try: + value = self.get(section, option, raw=False, vars=None) + except (configparser.NoSectionError, configparser.NoOptionError): + if fallback is not self._UNSET: + return fallback + raise + return getchoice(value, choices) + + def sections(self, prefix: str | None = None) -> list[str]: + """ + Return all the configuration section names. If given, only + sections starting with *prefix* are returned. + """ + + sections = super().sections() + if prefix is not None: + sections = [s for s in sections if s.startswith(prefix)] + return sections + + def subsections(self, group: str) -> dict[str, str]: + """ + Return a mapping of subsections in *group*. + """ + sections = self.sections(f"{group}:") + return {s.rpartition(":")[-1].strip(): s for s in sections} + + +def field_from_config(config, section): + """Construct a field instance from config.""" + + from pkgutil import resolve_name + + _type = config.getchoice(section, "type", FIELD_TYPES) + if isinstance(_type, str): + try: + cls = resolve_name(_type) + except (ValueError, ImportError, AttributeError) as exc: + value = config.get(section, "type") + msg = ( + f"Internal error: field type {value!r} maps to type {_type!r}, " + f"which raised the following error: {exc!s}" + ) + raise RuntimeError(msg) from None + else: + cls = _type + columns = config.getlist(section, "columns", fallback=()) + mask = config.get(section, "mask", fallback=None) + return cls(*columns, mask=mask) + + +def fields_from_config(config): + """Construct all field instances from config.""" + sections = config.subsections("fields") + return { + name: field_from_config(config, section) for name, section in sections.items() + } + + +def mapper_from_config(config, section): + """Construct a mapper instance from config.""" + + from .maps import Healpix + + nside = config.getint(section, "nside") + return Healpix(nside, datapath=HEALPIX_DATAPATH) + + +def mappers_from_config(config): + """Construct all mapper instances from config.""" + sections = config.subsections("fields") + return { + name: mapper_from_config(config, section) for name, section in sections.items() + } + + +def catalog_from_config(config, section, label=None, *, out=None): + """Construct a catalogue instance from config.""" + + from .catalog import FitsCatalog + from .io import read_vmap + + # TODO support non-FITS catalogue sources + source = config.getpath(section, "source") + # check if visibility is per catalogue or per selection + visibility: str | Mapping[str, str] + visibility = config.get(section, "visibility", fallback=None) + # check if visibility is a mapping + if visibility and "\n" in visibility: + visibility = config.getdict(section, "visibility") + selections = config.getdict(section, "selections") + # build the base catalogue + base_catalog = FitsCatalog(source) + base_catalog.label = label + # set base catalogue's visibility if just one was given + if isinstance(visibility, str): + try: + base_catalog.visibility = read_vmap(getpath(visibility)) + except (TypeError, ValueError, OSError) as exc: + msg = f"Cannot load visibility: {exc!s}" + raise ValueError(msg) + # create a view of the base catalogue for each selection + # since `out` can be given, also keep track of selections added here + if out is None: + out = {} + added = set() + for key, where in selections.items(): + # convert key to number and make sure it wasn't used before + num = int(key) + if out and num in out: + msg = f"Duplicate selection: {num}" + raise ValueError(msg) + # create view from selection string, if present + # otherwise, give the base catalog itself + if where: + catalog = base_catalog.where(where) + else: + catalog = base_catalog + # store the selection + out[num] = catalog + added.add(num) + # assign visibilities to individual selections if a mapping was given + # only allow visibilities for selections added here + if isinstance(visibility, Mapping): + for key, vmap in visibility.items(): + num = int(key) + if num not in added: + msg = f"Invalid value: unknown selection '{num}'" + raise ValueError(msg) + try: + out[num].visibility = read_vmap(getpath(vmap)) + except (TypeError, ValueError, OSError) as exc: + msg = f"Cannot load visibility: {exc!s}" + raise ValueError(msg) + # all done, return `out` unconditionally + return out + + +def catalogs_from_config(config): + """Construct all catalog instances from config.""" + sections = config.subsections("catalogs") + catalogs = {} + for label, section in sections.items(): + catalog_from_config(config, section, label, out=catalogs) + return catalogs + + +def lmax_from_config(config): + """Construct a dictionary with LMAX values for all fields.""" + sections = config.subsections("fields") + return {name: config.getint(section, "lmax") for name, section in sections.items()} + + +def bins_from_config(config, section): + """Construct angular bins from config.""" + + # dictionary of {spacing: (op, invop)} + spacings = { + "linear": (lambda x: x, lambda x: x), + "log": (np.log10, lambda x: 10**x), + "sqrt": (np.sqrt, np.square), + "log1p": (np.log1p, np.expm1), + } + + # dictionary of known weights + weights = { + None, + "2l+1", + "l(l+1)", + } + + bins = config.get(section, "bins", fallback="none") + + if bins == "none": + return None, None + + binopts = bins.split() + + if not 2 <= len(binopts) <= 3: + msg = f"{section}: bins should be of the form ' []'" + raise ValueError(msg) + + n = int(binopts[0]) + s = binopts[1] + w = binopts[2] if len(binopts) > 2 else None + + if n < 2: + msg = f"Invalid bin size '{n}' in section {section}" + raise ValueError(msg) + if s not in spacings: + msg = f"Invalid bin spacing '{s}' in section {section}" + raise ValueError(msg) + if w is not None and w not in weights: + msg = f"Invalid bin weights '{w}' in section {section}" + raise ValueError(msg) + + lmin = config.getint(section, "lmin", fallback=1) + lmax = config.getint(section, "lmax") + + op, inv = spacings[s] + arr = inv(np.linspace(op(lmin), op(lmax + 1), n + 1)) + # fix first and last array element to be exact + arr[0], arr[-1] = lmin, lmax + 1 + + return arr, w + + +def spectrum_from_config(config, section): + """Construct info dict for angular power spectra from config.""" + + options = config[section] + + info = {} + if "lmax" in options: + info["lmax"] = options.getint("lmax") + if "l2max" in options: + info["l2max"] = options.getint("l2max") + if "l3max" in options: + info["l3max"] = options.getint("l3max") + if "include" in options: + info["include"] = options.getfilter("include") + if "exclude" in options: + info["exclude"] = options.getfilter("exclude") + if "debias" in options: + info["debias"] = options.getboolean("debias") + if "bins" in options: + info["bins"] = bins_from_config(config, section) + + return info + + +def spectra_from_config(config): + """Construct pairs of labels and *kwargs* for angular power spectra.""" + sections = config.subsections("spectra") + spectra = [] + for label, section in sections.items(): + spectra += [(label, spectrum_from_config(config, section))] + if not spectra: + spectra += [(None, {})] + return spectra + + +# the type of a single path +Path = Union[str, os.PathLike] + +# the type of one or multiple paths +Paths = Union[Path, Iterable[Path]] + +# the type of loader functions for load_xyz() +ConfigLoader = Callable[[Paths], ConfigParser] + + +def configloader(path: Paths) -> ConfigParser: + """Load a config file using configparser.""" + + if isinstance(path, (str, os.PathLike)): + path = (path,) + + config = ConfigParser() + for p in path: + with open(p) as fp: + config.read_file(fp) + return config + + +# this constant sets the default loader +DEFAULT_LOADER = configloader + + +def map_all_selections( + config: ConfigParser, + logger: logging.Logger, +) -> Iterator: + """Iteratively map the catalogues defined in config.""" + + from .maps import map_catalogs + + # load catalogues, mappers, and fields to process + catalogs = catalogs_from_config(config) + mappers = mappers_from_config(config) + fields = fields_from_config(config) + + logger.info("fields %s", ", ".join(map(repr, fields))) + + # process each catalogue separately into maps + for key, catalog in catalogs.items(): + logger.info( + "%s%s", + f"catalog {catalog.label!r}, " if catalog.label else "", + f"selection {key}", + ) + + # maps for single catalogue + yield map_catalogs( + mappers, + fields, + {key: catalog}, + parallel=True, # process everything at this level in one go + progress=True, + ) + + +def load_all_maps(paths: Paths, logger: logging.Logger) -> Iterator: + """Iterate over MapFits from a path or list of paths.""" + + from .io import MapFits + + # make iterable if single path is given + if isinstance(paths, (str, os.PathLike)): + paths = (paths,) + + for path in paths: + logger.info("reading maps from %s", path) + yield MapFits(path, clobber=False) + + +def maps( + path: Path, + *, + files: Paths, + logger: logging.Logger, + loader: ConfigLoader = DEFAULT_LOADER, +) -> None: + """compute maps""" + + from .io import MapFits + + # load the config file, this contains the maps definition + logger.info("reading configuration from %s", files) + config = loader(files) + + # iterator over the individual maps + # this generates maps on the fly + itermaps = map_all_selections(config, logger) + + # output goes into a FITS-backed tocdict so we don't fill memory up + out = MapFits(path, clobber=True) + + # iterate over maps, keeping only one in memory at a time + for maps in itermaps: + # write to disk + logger.info("writing maps to %s", path) + out.update(maps) + # forget maps before next turn to free some memory + del maps + + +def alms( + path: Path, + *, + files: Paths | None, + maps: Paths | None, + healpix_datapath: Path | None = None, + logger: logging.Logger, + loader: ConfigLoader = DEFAULT_LOADER, +) -> None: + """compute spherical harmonic coefficients + + Compute spherical harmonic coefficients (alms) from catalogues or + maps. For catalogue input, the maps for each selection are created + in memory and discarded after its alms have been computed. + + """ + + global HEALPIX_DATAPATH + + from .io import AlmFits + from .maps import transform_maps + + # load the config file, this contains alms setting and maps definition + logger.info("reading configuration from %s", files) + config = loader(files) + + # set the HEALPix datapath + # FIXME: make this part of a configurable mapper interface + HEALPIX_DATAPATH = healpix_datapath + + # load the mappers to perform the transformation of each field + mappers = mappers_from_config(config) + + # load the individual lmax values for each field into a dictionary + lmax = lmax_from_config(config) + + # process either catalogues or maps + # everything is loaded via iterators to keep memory use low + itermaps: Iterator + if maps: + itermaps = load_all_maps(maps, logger) + else: + itermaps = map_all_selections(config, logger) + + # output goes into a FITS-backed tocdict so we don't fill up memory + logger.info("writing alms to %s", path) + out = AlmFits(path, clobber=True) + + # iterate over maps and transform each + for maps in itermaps: + logger.info("transforming %d maps", len(maps)) + transform_maps( + mappers, + maps, + lmax=lmax, + progress=True, + out=out, + ) + del maps + + +def chained_alms(alms: Paths | None) -> Mapping[Any, NDArray] | None: + """Return a ChainMap of AlmFits from all input alms, or None.""" + from collections import ChainMap + + from .io import AlmFits + + if alms is None: + return None + return ChainMap(*(AlmFits(alm) for alm in reversed(alms))) + + +def spectra( + path: Path, + *, + files: Paths, + alms: Paths, + alms2: Paths | None, + logger: logging.Logger, + loader: ConfigLoader = DEFAULT_LOADER, +) -> None: + """compute angular power spectra""" + + from .io import ClsFits + from .twopoint import angular_power_spectra, binned_cls, debias_cls + + # load the config file, this contains angular binning settings + logger.info("reading configuration from %s", files) + config = loader(files) + + # collect angular power spectra settings from config + spectra = spectra_from_config(config) + + # link all alms together + all_alms, all_alms2 = chained_alms(alms), chained_alms(alms2) + + # create an empty cls file, then fill it iteratively with alm combinations + out = ClsFits(path, clobber=True) + + total = 0 + logger.info("using %d set(s) of alms", len(all_alms)) + if all_alms2 is not None: + logger.info("using %d set(s) of cross-alms", len(all_alms2)) + for label, info in spectra: + logger.info( + "computing %s spectra", + repr(label) if label is not None else "all", + ) + # compute spectra + cls = angular_power_spectra( + all_alms, + all_alms2, + lmax=info.get("lmax"), + include=info.get("include"), + exclude=info.get("exclude"), + ) + # debias spectra if configured + if info.get("debias", True): + logger.info("debiasing %d spectra", len(cls)) + debias_cls(cls, inplace=True) + # bin spectra if configured + if info.get("bins") is not None: + bins, weights = info["bins"] + logger.info( + "binning %d spectra into %d bins using %s weights", + len(cls), + len(bins) - 1, + repr(weights) if weights is not None else "no", + ) + binned_cls(cls, bins, weights=weights, out=cls) + logger.info("writing %d spectra to %s", len(cls), path) + out.update(cls) + total += len(cls) + del cls + logger.info("finished computing %d spectra", total) + + +def mixmats( + path: Path, + *, + files: Paths, + alms: Paths, + alms2: Paths | None, + logger: logging.Logger, + loader: ConfigLoader = DEFAULT_LOADER, +) -> None: + """compute mixing matrices""" + + from .fields import get_masks + from .io import MmsFits + from .twopoint import angular_power_spectra, binned_mms, mixing_matrices + + # load the config file, this contains angular binning settings + logger.info("reading configuration from %s", files) + config = loader(files) + + # collect the defined fields from config + fields = fields_from_config(config) + + # collect angular power spectra settings from config + spectra = spectra_from_config(config) + + # link all alms together + all_alms, all_alms2 = chained_alms(alms), chained_alms(alms2) + + # create an empty mms file, then fill it iteratively + out = MmsFits(path, clobber=True) + + # go through all alm combinations one by one + total = 0 + logger.info("using %d set(s) of alms", len(all_alms)) + if all_alms2 is not None: + logger.info("using %d set(s) of cross-alms", len(all_alms2)) + for label, info in spectra: + # get mask combinations for fields included in these spectra + include, exclude = info.get("include"), info.get("exclude") + include_masks = get_masks( + fields, + comb=2, + include=include, + exclude=exclude, + append_eb=True, + ) + if not include_masks: + logger.info( + "missing masks for %s spectra, skipping...", + repr(label) if label is not None else "all", + ) + continue + logger.info( + "computing %s mask spectra for %s", + repr(label) if label is not None else "all", + ", ".join(map(str, include_masks)), + ) + # determine the various lmax values + lmax, l2max, l3max = info.get("lmax"), info.get("l2max"), info.get("l3max") + # compute spectra of masks + mask_cls = angular_power_spectra( + all_alms, + all_alms2, + lmax=l3max, + include=include_masks, + ) + # now compute the mixing matrices from these spectra + logger.info( + "computing %s mixing matrices from %d spectra", + repr(label) if label is not None else "all", + len(mask_cls), + ) + mms = mixing_matrices( + fields, + mask_cls, + l1max=lmax, + l2max=l2max, + l3max=l3max, + progress=True, + ) + # bin mixing matrices if configured + if info.get("bins") is not None: + bins, weights = info["bins"] + logger.info( + "binning %d mixing matrices into %d bins using %s weights", + len(mms), + len(bins) - 1, + repr(weights) if weights is not None else "no", + ) + binned_mms(mms, bins, weights=weights, out=mms) + logger.info("writing %d mixing matrices to %s", len(mms), path) + out.update(mms) + total += len(mms) + del mask_cls, mms + logger.info("finished computing %d mixing matrices", total) + + +class MainFormatter(argparse.RawDescriptionHelpFormatter): + """Formatter that keeps order of arguments for usage.""" + + def add_usage(self, usage, actions, groups, prefix=None): + self.actions = actions + super().add_usage(usage, actions, groups, prefix) + + def _format_actions_usage(self, actions, groups): + return super()._format_actions_usage(self.actions, groups) + + +def main(): + """Main method of the `heracles` command. + + Parses arguments and calls the appropriate subcommand. + + """ + + def add_command(func): + """Create a subparser for a command given by a function.""" + + name = func.__name__ + doc = func.__doc__.strip() + help_, _, description = doc.partition("\n") + + parser = commands.add_parser( + name, + help=help_, + description=description, + parents=[cmd_parser], + formatter_class=MainFormatter, + ) + parser.set_defaults(cmd=func) + return parser + + # common parser for all subcommands + cmd_parser = argparse.ArgumentParser( + add_help=False, + ) + cmd_parser.add_argument( + "-c", + "--config", + default="heracles.cfg", + help="configuration file (can be repeated)", + metavar="", + action="append", + dest="files", + ) + + # main parser for CLI invokation + main_parser = argparse.ArgumentParser( + prog="heracles", + epilog="Made in the Euclid Science Ground Segment", + formatter_class=MainFormatter, + ) + main_parser.set_defaults(cmd=None) + + commands = main_parser.add_subparsers( + title="commands", + metavar="", + help="the processing step to carry out", + ) + + ######## + # maps # + ######## + + parser = add_command(maps) + group = parser.add_argument_group("output") + group.add_argument( + "path", + help="output FITS file for maps", + metavar="", + ) + + ######## + # alms # + ######## + + parser = add_command(alms) + parser.add_argument( + "--healpix-datapath", + help="path to HEALPix data files", + metavar="", + ) + group = parser.add_argument_group("output") + group.add_argument( + "path", + help="output FITS file for alms", + metavar="", + ) + group = parser.add_argument_group("inputs") + group.add_argument( + "maps", + nargs="*", + default=None, + help="input FITS file(s) for maps", + metavar="", + ) + + ########### + # spectra # + ########### + + parser = add_command(spectra) + group = parser.add_argument_group("output") + group.add_argument( + "path", + help="output FITS file for spectra", + metavar="", + ) + group = parser.add_argument_group("inputs") + group.add_argument( + "alms", + nargs="+", + help="input FITS file(s) for alms", + metavar="", + ) + group.add_argument( + "-X", + nargs="+", + help="input FITS file(s) for cross-spectra", + metavar="", + dest="alms2", + ) + + ########### + # mixmats # + ########### + + parser = add_command(mixmats) + group = parser.add_argument_group("output") + group.add_argument( + "path", + help="output FITS file for mixing matrices", + metavar="", + ) + group = parser.add_argument_group("inputs") + group.add_argument( + "alms", + nargs="+", + help="input FITS file(s) for alms", + metavar="", + ) + group.add_argument( + "-X", + nargs="+", + help="input FITS file(s) for cross-spectra", + metavar="", + dest="alms2", + ) + + ####### + # run # + ####### + + args = main_parser.parse_args() + + # show full help if no command is given + if args.cmd is None: + main_parser.print_help() + return 1 + + # get keyword args + kwargs = vars(args) + cmd = kwargs.pop("cmd") + + # set up logger for CLI output + logger = logging.getLogger(__name__) + logger.addHandler(logging.StreamHandler()) + logger.setLevel(logging.DEBUG) + + try: + cmd(**kwargs, logger=logger) + except Exception as exc: # noqa: BLE001 + logger.debug("Exception", exc_info=exc) + logger.error(f"ERROR: {exc!s}") + return 1 + else: + return 0 diff --git a/heracles/io.py b/heracles/io.py index 2977e97..669926b 100644 --- a/heracles/io.py +++ b/heracles/io.py @@ -56,7 +56,8 @@ def _write_metadata(hdu, metadata): """write array metadata to FITS HDU""" md = metadata or {} for key, value in md.items(): - hdu.write_key("META " + key.upper(), value, _METADATA_COMMENTS.get(key)) + comment = _METADATA_COMMENTS.get(key, "") + hdu.write_key("META " + key.upper(), value, comment) def _read_metadata(hdu): @@ -531,8 +532,8 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud # write a new TOC extension if FITS doesn't already contain one if "MMTOC" not in fits: fits.create_table_hdu( - names=["EXT", "NAME", "BIN1", "BIN2"], - formats=["10A", "10A", "I", "I"], + names=["EXT", "NAME1", "NAME2", "BIN1", "BIN2"], + formats=["10A", "10A", "10A", "I", "I"], extname="MMTOC", ) @@ -545,12 +546,12 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud mmn += 1 # write every mixing matrix - for (n, i1, i2), mm in mms.items(): + for (k1, k2, i1, i2), mm in mms.items(): # skip if not selected - if not toc_match((n, i1, i2), include=include, exclude=exclude): + if not toc_match((k1, k2, i1, i2), include=include, exclude=exclude): continue - logger.info("writing mixing matrix %s for bins %s, %s", n, i1, i2) + logger.info("writing %s x %s mm for bins %s, %s", k1, k2, i1, i2) # the mm extension name ext = f"MM{mmn}" @@ -560,7 +561,7 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud _write_twopoint(fits, ext, mm, "MM") # write the TOC entry - tocentry[0] = (ext, n, i1, i2) + tocentry[0] = (ext, k1, k2, i1, i2) fits["MMTOC"].append(tocentry) logger.info("done with %d mm(s)", len(mms)) @@ -584,16 +585,16 @@ def read_mms(filename, workdir=".", *, include=None, exclude=None): # read every entry in the TOC, add it to the list, then read the mms for entry in fits_toc: - ext, n, i1, i2 = entry[["EXT", "NAME", "BIN1", "BIN2"]] + ext, k1, k2, i1, i2 = entry[["EXT", "NAME1", "NAME2", "BIN1", "BIN2"]] # skip if not selected - if not toc_match((n, i1, i2), include=include, exclude=exclude): + if not toc_match((k1, k2, i1, i2), include=include, exclude=exclude): continue - logger.info("reading mixing matrix %s for bins %s, %s", n, i1, i2) + logger.info("writing %s x %s mm for bins %s, %s", k1, k2, i1, i2) # read the mixing matrix from the extension and store in set of mms - mms[n, i1, i2] = _read_twopoint(fits, ext) + mms[k1, k2, i1, i2] = _read_twopoint(fits, ext) logger.info("done with %d mm(s)", len(mms)) @@ -889,6 +890,6 @@ class MmsFits(TocFits): """FITS-backed mapping for mixing matrices.""" tag = "MM" - columns = {"NAME": "10A", "BIN1": "I", "BIN2": "I"} + columns = {"NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"} reader = staticmethod(_read_twopoint) writer = partial(_write_twopoint, name=tag) diff --git a/heracles/maps/_mapping.py b/heracles/maps/_mapping.py index 064a09b..f436237 100644 --- a/heracles/maps/_mapping.py +++ b/heracles/maps/_mapping.py @@ -126,7 +126,12 @@ def map_catalogs( coros.append(coro) # run all coroutines concurrently - results = coroutines.run(coroutines.gather(*coros)) + try: + results = coroutines.run(coroutines.gather(*coros)) + finally: + # force-close coroutines to prevent "never awaited" warnings + for coro in coros: + coro.close() # store results for key, value in zip(keys, results): diff --git a/pyproject.toml b/pyproject.toml index f29b6f5..3d7fca9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,14 @@ optional-dependencies = {all = [ ]} readme = "README.md" requires-python = ">=3.9" -urls.Documentation = "https://heracles.readthedocs.io/" -urls.Homepage = "https://github.com/heracles-ec/heracles" -urls.Issues = "https://github.com/heracles-ec/heracles/issues" + +[project.scripts] +heracles = "heracles.cli:main" + +[project.urls] +Documentation = "https://heracles.readthedocs.io/" +Homepage = "https://github.com/heracles-ec/heracles" +Issues = "https://github.com/heracles-ec/heracles/issues" [tool.black] force-exclude = """ diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..f3e5f87 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,412 @@ +from unittest.mock import patch + +import numpy as np +import numpy.testing as npt +import pytest + + +def test_getlist(): + from heracles.cli import getlist + + assert ( + getlist( + """ + x + y + z + """, + ) + == ["x", "y", "z"] + ) + assert getlist("xyz") == ["xyz"] + + +def test_getdict(): + from heracles.cli import getdict + + assert ( + getdict( + """ + x=1 + y = 2 + z= 3 + """, + ) + == {"x": "1", "y": "2", "z": "3"} + ) + + with pytest.raises(ValueError, match="Invalid value"): + getdict( + """ + 1 + 2 + 3 + """, + ) + + +def test_getchoice(): + from heracles.cli import getchoice + + choices = { + "first": 1, + "second": 2, + } + + assert getchoice("first", choices) == 1 + with pytest.raises(ValueError, match="Invalid value"): + getchoice("third", choices) + + +@patch.dict("os.environ", {"HOME": "/home/user", "TEST": "folder"}) +def test_getpath(): + from heracles.cli import getpath + + assert getpath("~/${TEST}/file.txt") == "/home/user/folder/file.txt" + + +def test_getfilter(): + from heracles.cli import getfilter + + assert getfilter("a") == [("a",)] + assert getfilter("a, ..., 1, 2") == [("a", ..., 1, 2)] + assert ( + getfilter( + """ + a, 1 + b, 2 + """, + ) + == [("a", 1), ("b", 2)] + ) + + +def test_subsections(): + from heracles.cli import ConfigParser + + config = ConfigParser() + config.read_dict( + { + "getme:a": {}, + "getme: b ": {}, + "getmenot:c": {}, + }, + ) + + assert config.subsections("getme") == {"a": "getme:a", "b": "getme: b "} + + +def test_field_from_config(): + from unittest.mock import Mock + + from heracles.cli import ConfigParser, field_from_config + + mock, other_mock = Mock(), Mock() + + mock_field_types = { + "test": mock, + "other_test": other_mock, + "error": "", + } + + config = ConfigParser() + config.read_dict( + { + "a": { + "type": "test", + "columns": """ + COL1 + -COL2 + """, + "mask": "x", + }, + "b": { + "type": "other_test", + }, + "c": { + "type": "error", + }, + }, + ) + + with patch.dict("heracles.cli.FIELD_TYPES", mock_field_types): + a = field_from_config(config, "a") + b = field_from_config(config, "b") + with pytest.raises(RuntimeError, match="Internal error"): + field_from_config(config, "c") + + mock.assert_called_once_with("COL1", "-COL2", mask="x") + assert mock.return_value is a + other_mock.assert_called_once_with(mask=None) + assert other_mock.return_value is b + + +@patch("heracles.cli.field_from_config") +def test_fields_from_config(mock): + from heracles.cli import ConfigParser, fields_from_config + + config = ConfigParser() + config.read_dict( + { + "fields:a": {}, + "fields:b": {}, + "fields:c": {}, + }, + ) + + m = fields_from_config(config) + + assert m == { + "a": mock.return_value, + "b": mock.return_value, + "c": mock.return_value, + } + assert mock.call_args_list == [ + ((config, "fields:a"),), + ((config, "fields:b"),), + ((config, "fields:c"),), + ] + + +@patch("heracles.cli.HEALPIX_DATAPATH") +@patch("heracles.maps.Healpix") +def test_mapper_from_config(mock, mock_datapath): + from heracles.cli import ConfigParser, mapper_from_config + + config = ConfigParser() + config.read_dict( + { + "a": { + "nside": "1", + }, + }, + ) + + a = mapper_from_config(config, "a") + mock.assert_called_once_with(1, datapath=mock_datapath) + assert mock.return_value is a + + +@patch("heracles.cli.mapper_from_config") +def test_mappers_from_config(mock): + from heracles.cli import ConfigParser, mappers_from_config + + config = ConfigParser() + config.read_dict( + { + "fields:a": {}, + "fields:b": {}, + "fields:c": {}, + }, + ) + + m = mappers_from_config(config) + + assert m == { + "a": mock.return_value, + "b": mock.return_value, + "c": mock.return_value, + } + assert mock.call_args_list == [ + ((config, "fields:a"),), + ((config, "fields:b"),), + ((config, "fields:c"),), + ] + + +@patch("heracles.io.read_vmap") +def test_catalog_from_config(mock): + from heracles.cli import ConfigParser, catalog_from_config + + # single visibility + + config = ConfigParser() + config.read_dict( + { + "test_with_single_visibility": { + "source": "catalog.fits", + "selections": """ + 0 = TOM_BIN_ID==0 + 1 = TOM_BIN_ID==1 + 2 = TOM_BIN_ID==2 + """, + "visibility": "vmap.fits", + }, + "test_with_many_visibilities": { + "source": "catalog.fits", + "selections": """ + 0 = TOM_BIN_ID==0 + 1 = TOM_BIN_ID==1 + 2 = TOM_BIN_ID==2 + """, + "visibility": """ + 0 = vmap.0.fits + 2 = vmap.2.fits + """, + }, + }, + ) + + catalog = catalog_from_config(config, "test_with_single_visibility", "label 1") + + assert catalog.keys() == {0, 1, 2} + assert catalog[0].base.__class__.__name__ == "FitsCatalog" + assert catalog[0].base.path == "catalog.fits" + assert catalog[0].base.visibility is mock.return_value + assert catalog[0].base.label == "label 1" + assert catalog[1].base is catalog[0].base + assert catalog[2].base is catalog[0].base + assert catalog[0].label is catalog[0].base.label + assert catalog[1].label is catalog[0].base.label + assert catalog[2].label is catalog[0].base.label + assert catalog[0].selection == "TOM_BIN_ID==0" + assert catalog[1].selection == "TOM_BIN_ID==1" + assert catalog[2].selection == "TOM_BIN_ID==2" + assert catalog[0].visibility is catalog[0].base.visibility + assert catalog[1].visibility is catalog[0].base.visibility + assert catalog[2].visibility is catalog[0].base.visibility + assert mock.call_args_list == [(("vmap.fits",),)] + + mock.reset_mock() + + catalog = catalog_from_config(config, "test_with_many_visibilities", "label 2") + + assert catalog.keys() == {0, 1, 2} + assert catalog[0].base.__class__.__name__ == "FitsCatalog" + assert catalog[0].base.path == "catalog.fits" + assert catalog[0].base.visibility is None + assert catalog[0].base.label == "label 2" + assert catalog[1].base is catalog[0].base + assert catalog[2].base is catalog[0].base + assert catalog[0].label is catalog[0].base.label + assert catalog[1].label is catalog[0].base.label + assert catalog[2].label is catalog[0].base.label + assert catalog[0].selection == "TOM_BIN_ID==0" + assert catalog[1].selection == "TOM_BIN_ID==1" + assert catalog[2].selection == "TOM_BIN_ID==2" + assert catalog[0].visibility is mock.return_value + assert catalog[1].visibility is None + assert catalog[2].visibility is mock.return_value + assert mock.call_args_list == [ + (("vmap.0.fits",),), + (("vmap.2.fits",),), + ] + + with pytest.raises(ValueError, match="Duplicate selection"): + catalog_from_config(config, "test_with_single_visibility", out=catalog) + + +@patch("heracles.cli.catalog_from_config") +def test_catalogs_from_config(mock): + from heracles.cli import ConfigParser, catalogs_from_config + + config = ConfigParser() + config.read_dict( + { + "catalogs:a": {}, + "catalogs:b": {}, + "catalogs:c": {}, + }, + ) + + c = catalogs_from_config(config) + + assert mock.call_args_list == [ + ((config, "catalogs:a", "a"), {"out": c}), + ((config, "catalogs:b", "b"), {"out": c}), + ((config, "catalogs:c", "c"), {"out": c}), + ] + + +def test_lmax_from_config(): + from heracles.cli import ConfigParser, lmax_from_config + + config = ConfigParser() + config.read_dict( + { + "defaults": {"lmax": 30}, + "fields:a": {"lmax": 10}, + "fields:b": {"lmax": 20}, + "fields:c": {}, # should use defaults + }, + ) + + assert lmax_from_config(config) == {"a": 10, "b": 20, "c": 30} + + +def test_bins_from_config(): + from heracles.cli import ConfigParser, bins_from_config + + config = ConfigParser() + config.read_dict( + { + "linear_bins": { + "lmin": "0", + "lmax": "4", + "bins": "5 linear", + }, + "log_bins": { + "lmin": "1", + "lmax": "999_999", + "bins": "6 log", + }, + "sqrt_bins": { + "lmin": "1", + "lmax": "35", + "bins": "5 sqrt", + }, + "log1p_bins": { + "lmin": "0", + "lmax": "9", + "bins": "5 log1p", + }, + }, + ) + + npt.assert_array_equal( + bins_from_config(config, "linear_bins")[0], + [0, 1, 2, 3, 4, 5], + ) + + npt.assert_array_equal( + bins_from_config(config, "log_bins")[0], + [1, 10, 100, 1000, 10000, 100000, 1000000], + ) + + npt.assert_array_equal( + bins_from_config(config, "sqrt_bins")[0], + [1, 4, 9, 16, 25, 36], + ) + + npt.assert_allclose( + bins_from_config(config, "log1p_bins")[0], + np.expm1(np.linspace(np.log1p(0), np.log1p(10), 6)), + ) + + +@patch("heracles.cli.bins_from_config") +def test_spectrum_from_config(mock): + from heracles.cli import ConfigParser, spectrum_from_config + + config = ConfigParser() + config.read_dict( + { + "a": { + "lmax": 10, + "l2max": 12, + "l3max": 20, + "include": "x", + "exclude": "y", + "bins": "...", + }, + }, + ) + + assert spectrum_from_config(config, "a") == { + "lmax": 10, + "l2max": 12, + "l3max": 20, + "include": [("x",)], + "exclude": [("y",)], + "bins": mock.return_value, + } diff --git a/tests/test_io.py b/tests/test_io.py index 968d32b..f99694e 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -233,9 +233,9 @@ def test_write_read_mms(rng, tmp_path): workdir = str(tmp_path) mms = { - ("00", 0, 1): rng.standard_normal((10, 10)), - ("0+", 1, 2): rng.standard_normal((20, 5)), - ("++", 2, 3): rng.standard_normal((10, 5, 2)), + ("P", "P", 0, 1): rng.standard_normal((10, 10)), + ("P", "G_E", 1, 2): rng.standard_normal((20, 5)), + ("G_E", "G_E", 2, 3): rng.standard_normal((10, 5, 2)), } write_mms(filename, mms, workdir=workdir)