Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT: create a "core" module for shared functionality #32

Merged
merged 2 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .commitlint.rules.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@ module.exports = {
"scope-enum": [
2,
"always",
["catalog", "covariance", "io", "maps", "plot", "twopoint", "util"],
[
"catalog",
"core",
"covariance",
"io",
"maps",
"plot",
"twopoint",
"util",
],
],
"scope-case": [0, "always", "lower-case"],
"type-enum": [
Expand Down
2 changes: 1 addition & 1 deletion examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@
"metadata": {},
"outputs": [],
"source": [
"from heracles.util import TocDict\n",
"from heracles.core import TocDict\n",
"\n",
"theory_cls = TocDict()\n",
"for i, s1 in enumerate(sources):\n",
Expand Down
88 changes: 88 additions & 0 deletions heracles/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Heracles: Euclid code for harmonic-space statistics on the sphere
#
# Copyright (C) 2023 Euclid Science Ground Segment
#
# This file is part of Heracles.
#
# Heracles is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Heracles is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with Heracles. If not, see <https://www.gnu.org/licenses/>.
"""module for common core functionality"""

from collections import UserDict
from collections.abc import Mapping, Sequence


def toc_match(key, include=None, exclude=None):
"""return whether a tocdict entry matches include/exclude criteria"""
if include is not None:
for pattern in include:
if all(p is Ellipsis or p == k for p, k in zip(pattern, key)):
break
else:
return False
if exclude is not None:
for pattern in exclude:
if all(p is Ellipsis or p == k for p, k in zip(pattern, key)):
return False
return True


def toc_filter(obj, include=None, exclude=None):
"""return a filtered toc dict ``d``"""
if isinstance(obj, Sequence):
return [toc_filter(item, include, exclude) for item in obj]
if isinstance(obj, Mapping):
return {k: v for k, v in obj.items() if toc_match(k, include, exclude)}
msg = "invalid input type"
raise TypeError(msg)


# subclassing UserDict here since that returns the correct type from methods
# such as __copy__(), __or__(), etc.
class TocDict(UserDict):
"""Table-of-contents dictionary with pattern-based lookup"""

def __getitem__(self, pattern):
"""look up one or many keys in dict"""
# first, see if pattern is a valid entry in the dict
# might fail with KeyError (no such entry) or TypeError (not hashable)
try:
return self.data[pattern]
except (KeyError, TypeError):
pass
# pattern might be a single object such as e.g. "X"
if not isinstance(pattern, tuple):
pattern = (pattern,)
# no pattern == matches everything
if not pattern:
return self.copy()
# go through all keys in the dict and match them against the pattern
# return an object of the same type
found = self.__class__()
for key, value in self.data.items():
if isinstance(key, tuple):
# key too short, cannot possibly match pattern
if len(key) < len(pattern):
continue
# match every part of pattern against the given key
# Ellipsis (...) is a wildcard and comparison is skipped
if all(p == k for p, k in zip(pattern, key) if p is not ...):
found[key] = value
else:
# key is a single entry, pattern must match it
if pattern == (...,) or pattern == (key,):
found[key] = value
# nothing matched the pattern, treat as usual dict lookup error
if not found:
raise KeyError(pattern)
return found
2 changes: 1 addition & 1 deletion heracles/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import healpy as hp
import numpy as np

from .util import TocDict, toc_match
from .core import TocDict, toc_match

logger = logging.getLogger(__name__)

Expand Down
6 changes: 5 additions & 1 deletion heracles/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import numpy as np
from numba import njit

from .util import Progress, TocDict, toc_match
from .core import TocDict, toc_match

if t.TYPE_CHECKING:
from .catalog import Catalog, CatalogPage
Expand Down Expand Up @@ -706,6 +706,8 @@ def map_catalogs(

# display a progress bar if asked to
if progress:
from .util import Progress

prog = Progress()

# collect groups of catalogues to go through if parallel
Expand Down Expand Up @@ -815,6 +817,8 @@ def transform_maps(

# display a progress bar if asked to
if progress:
from .util import Progress

prog = Progress()
prog.start(len(maps))

Expand Down
2 changes: 1 addition & 1 deletion heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
from convolvecl import mixmat, mixmat_eb

from .core import TocDict, toc_match
from .maps import (
map_catalogs as _map_catalogs,
)
Expand All @@ -36,7 +37,6 @@
from .maps import (
update_metadata,
)
from .util import TocDict, toc_match

logger = logging.getLogger(__name__)

Expand Down
68 changes: 0 additions & 68 deletions heracles/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,77 +21,9 @@
import os
import sys
import time
from collections import UserDict
from collections.abc import Mapping, Sequence
from datetime import timedelta


def toc_match(key, include=None, exclude=None):
"""return whether a tocdict entry matches include/exclude criteria"""
if include is not None:
for pattern in include:
if all(p is Ellipsis or p == k for p, k in zip(pattern, key)):
break
else:
return False
if exclude is not None:
for pattern in exclude:
if all(p is Ellipsis or p == k for p, k in zip(pattern, key)):
return False
return True


def toc_filter(obj, include=None, exclude=None):
"""return a filtered toc dict ``d``"""
if isinstance(obj, Sequence):
return [toc_filter(item, include, exclude) for item in obj]
if isinstance(obj, Mapping):
return {k: v for k, v in obj.items() if toc_match(k, include, exclude)}
msg = "invalid input type"
raise TypeError(msg)


# subclassing UserDict here since that returns the correct type from methods
# such as __copy__(), __or__(), etc.
class TocDict(UserDict):
"""Table-of-contents dictionary with pattern-based lookup"""

def __getitem__(self, pattern):
"""look up one or many keys in dict"""
# first, see if pattern is a valid entry in the dict
# might fail with KeyError (no such entry) or TypeError (not hashable)
try:
return self.data[pattern]
except (KeyError, TypeError):
pass
# pattern might be a single object such as e.g. "X"
if not isinstance(pattern, tuple):
pattern = (pattern,)
# no pattern == matches everything
if not pattern:
return self.copy()
# go through all keys in the dict and match them against the pattern
# return an object of the same type
found = self.__class__()
for key, value in self.data.items():
if isinstance(key, tuple):
# key too short, cannot possibly match pattern
if len(key) < len(pattern):
continue
# match every part of pattern against the given key
# Ellipsis (...) is a wildcard and comparison is skipped
if all(p == k for p, k in zip(pattern, key) if p is not ...):
found[key] = value
else:
# key is a single entry, pattern must match it
if pattern == (...,) or pattern == (key,):
found[key] = value
# nothing matched the pattern, treat as usual dict lookup error
if not found:
raise KeyError(pattern)
return found


class Progress:
"""simple progress bar for operations"""

Expand Down
80 changes: 80 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest


def test_toc_match():
from heracles.core import toc_match

assert toc_match(("a",))
assert toc_match(("a",), None, [])
assert not toc_match(("a",), [], None)

assert toc_match(("aa", 1, 2), [("aa", 1, 2)], None)
assert toc_match(("aa", 1, 2), [("aa", 1, 2)], [])
assert toc_match(("aa", 1, 2), [("aa", 1, 2)], [("ab", 1, 2)])
assert toc_match(("aa", 1, 2), [("aa",)], None)
assert toc_match(("aa", 1, 2), [(..., 1)], None)
assert toc_match(("aa", 1, 2), [(..., ..., 2)], None)

assert not toc_match(("aa", 1, 2), None, [("aa", 1, 2)])
assert not toc_match(("aa", 1, 2), [], [("aa", 1, 2)])
assert not toc_match(("aa", 1, 2), [("aa", 1, 2)], [("aa", 1, 2)])
assert not toc_match(("aa", 1, 2), None, [("aa",)])
assert not toc_match(("aa", 1, 2), None, [(..., 1)])
assert not toc_match(("aa", 1, 2), None, [(..., ..., 2)])


def test_toc_filter():
from heracles.core import toc_filter

full = {("a", "b"): 1, ("c", "d"): 2}

assert toc_filter(full, [("a",)]) == {("a", "b"): 1}
assert toc_filter(full, [(..., "b")]) == {("a", "b"): 1}
assert toc_filter(full, [("a",), (..., "d")]) == full
assert toc_filter([full] * 2, [("a",)]) == [{("a", "b"): 1}] * 2

with pytest.raises(TypeError):
toc_filter(object())


def test_tocdict():
from copy import copy, deepcopy

from heracles.core import TocDict

d = TocDict(
{
("a", "b", 1): "ab1",
("a", "c", 1): "ac1",
("b", "c", 2): "bc2",
},
)

assert d["a", "b", 1] == "ab1"
assert d["a", "c", 1] == "ac1"
assert d["b", "c", 2] == "bc2"
with pytest.raises(KeyError):
d["b", "c", 1]

assert d["a"] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"}
assert d["a", ..., 1] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"}
assert d[..., ..., 1] == {("a", "b", 1): "ab1", ("a", "c", 1): "ac1"}
assert d[..., "c", 1] == {("a", "c", 1): "ac1"}
assert d[..., "c"] == {("a", "c", 1): "ac1", ("b", "c", 2): "bc2"}
assert d[..., ..., 2] == {("b", "c", 2): "bc2"}
with pytest.raises(KeyError):
d["c"]

d = TocDict(a=1, b=2)
assert d["a"] == 1
assert d["b"] == 2
assert d[...] == d
assert d[()] == d

assert type(d.copy()) == type(d)
assert type(copy(d)) == type(d)
assert type(deepcopy(d)) == type(d)

d = TocDict(a=1) | TocDict(b=2)
assert type(d) is TocDict
assert d == {"a": 1, "b": 2}
Loading
Loading