Skip to content

Commit

Permalink
renamed numba defaults class and moved it to top directory
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBorden committed Mar 22, 2024
1 parent c57d48d commit 2a76a94
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 103 deletions.
4 changes: 2 additions & 2 deletions src/pygama/math/functions/crystal_ball.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from math import erf

from pygama.math.functions.pygama_continuous import pygama_continuous
from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults



Expand Down
4 changes: 2 additions & 2 deletions src/pygama/math/functions/exgauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from pygama.math.functions.error_function import nb_erf, nb_erfc
from pygama.math.functions.gauss import nb_gauss_cdf, nb_gauss_pdf

from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults

limit = np.log(sys.float_info.max)/10

Expand Down
4 changes: 2 additions & 2 deletions src/pygama/math/functions/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np

from pygama.math.functions.pygama_continuous import pygama_continuous
from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults



Expand Down
2 changes: 1 addition & 1 deletion src/pygama/math/functions/gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pygama.math.functions.error_function import nb_erf
from pygama.math.functions.pygama_continuous import pygama_continuous
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults as nb_defaults


@nb.njit(**nb_defaults(parallel=False))
Expand Down
4 changes: 2 additions & 2 deletions src/pygama/math/functions/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from numba import prange

from pygama.math.functions.pygama_continuous import pygama_continuous
from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults


@nb.njit(**nb_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/pygama/math/functions/moyal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from math import erfc

from pygama.math.functions.pygama_continuous import pygama_continuous
from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults



Expand Down
4 changes: 2 additions & 2 deletions src/pygama/math/functions/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from numba import prange

from scipy.stats import rv_discrete
from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults


@nb.njit(**nb_defaults(parallel=False))
Expand Down
2 changes: 1 addition & 1 deletion src/pygama/math/functions/polynomial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numba as nb
import numpy as np
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults as nb_defaults


@nb.njit(**nb_defaults(parallel=False))
Expand Down
6 changes: 2 additions & 4 deletions src/pygama/math/functions/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@

from pygama.math.functions.gauss import nb_gauss
from pygama.math.functions.pygama_continuous import pygama_continuous
from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults

kwd = {"parallel": False, "fastmath": True}
kwd_parallel = {"parallel": True, "fastmath": True}

@nb.njit(**nb_defaults(parallel=False))
def nb_step_int(x: float, mu: float, sigma: float, hstep: float) -> np.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions src/pygama/math/functions/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from numba import prange

from pygama.math.functions.pygama_continuous import pygama_continuous
from ..utils import numba_math_defaults_kwargs as nb_kwargs
from ..utils import numba_math_defaults as nb_defaults
from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
from pygama.utils import numba_math_defaults as nb_defaults


@nb.njit(**nb_kwargs)
Expand Down
76 changes: 1 addition & 75 deletions src/pygama/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
"""
import logging
import sys
import os
from typing import Optional, Union, Callable, Any, Iterator
from collections.abc import MutableMapping
from typing import Optional, Union, Callable

import numpy as np

Expand Down Expand Up @@ -105,75 +103,3 @@ def print_fit_results(pars: np.ndarray, cov: np.ndarray, func: Optional[Callable
log.info(f"{par_names[i]} = {mean} +/- {sigma}")
if pad:
log.info("")


def getenv_bool(name: str, default: bool = False) -> bool:
"""Get environment value as a boolean, returning True for 1, t and true
(caps-insensitive), and False for any other value and default if undefined.
"""
val = os.getenv(name)
if not val:
return default
elif val.lower() in ("1", "t", "true"):
return True
else:
return False

class NumbaMathDefaults(MutableMapping):
"""Bare-bones class to store some Numba default options. Defaults values
are set from environment variables
Examples
--------
Set all default option values for a processor at once by expanding the
provided dictionary:
>>> from numba import guvectorize
>>> from pygama.math.utils import numba_defaults_kwargs as nb_kwargs
>>> @guvectorize([], "", **nb_kwargs, nopython=True) # def proc(...): ...
Customize one argument but still set defaults for the others:
>>> from pygama.math.utils import numba_defaults as nb_defaults
>>> @guvectorize([], "", **nb_defaults(cache=False) # def proc(...): ...
Override global options at runtime:
>>> from pygama.math.utils import numba_defaults
>>> # must set options before explicitly importing pygama.math.distributions!
>>> numba_defaults.cache = False
"""

def __init__(self) -> None:
self.parallel: bool = getenv_bool("MATH_PARALLEL", default=True)
self.fastmath: bool = getenv_bool("MATH_FAST", default=True)

def __getitem__(self, item: str) -> Any:
return self.__dict__[item]

def __setitem__(self, item: str, val: Any) -> None:
self.__dict__[item] = val

def __delitem__(self, item: str) -> None:
del self.__dict__[item]

def __iter__(self) -> Iterator:
return self.__dict__.__iter__()

def __len__(self) -> int:
return len(self.__dict__)

def __call__(self, **kwargs) -> dict:
mapping = self.__dict__.copy()
mapping.update(**kwargs)
return mapping

def __str__(self) -> str:
return str(self.__dict__)

def __repr__(self) -> str:
return str(self.__dict__)


numba_math_defaults = NumbaMathDefaults()
numba_math_defaults_kwargs = numba_math_defaults
76 changes: 76 additions & 0 deletions src/pygama/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from collections.abc import MutableMapping
from typing import Any, Iterator


def getenv_bool(name: str, default: bool = False) -> bool:
"""Get environment value as a boolean, returning True for 1, t and true
(caps-insensitive), and False for any other value and default if undefined.
"""
val = os.getenv(name)
if not val:
return default
elif val.lower() in ("1", "t", "true"):
return True
else:
return False


class NumbaPygamaDefaults(MutableMapping):
"""Bare-bones class to store some Numba default options. Defaults values
are set from environment variables. Useful for the pygama.math distributions
Examples
--------
Set all default option values for a numba wrapped function at once by expanding the
provided dictionary:
>>> from numba import njit
>>> from pygama.utils import numba_math_defaults_kwargs as nb_kwargs
>>> @njit([], "", **nb_kwargs, nopython=True) # def dist(...): ...
Customize one argument but still set defaults for the others:
>>> from pygama.utils import numba_math_defaults as nb_defaults
>>> @njit([], "", **nb_defaults(cache=False) # def dist(...): ...
Override global options at runtime:
>>> from pygama.utils import numba_math_defaults
>>> # must set options before explicitly importing pygama.math.distributions!
>>> numba_math_defaults.cache = False
"""

def __init__(self) -> None:
self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=True)
self.fastmath: bool = getenv_bool("PYGAMA_FAST", default=True)

def __getitem__(self, item: str) -> Any:
return self.__dict__[item]

def __setitem__(self, item: str, val: Any) -> None:
self.__dict__[item] = val

def __delitem__(self, item: str) -> None:
del self.__dict__[item]

def __iter__(self) -> Iterator:
return self.__dict__.__iter__()

def __len__(self) -> int:
return len(self.__dict__)

def __call__(self, **kwargs) -> dict:
mapping = self.__dict__.copy()
mapping.update(**kwargs)
return mapping

def __str__(self) -> str:
return str(self.__dict__)

def __repr__(self) -> str:
return str(self.__dict__)


numba_math_defaults = NumbaPygamaDefaults()
numba_math_defaults_kwargs = numba_math_defaults
8 changes: 0 additions & 8 deletions tests/math/test_math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,3 @@ def test_print_fit_results(caplog):
"p2 = 1.0 +/- 1.0",
"",
] == [rec.message for rec in caplog.records]


def test_math_numba_defaults():
assert pgu.numba_math_defaults_kwargs.fastmath
assert pgu.numba_math_defaults_kwargs.parallel

pgu.numba_math_defaults.fastmath = False
assert ~pgu.numba_math_defaults.fastmath
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pygama.utils as pgu


def test_math_numba_defaults():
assert pgu.numba_math_defaults_kwargs.fastmath
assert pgu.numba_math_defaults_kwargs.parallel

pgu.numba_math_defaults.fastmath = False
assert ~pgu.numba_math_defaults.fastmath

0 comments on commit 2a76a94

Please sign in to comment.