diff --git a/src/pygama/math/functions/crystal_ball.py b/src/pygama/math/functions/crystal_ball.py index dbd0d9ccd..a291f3763 100644 --- a/src/pygama/math/functions/crystal_ball.py +++ b/src/pygama/math/functions/crystal_ball.py @@ -9,8 +9,8 @@ from math import erf from pygama.math.functions.pygama_continuous import pygama_continuous -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults diff --git a/src/pygama/math/functions/exgauss.py b/src/pygama/math/functions/exgauss.py index e15cd4021..7922cb441 100644 --- a/src/pygama/math/functions/exgauss.py +++ b/src/pygama/math/functions/exgauss.py @@ -13,8 +13,8 @@ from pygama.math.functions.error_function import nb_erf, nb_erfc from pygama.math.functions.gauss import nb_gauss_cdf, nb_gauss_pdf -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults limit = np.log(sys.float_info.max)/10 diff --git a/src/pygama/math/functions/exponential.py b/src/pygama/math/functions/exponential.py index c8d6e8b55..b900e2fa5 100644 --- a/src/pygama/math/functions/exponential.py +++ b/src/pygama/math/functions/exponential.py @@ -8,8 +8,8 @@ import numpy as np from pygama.math.functions.pygama_continuous import pygama_continuous -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults diff --git a/src/pygama/math/functions/gauss.py b/src/pygama/math/functions/gauss.py index 7df5f2775..ab52dbaf5 100644 --- a/src/pygama/math/functions/gauss.py +++ b/src/pygama/math/functions/gauss.py @@ -9,7 +9,7 @@ from pygama.math.functions.error_function import nb_erf from pygama.math.functions.pygama_continuous import pygama_continuous -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults as nb_defaults @nb.njit(**nb_defaults(parallel=False)) diff --git a/src/pygama/math/functions/linear.py b/src/pygama/math/functions/linear.py index cfb2dd222..39a9c5853 100644 --- a/src/pygama/math/functions/linear.py +++ b/src/pygama/math/functions/linear.py @@ -6,8 +6,8 @@ from numba import prange from pygama.math.functions.pygama_continuous import pygama_continuous -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults @nb.njit(**nb_kwargs) diff --git a/src/pygama/math/functions/moyal.py b/src/pygama/math/functions/moyal.py index 76424c6ba..f16d5e1ca 100644 --- a/src/pygama/math/functions/moyal.py +++ b/src/pygama/math/functions/moyal.py @@ -10,8 +10,8 @@ from math import erfc from pygama.math.functions.pygama_continuous import pygama_continuous -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults diff --git a/src/pygama/math/functions/poisson.py b/src/pygama/math/functions/poisson.py index 837b78276..524011496 100644 --- a/src/pygama/math/functions/poisson.py +++ b/src/pygama/math/functions/poisson.py @@ -8,8 +8,8 @@ from numba import prange from scipy.stats import rv_discrete -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults @nb.njit(**nb_defaults(parallel=False)) diff --git a/src/pygama/math/functions/polynomial.py b/src/pygama/math/functions/polynomial.py index f78349501..d888d53dd 100644 --- a/src/pygama/math/functions/polynomial.py +++ b/src/pygama/math/functions/polynomial.py @@ -1,6 +1,6 @@ import numba as nb import numpy as np -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults as nb_defaults @nb.njit(**nb_defaults(parallel=False)) diff --git a/src/pygama/math/functions/step.py b/src/pygama/math/functions/step.py index dc35bbaf4..c33c89b7b 100644 --- a/src/pygama/math/functions/step.py +++ b/src/pygama/math/functions/step.py @@ -11,11 +11,9 @@ from pygama.math.functions.gauss import nb_gauss from pygama.math.functions.pygama_continuous import pygama_continuous -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults -kwd = {"parallel": False, "fastmath": True} -kwd_parallel = {"parallel": True, "fastmath": True} @nb.njit(**nb_defaults(parallel=False)) def nb_step_int(x: float, mu: float, sigma: float, hstep: float) -> np.ndarray: diff --git a/src/pygama/math/functions/uniform.py b/src/pygama/math/functions/uniform.py index b51c1c3f2..72c9cc3af 100644 --- a/src/pygama/math/functions/uniform.py +++ b/src/pygama/math/functions/uniform.py @@ -10,8 +10,8 @@ from numba import prange from pygama.math.functions.pygama_continuous import pygama_continuous -from ..utils import numba_math_defaults_kwargs as nb_kwargs -from ..utils import numba_math_defaults as nb_defaults +from pygama.utils import numba_math_defaults_kwargs as nb_kwargs +from pygama.utils import numba_math_defaults as nb_defaults @nb.njit(**nb_kwargs) diff --git a/src/pygama/math/utils.py b/src/pygama/math/utils.py index a29d025f7..48bd6e365 100644 --- a/src/pygama/math/utils.py +++ b/src/pygama/math/utils.py @@ -3,9 +3,7 @@ """ import logging import sys -import os -from typing import Optional, Union, Callable, Any, Iterator -from collections.abc import MutableMapping +from typing import Optional, Union, Callable import numpy as np @@ -105,75 +103,3 @@ def print_fit_results(pars: np.ndarray, cov: np.ndarray, func: Optional[Callable log.info(f"{par_names[i]} = {mean} +/- {sigma}") if pad: log.info("") - - -def getenv_bool(name: str, default: bool = False) -> bool: - """Get environment value as a boolean, returning True for 1, t and true - (caps-insensitive), and False for any other value and default if undefined. - """ - val = os.getenv(name) - if not val: - return default - elif val.lower() in ("1", "t", "true"): - return True - else: - return False - -class NumbaMathDefaults(MutableMapping): - """Bare-bones class to store some Numba default options. Defaults values - are set from environment variables - - Examples - -------- - Set all default option values for a processor at once by expanding the - provided dictionary: - - >>> from numba import guvectorize - >>> from pygama.math.utils import numba_defaults_kwargs as nb_kwargs - >>> @guvectorize([], "", **nb_kwargs, nopython=True) # def proc(...): ... - - Customize one argument but still set defaults for the others: - - >>> from pygama.math.utils import numba_defaults as nb_defaults - >>> @guvectorize([], "", **nb_defaults(cache=False) # def proc(...): ... - - Override global options at runtime: - - >>> from pygama.math.utils import numba_defaults - >>> # must set options before explicitly importing pygama.math.distributions! - >>> numba_defaults.cache = False - """ - - def __init__(self) -> None: - self.parallel: bool = getenv_bool("MATH_PARALLEL", default=True) - self.fastmath: bool = getenv_bool("MATH_FAST", default=True) - - def __getitem__(self, item: str) -> Any: - return self.__dict__[item] - - def __setitem__(self, item: str, val: Any) -> None: - self.__dict__[item] = val - - def __delitem__(self, item: str) -> None: - del self.__dict__[item] - - def __iter__(self) -> Iterator: - return self.__dict__.__iter__() - - def __len__(self) -> int: - return len(self.__dict__) - - def __call__(self, **kwargs) -> dict: - mapping = self.__dict__.copy() - mapping.update(**kwargs) - return mapping - - def __str__(self) -> str: - return str(self.__dict__) - - def __repr__(self) -> str: - return str(self.__dict__) - - -numba_math_defaults = NumbaMathDefaults() -numba_math_defaults_kwargs = numba_math_defaults \ No newline at end of file diff --git a/src/pygama/utils.py b/src/pygama/utils.py new file mode 100644 index 000000000..0b28a299b --- /dev/null +++ b/src/pygama/utils.py @@ -0,0 +1,76 @@ +import os +from collections.abc import MutableMapping +from typing import Any, Iterator + + +def getenv_bool(name: str, default: bool = False) -> bool: + """Get environment value as a boolean, returning True for 1, t and true + (caps-insensitive), and False for any other value and default if undefined. + """ + val = os.getenv(name) + if not val: + return default + elif val.lower() in ("1", "t", "true"): + return True + else: + return False + + +class NumbaPygamaDefaults(MutableMapping): + """Bare-bones class to store some Numba default options. Defaults values + are set from environment variables. Useful for the pygama.math distributions + + Examples + -------- + Set all default option values for a numba wrapped function at once by expanding the + provided dictionary: + + >>> from numba import njit + >>> from pygama.utils import numba_math_defaults_kwargs as nb_kwargs + >>> @njit([], "", **nb_kwargs, nopython=True) # def dist(...): ... + + Customize one argument but still set defaults for the others: + + >>> from pygama.utils import numba_math_defaults as nb_defaults + >>> @njit([], "", **nb_defaults(cache=False) # def dist(...): ... + + Override global options at runtime: + + >>> from pygama.utils import numba_math_defaults + >>> # must set options before explicitly importing pygama.math.distributions! + >>> numba_math_defaults.cache = False + """ + + def __init__(self) -> None: + self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=True) + self.fastmath: bool = getenv_bool("PYGAMA_FAST", default=True) + + def __getitem__(self, item: str) -> Any: + return self.__dict__[item] + + def __setitem__(self, item: str, val: Any) -> None: + self.__dict__[item] = val + + def __delitem__(self, item: str) -> None: + del self.__dict__[item] + + def __iter__(self) -> Iterator: + return self.__dict__.__iter__() + + def __len__(self) -> int: + return len(self.__dict__) + + def __call__(self, **kwargs) -> dict: + mapping = self.__dict__.copy() + mapping.update(**kwargs) + return mapping + + def __str__(self) -> str: + return str(self.__dict__) + + def __repr__(self) -> str: + return str(self.__dict__) + + +numba_math_defaults = NumbaPygamaDefaults() +numba_math_defaults_kwargs = numba_math_defaults diff --git a/tests/math/test_math_utils.py b/tests/math/test_math_utils.py index 44f96feb5..ce8a9abe7 100644 --- a/tests/math/test_math_utils.py +++ b/tests/math/test_math_utils.py @@ -31,11 +31,3 @@ def test_print_fit_results(caplog): "p2 = 1.0 +/- 1.0", "", ] == [rec.message for rec in caplog.records] - - -def test_math_numba_defaults(): - assert pgu.numba_math_defaults_kwargs.fastmath - assert pgu.numba_math_defaults_kwargs.parallel - - pgu.numba_math_defaults.fastmath = False - assert ~pgu.numba_math_defaults.fastmath diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..ba70ad382 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,9 @@ +import pygama.utils as pgu + + +def test_math_numba_defaults(): + assert pgu.numba_math_defaults_kwargs.fastmath + assert pgu.numba_math_defaults_kwargs.parallel + + pgu.numba_math_defaults.fastmath = False + assert ~pgu.numba_math_defaults.fastmath