From 9cd5b0f1795aaee8219f659cd06ab0909d40db08 Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Tue, 12 Sep 2023 11:45:04 +0100 Subject: [PATCH 1/2] create a "core" module for shared functionality --- .commitlint.rules.js | 11 ++++- examples/example.ipynb | 2 +- heracles/core.py | 88 +++++++++++++++++++++++++++++++++ heracles/io.py | 2 +- heracles/maps.py | 6 ++- heracles/twopoint.py | 2 +- heracles/util.py | 68 ------------------------- tests/test_core.py | 110 +++++++++++++++++++++++++++++++++++++++++ tests/test_util.py | 82 ------------------------------ 9 files changed, 216 insertions(+), 155 deletions(-) create mode 100644 heracles/core.py create mode 100644 tests/test_core.py diff --git a/.commitlint.rules.js b/.commitlint.rules.js index a6d1425..9c3f21b 100644 --- a/.commitlint.rules.js +++ b/.commitlint.rules.js @@ -5,7 +5,16 @@ module.exports = { "scope-enum": [ 2, "always", - ["catalog", "covariance", "io", "maps", "plot", "twopoint", "util"], + [ + "catalog", + "core", + "covariance", + "io", + "maps", + "plot", + "twopoint", + "util", + ], ], "scope-case": [0, "always", "lower-case"], "type-enum": [ diff --git a/examples/example.ipynb b/examples/example.ipynb index 5d3904f..506919e 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -1448,7 +1448,7 @@ "metadata": {}, "outputs": [], "source": [ - "from heracles.util import TocDict\n", + "from heracles.core import TocDict\n", "\n", "theory_cls = TocDict()\n", "for i, s1 in enumerate(sources):\n", diff --git a/heracles/core.py b/heracles/core.py new file mode 100644 index 0000000..bebf96c --- /dev/null +++ b/heracles/core.py @@ -0,0 +1,88 @@ +# Heracles: Euclid code for harmonic-space statistics on the sphere +# +# Copyright (C) 2023 Euclid Science Ground Segment +# +# This file is part of Heracles. +# +# Heracles is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Heracles is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with Heracles. If not, see . +"""module for common core functionality""" + +from collections import UserDict +from collections.abc import Mapping, Sequence + + +def toc_match(key, include=None, exclude=None): + """return whether a tocdict entry matches include/exclude criteria""" + if include is not None: + for pattern in include: + if all(p is Ellipsis or p == k for p, k in zip(pattern, key)): + break + else: + return False + if exclude is not None: + for pattern in exclude: + if all(p is Ellipsis or p == k for p, k in zip(pattern, key)): + return False + return True + + +def toc_filter(obj, include=None, exclude=None): + """return a filtered toc dict ``d``""" + if isinstance(obj, Sequence): + return [toc_filter(item, include, exclude) for item in obj] + if isinstance(obj, Mapping): + return {k: v for k, v in obj.items() if toc_match(k, include, exclude)} + msg = "invalid input type" + raise TypeError(msg) + + +# subclassing UserDict here since that returns the correct type from methods +# such as __copy__(), __or__(), etc. +class TocDict(UserDict): + """Table-of-contents dictionary with pattern-based lookup""" + + def __getitem__(self, pattern): + """look up one or many keys in dict""" + # first, see if pattern is a valid entry in the dict + # might fail with KeyError (no such entry) or TypeError (not hashable) + try: + return self.data[pattern] + except (KeyError, TypeError): + pass + # pattern might be a single object such as e.g. "X" + if not isinstance(pattern, tuple): + pattern = (pattern,) + # no pattern == matches everything + if not pattern: + return self.copy() + # go through all keys in the dict and match them against the pattern + # return an object of the same type + found = self.__class__() + for key, value in self.data.items(): + if isinstance(key, tuple): + # key too short, cannot possibly match pattern + if len(key) < len(pattern): + continue + # match every part of pattern against the given key + # Ellipsis (...) is a wildcard and comparison is skipped + if all(p == k for p, k in zip(pattern, key) if p is not ...): + found[key] = value + else: + # key is a single entry, pattern must match it + if pattern == (...,) or pattern == (key,): + found[key] = value + # nothing matched the pattern, treat as usual dict lookup error + if not found: + raise KeyError(pattern) + return found diff --git a/heracles/io.py b/heracles/io.py index 8857628..2f88a40 100644 --- a/heracles/io.py +++ b/heracles/io.py @@ -25,7 +25,7 @@ import healpy as hp import numpy as np -from .util import TocDict, toc_match +from .core import TocDict, toc_match logger = logging.getLogger(__name__) diff --git a/heracles/maps.py b/heracles/maps.py index dd6d042..f7e6732 100644 --- a/heracles/maps.py +++ b/heracles/maps.py @@ -29,7 +29,7 @@ import numpy as np from numba import njit -from .util import Progress, TocDict, toc_match +from .core import TocDict, toc_match if t.TYPE_CHECKING: from .catalog import Catalog, CatalogPage @@ -706,6 +706,8 @@ def map_catalogs( # display a progress bar if asked to if progress: + from .util import Progress + prog = Progress() # collect groups of catalogues to go through if parallel @@ -815,6 +817,8 @@ def transform_maps( # display a progress bar if asked to if progress: + from .util import Progress + prog = Progress() prog.start(len(maps)) diff --git a/heracles/twopoint.py b/heracles/twopoint.py index 264e6e8..1f49d15 100644 --- a/heracles/twopoint.py +++ b/heracles/twopoint.py @@ -27,6 +27,7 @@ import numpy as np from convolvecl import mixmat, mixmat_eb +from .core import TocDict, toc_match from .maps import ( map_catalogs as _map_catalogs, ) @@ -36,7 +37,6 @@ from .maps import ( update_metadata, ) -from .util import TocDict, toc_match logger = logging.getLogger(__name__) diff --git a/heracles/util.py b/heracles/util.py index 5e8b145..f485d4f 100644 --- a/heracles/util.py +++ b/heracles/util.py @@ -21,77 +21,9 @@ import os import sys import time -from collections import UserDict -from collections.abc import Mapping, Sequence from datetime import timedelta -def toc_match(key, include=None, exclude=None): - """return whether a tocdict entry matches include/exclude criteria""" - if include is not None: - for pattern in include: - if all(p is Ellipsis or p == k for p, k in zip(pattern, key)): - break - else: - return False - if exclude is not None: - for pattern in exclude: - if all(p is Ellipsis or p == k for p, k in zip(pattern, key)): - return False - return True - - -def toc_filter(obj, include=None, exclude=None): - """return a filtered toc dict ``d``""" - if isinstance(obj, Sequence): - return [toc_filter(item, include, exclude) for item in obj] - if isinstance(obj, Mapping): - return {k: v for k, v in obj.items() if toc_match(k, include, exclude)} - msg = "invalid input type" - raise TypeError(msg) - - -# subclassing UserDict here since that returns the correct type from methods -# such as __copy__(), __or__(), etc. -class TocDict(UserDict): - """Table-of-contents dictionary with pattern-based lookup""" - - def __getitem__(self, pattern): - """look up one or many keys in dict""" - # first, see if pattern is a valid entry in the dict - # might fail with KeyError (no such entry) or TypeError (not hashable) - try: - return self.data[pattern] - except (KeyError, TypeError): - pass - # pattern might be a single object such as e.g. "X" - if not isinstance(pattern, tuple): - pattern = (pattern,) - # no pattern == matches everything - if not pattern: - return self.copy() - # go through all keys in the dict and match them against the pattern - # return an object of the same type - found = self.__class__() - for key, value in self.data.items(): - if isinstance(key, tuple): - # key too short, cannot possibly match pattern - if len(key) < len(pattern): - continue - # match every part of pattern against the given key - # Ellipsis (...) is a wildcard and comparison is skipped - if all(p == k for p, k in zip(pattern, key) if p is not ...): - found[key] = value - else: - # key is a single entry, pattern must match it - if pattern == (...,) or pattern == (key,): - found[key] = value - # nothing matched the pattern, treat as usual dict lookup error - if not found: - raise KeyError(pattern) - return found - - class Progress: """simple progress bar for operations""" diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..ad35454 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,110 @@ +import pytest + + +def test_toc_match(): + from heracles.core import toc_match + + assert toc_match(("a",)) + assert toc_match(("a",), None, []) + assert not toc_match(("a",), [], None) + + assert toc_match(("aa", 1, 2), [("aa", 1, 2)], None) + assert toc_match(("aa", 1, 2), [("aa", 1, 2)], []) + assert toc_match(("aa", 1, 2), [("aa", 1, 2)], [("ab", 1, 2)]) + assert toc_match(("aa", 1, 2), [("aa",)], None) + assert toc_match(("aa", 1, 2), [(..., 1)], None) + assert toc_match(("aa", 1, 2), [(..., ..., 2)], None) + + assert not toc_match(("aa", 1, 2), None, [("aa", 1, 2)]) + assert not toc_match(("aa", 1, 2), [], [("aa", 1, 2)]) + assert not toc_match(("aa", 1, 2), [("aa", 1, 2)], [("aa", 1, 2)]) + assert not toc_match(("aa", 1, 2), None, [("aa",)]) + assert not toc_match(("aa", 1, 2), None, [(..., 1)]) + assert not toc_match(("aa", 1, 2), None, [(..., ..., 2)]) + + +def test_toc_filter(): + from heracles.core import toc_filter + + full = {("a", "b"): 1, ("c", "d"): 2} + + assert toc_filter(full, [("a",)]) == {("a", "b"): 1} + assert toc_filter(full, [(..., "b")]) == {("a", "b"): 1} + assert toc_filter(full, [("a",), (..., "d")]) == full + assert toc_filter([full] * 2, [("a",)]) == [{("a", "b"): 1}] * 2 + + with pytest.raises(TypeError): + toc_filter(object()) + + +def test_tocdict(): + from copy import copy, deepcopy + + from heracles.core import TocDict + + d = TocDict( + { + ("a", "b", 1): "ab1", + ("a", "c", 1): "ac1", + ("b", "c", 2): "bc2", + }, + ) + + assert d["a", "b", 1] == "ab1" + assert d["a", "c", 1] == "ac1" + assert d["b", "c", 2] == "bc2" + with pytest.raises(KeyError): + d["b", "c", 1] + + assert d["a"] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"} + assert d["a", ..., 1] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"} + assert d[..., ..., 1] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"} + assert d[..., "c", 1] == {("a", "c", 1): "ac1"} + assert d[..., "c"] == {("a", "c", 1): "ac1", ("b", "c", 2): "bc2"} + assert d[..., ..., 2] == {("b", "c", 2): "bc2"} + with pytest.raises(KeyError): + d["c"] + + d = TocDict(a=1, b=2) + assert d["a"] == 1 + assert d["b"] == 2 + assert d[...] == d + assert d[()] == d + + assert type(d.copy()) == type(d) + assert type(copy(d)) == type(d) + assert type(deepcopy(d)) == type(d) + + d = TocDict(a=1) | TocDict(b=2) + assert type(d) is TocDict + assert d == {"a": 1, "b": 2} + + +def test_progress(): + from io import StringIO + + from heracles.util import Progress + + f = StringIO() + prog = Progress(f) + prog.start(10, "my title") + s = f.getvalue() + assert s.count("\r") == 1 + assert s.count("\n") == 0 + assert "my title" in s + assert "0/10" in s + prog.update() + s = f.getvalue() + assert s.count("\r") == 2 + assert s.count("\n") == 0 + assert "1/10" in s + prog.update(5) + s = f.getvalue() + assert s.count("\r") == 3 + assert s.count("\n") == 0 + assert "6/10" in s + prog.stop() + s = f.getvalue() + assert s.count("\r") == 4 + assert s.count("\n") == 1 + assert "10/10" in s diff --git a/tests/test_util.py b/tests/test_util.py index 8ca2247..ba33f01 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,85 +1,3 @@ -import pytest - - -def test_toc_match(): - from heracles.util import toc_match - - assert toc_match(("a",)) - assert toc_match(("a",), None, []) - assert not toc_match(("a",), [], None) - - assert toc_match(("aa", 1, 2), [("aa", 1, 2)], None) - assert toc_match(("aa", 1, 2), [("aa", 1, 2)], []) - assert toc_match(("aa", 1, 2), [("aa", 1, 2)], [("ab", 1, 2)]) - assert toc_match(("aa", 1, 2), [("aa",)], None) - assert toc_match(("aa", 1, 2), [(..., 1)], None) - assert toc_match(("aa", 1, 2), [(..., ..., 2)], None) - - assert not toc_match(("aa", 1, 2), None, [("aa", 1, 2)]) - assert not toc_match(("aa", 1, 2), [], [("aa", 1, 2)]) - assert not toc_match(("aa", 1, 2), [("aa", 1, 2)], [("aa", 1, 2)]) - assert not toc_match(("aa", 1, 2), None, [("aa",)]) - assert not toc_match(("aa", 1, 2), None, [(..., 1)]) - assert not toc_match(("aa", 1, 2), None, [(..., ..., 2)]) - - -def test_toc_filter(): - from heracles.util import toc_filter - - full = {("a", "b"): 1, ("c", "d"): 2} - - assert toc_filter(full, [("a",)]) == {("a", "b"): 1} - assert toc_filter(full, [(..., "b")]) == {("a", "b"): 1} - assert toc_filter(full, [("a",), (..., "d")]) == full - assert toc_filter([full] * 2, [("a",)]) == [{("a", "b"): 1}] * 2 - - with pytest.raises(TypeError): - toc_filter(object()) - - -def test_tocdict(): - from copy import copy, deepcopy - - from heracles.util import TocDict - - d = TocDict( - { - ("a", "b", 1): "ab1", - ("a", "c", 1): "ac1", - ("b", "c", 2): "bc2", - }, - ) - - assert d["a", "b", 1] == "ab1" - assert d["a", "c", 1] == "ac1" - assert d["b", "c", 2] == "bc2" - with pytest.raises(KeyError): - d["b", "c", 1] - - assert d["a"] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"} - assert d["a", ..., 1] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"} - assert d[..., ..., 1] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"} - assert d[..., "c", 1] == {("a", "c", 1): "ac1"} - assert d[..., "c"] == {("a", "c", 1): "ac1", ("b", "c", 2): "bc2"} - assert d[..., ..., 2] == {("b", "c", 2): "bc2"} - with pytest.raises(KeyError): - d["c"] - - d = TocDict(a=1, b=2) - assert d["a"] == 1 - assert d["b"] == 2 - assert d[...] == d - assert d[()] == d - - assert type(d.copy()) == type(d) - assert type(copy(d)) == type(d) - assert type(deepcopy(d)) == type(d) - - d = TocDict(a=1) | TocDict(b=2) - assert type(d) is TocDict - assert d == {"a": 1, "b": 2} - - def test_progress(): from io import StringIO From ab0a5441275eec4c7cdaa2b62cc3947127bd7f10 Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Tue, 12 Sep 2023 11:48:59 +0100 Subject: [PATCH 2/2] remove duplicated test --- tests/test_core.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index ad35454..b8e692b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -78,33 +78,3 @@ def test_tocdict(): d = TocDict(a=1) | TocDict(b=2) assert type(d) is TocDict assert d == {"a": 1, "b": 2} - - -def test_progress(): - from io import StringIO - - from heracles.util import Progress - - f = StringIO() - prog = Progress(f) - prog.start(10, "my title") - s = f.getvalue() - assert s.count("\r") == 1 - assert s.count("\n") == 0 - assert "my title" in s - assert "0/10" in s - prog.update() - s = f.getvalue() - assert s.count("\r") == 2 - assert s.count("\n") == 0 - assert "1/10" in s - prog.update(5) - s = f.getvalue() - assert s.count("\r") == 3 - assert s.count("\n") == 0 - assert "6/10" in s - prog.stop() - s = f.getvalue() - assert s.count("\r") == 4 - assert s.count("\n") == 1 - assert "10/10" in s