Skip to content

Commit

Permalink
Merge branch 'develop' into migrate-to-sbc
Browse files Browse the repository at this point in the history
  • Loading branch information
cringeyburger authored Aug 30, 2024
2 parents d01562c + ad41dbe commit 02e4f97
Show file tree
Hide file tree
Showing 23 changed files with 59 additions and 59 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

## Breaking changes

- Replaced `have_jax` with `has_jax`, `have_idaklu` with `has_idaklu`, and
`have_iree` with `has_iree` ([#4398](https://github.com/pybamm-team/PyBaMM/pull/4398))
- Remove deprecated function `pybamm_install_jax` ([#4362](https://github.com/pybamm-team/PyBaMM/pull/4362))
- Removed legacy python-IDAKLU solver. ([#4326](https://github.com/pybamm-team/PyBaMM/pull/4326))

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ Utility functions

.. autofunction:: pybamm.load

.. autofunction:: pybamm.have_jax
.. autofunction:: pybamm.has_jax

.. autofunction:: pybamm.is_jax_compatible
2 changes: 1 addition & 1 deletion examples/scripts/compare_dae_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
casadi_sol = pybamm.CasadiSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
solutions = [casadi_sol]

if pybamm.have_idaklu():
if pybamm.has_idaklu():
klu_sol = pybamm.IDAKLUSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
solutions.append(klu_sol)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from .util import (
get_parameters_filepath,
have_jax,
has_jax,
import_optional_dependency,
is_jax_compatible,
get_git_commit_info,
Expand Down Expand Up @@ -170,7 +170,7 @@
from .solvers.jax_bdf_solver import jax_bdf_integrate

from .solvers.idaklu_jax import IDAKLUJax
from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu, have_iree
from .solvers.idaklu_solver import IDAKLUSolver, has_idaklu, has_iree

# Experiments
from .experiment.experiment import Experiment
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pybamm

if pybamm.have_jax():
if pybamm.has_jax():
import jax

platform = jax.lib.xla_bridge.get_backend().platform.casefold()
Expand Down Expand Up @@ -43,7 +43,7 @@ class JaxCooMatrix:
def __init__(
self, row: ArrayLike, col: ArrayLike, data: ArrayLike, shape: tuple[int, int]
):
if not pybamm.have_jax(): # pragma: no cover
if not pybamm.has_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down Expand Up @@ -527,7 +527,7 @@ class EvaluatorJax:
"""

def __init__(self, symbol: pybamm.Symbol):
if not pybamm.have_jax(): # pragma: no cover
if not pybamm.has_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/solvers/idaklu_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
except ImportError: # pragma: no cover
idaklu_spec = None

if pybamm.have_jax():
if pybamm.has_jax():
import jax
from jax import lax
from jax import numpy as jnp
Expand Down Expand Up @@ -57,11 +57,11 @@ def __init__(
calculate_sensitivities=True,
t_interp=None,
):
if not pybamm.have_jax():
if not pybamm.has_jax():
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
) # pragma: no cover
if not pybamm.have_idaklu():
if not pybamm.has_idaklu():
raise ModuleNotFoundError(
"IDAKLU is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html"
) # pragma: no cover
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings


if pybamm.have_jax():
if pybamm.has_jax():
import jax
from jax import numpy as jnp

Expand All @@ -33,11 +33,11 @@
idaklu_spec = None


def have_idaklu():
def has_idaklu():
return idaklu_spec is not None


def have_iree():
def has_iree():
try:
import iree.compiler # noqa: F401

Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pybamm

if pybamm.have_jax():
if pybamm.has_jax():
import jax
import jax.numpy as jnp
from jax import core, dtypes
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
calculated state vector at each of the m time points
"""
if not pybamm.have_jax():
if not pybamm.has_jax():
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pybamm

if pybamm.have_jax():
if pybamm.has_jax():
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
extrap_tol=None,
extra_options=None,
):
if not pybamm.have_jax():
if not pybamm.has_jax():
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def get_parameters_filepath(path):
return os.path.join(pybamm.__path__[0], path)


def have_jax():
def has_jax():
"""
Check if jax and jaxlib are installed with the correct versions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_sensitivities(self):
param = pybamm.ParameterValues("Ecker2015")
rtol = 1e-6
atol = 1e-6
if pybamm.have_idaklu():
if pybamm.has_idaklu():
solver = pybamm.IDAKLUSolver(rtol=rtol, atol=atol)
else:
solver = pybamm.CasadiSolver(rtol=rtol, atol=atol)
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_optimisations(self):
to_python = optimtest.evaluate_model(to_python=True)
np.testing.assert_array_almost_equal(original, to_python)

if pybamm.have_jax():
if pybamm.has_jax():
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_optimisations(self):
to_python = optimtest.evaluate_model(to_python=True)
np.testing.assert_array_almost_equal(original, to_python)

if pybamm.have_jax():
if pybamm.has_jax():
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_solvers/test_idaklu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


@pytest.mark.skipif(not pybamm.have_idaklu(), reason="idaklu solver is not installed")
@pytest.mark.skipif(not pybamm.has_idaklu(), reason="idaklu solver is not installed")
class TestIDAKLUSolver:
def test_on_spme(self):
model = pybamm.lithium_ion.SPMe()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,14 @@ def test_solver_citations(self):
assert "Virtanen2020" in citations._papers_to_cite
assert "Virtanen2020" in citations._citation_tags.keys()

if pybamm.have_idaklu():
if pybamm.has_idaklu():
citations._reset()
assert "Hindmarsh2005" not in citations._papers_to_cite
pybamm.IDAKLUSolver()
assert "Hindmarsh2005" in citations._papers_to_cite
assert "Hindmarsh2005" in citations._citation_tags.keys()

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_jax_citations(self):
citations = pybamm.citations
citations._reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_run_experiment_multiple_times(self):
sol1["Voltage [V]"].data, sol2["Voltage [V]"].data
)

@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed")
@unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed")
def test_run_experiment_cccv_solvers(self):
experiment_2step = pybamm.Experiment(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections import OrderedDict
import re

if pybamm.have_jax():
if pybamm.has_jax():
import jax
from tests import (
function_test,
Expand Down Expand Up @@ -446,7 +446,7 @@ def test_evaluator_python(self):
result = evaluator(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_find_symbols_jax(self):
# test sparse conversion
constant_symbols = OrderedDict()
Expand All @@ -459,7 +459,7 @@ def test_find_symbols_jax(self):
next(iter(constant_symbols.values())).toarray(), A.entries.toarray()
)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_evaluator_jax(self):
result = evaluator(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_jacobian(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -636,7 +636,7 @@ def test_evaluator_jax_jacobian(self):
result_true = evaluator_jac(t=None, y=y)
np.testing.assert_allclose(result_test, result_true)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_jvp(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -656,23 +656,23 @@ def test_evaluator_jax_jvp(self):
np.testing.assert_allclose(result_test, result_true)
np.testing.assert_allclose(result_test_times_v, result_true_times_v)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_debug(self):
a = pybamm.StateVector(slice(0, 1))
expr = a**2
y_test = np.array([2.0, 3.0])
evaluator = pybamm.EvaluatorJax(expr)
evaluator.debug(y=y_test)

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_inputs(self):
a = pybamm.InputParameter("a")
expr = a**2
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator(inputs={"a": 2})
assert result == 4

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_evaluator_jax_demotion(self):
for demote in [True, False]:
pybamm.demote_expressions_to_32bit = demote # global flag
Expand Down Expand Up @@ -734,7 +734,7 @@ def test_evaluator_jax_demotion(self):
assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col)
pybamm.demote_expressions_to_32bit = False

@pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed")
@pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed")
def test_jax_coo_matrix(self):
A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2))
Adense = jax.numpy.array([[1.0, 0], [0, 2.0]])
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,12 @@ def test_multiprocess_context(self):
assert solver.get_platform_context("Linux") == "fork"
assert solver.get_platform_context("Darwin") == "fork"

@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed")
@unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed")
def test_sensitivities(self):
def exact_diff_a(y, a, b):
return np.array([[y[0] ** 2 + 2 * a], [y[0]]])

@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed")
@unittest.skipIf(not pybamm.has_jax(), "jax or jaxlib is not installed")
def exact_diff_b(y, a, b):
return np.array([[y[0]], [0]])

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_solvers/test_idaklu_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys

testcase = []
if pybamm.have_idaklu() and pybamm.have_jax():
if pybamm.has_idaklu() and pybamm.has_jax():
from jax.tree_util import tree_flatten
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -89,7 +89,7 @@ def no_jit(f):

# Check the interface throws an appropriate error if either IDAKLU or JAX not available
@unittest.skipIf(
pybamm.have_idaklu() and pybamm.have_jax(),
pybamm.has_idaklu() and pybamm.has_jax(),
"Both IDAKLU and JAX are available",
)
class TestIDAKLUJax_NoJax(unittest.TestCase):
Expand All @@ -99,7 +99,7 @@ def test_instantiate_fails(self):


@unittest.skipIf(
not pybamm.have_idaklu() or not pybamm.have_jax(),
not pybamm.has_idaklu() or not pybamm.has_jax(),
"IDAKLU Solver and/or JAX are not available",
)
@pytest.mark.skipif(
Expand Down
Loading

0 comments on commit 02e4f97

Please sign in to comment.