diff --git a/CHANGELOG.md b/CHANGELOG.md index 38747ea7fb..6c532e839c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ ## Breaking changes +- Replaced `have_jax` with `has_jax`, `have_idaklu` with `has_idaklu`, and + `have_iree` with `has_iree` ([#4398](https://github.com/pybamm-team/PyBaMM/pull/4398)) - Remove deprecated function `pybamm_install_jax` ([#4362](https://github.com/pybamm-team/PyBaMM/pull/4362)) - Removed legacy python-IDAKLU solver. ([#4326](https://github.com/pybamm-team/PyBaMM/pull/4326)) diff --git a/docs/source/api/util.rst b/docs/source/api/util.rst index 7496b59554..824ec6126d 100644 --- a/docs/source/api/util.rst +++ b/docs/source/api/util.rst @@ -16,6 +16,6 @@ Utility functions .. autofunction:: pybamm.load -.. autofunction:: pybamm.have_jax +.. autofunction:: pybamm.has_jax .. autofunction:: pybamm.is_jax_compatible diff --git a/examples/scripts/compare_dae_solver.py b/examples/scripts/compare_dae_solver.py index 815b458f1a..52ead1a242 100644 --- a/examples/scripts/compare_dae_solver.py +++ b/examples/scripts/compare_dae_solver.py @@ -28,7 +28,7 @@ casadi_sol = pybamm.CasadiSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval) solutions = [casadi_sol] -if pybamm.have_idaklu(): +if pybamm.has_idaklu(): klu_sol = pybamm.IDAKLUSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval) solutions.append(klu_sol) else: diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index 75f5f4f160..36ad0b137a 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -15,7 +15,7 @@ ) from .util import ( get_parameters_filepath, - have_jax, + has_jax, import_optional_dependency, is_jax_compatible, get_git_commit_info, @@ -170,7 +170,7 @@ from .solvers.jax_bdf_solver import jax_bdf_integrate from .solvers.idaklu_jax import IDAKLUJax -from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu, have_iree +from .solvers.idaklu_solver import IDAKLUSolver, has_idaklu, has_iree # Experiments from .experiment.experiment import Experiment diff --git a/src/pybamm/expression_tree/operations/evaluate_python.py b/src/pybamm/expression_tree/operations/evaluate_python.py index 20a6d4b4a2..a8a37ea7b2 100644 --- a/src/pybamm/expression_tree/operations/evaluate_python.py +++ b/src/pybamm/expression_tree/operations/evaluate_python.py @@ -11,7 +11,7 @@ import pybamm -if pybamm.have_jax(): +if pybamm.has_jax(): import jax platform = jax.lib.xla_bridge.get_backend().platform.casefold() @@ -43,7 +43,7 @@ class JaxCooMatrix: def __init__( self, row: ArrayLike, col: ArrayLike, data: ArrayLike, shape: tuple[int, int] ): - if not pybamm.have_jax(): # pragma: no cover + if not pybamm.has_jax(): # pragma: no cover raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" ) @@ -527,7 +527,7 @@ class EvaluatorJax: """ def __init__(self, symbol: pybamm.Symbol): - if not pybamm.have_jax(): # pragma: no cover + if not pybamm.has_jax(): # pragma: no cover raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" ) diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index 5a73d42c6e..991c16775e 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -22,7 +22,7 @@ except ImportError: # pragma: no cover idaklu_spec = None -if pybamm.have_jax(): +if pybamm.has_jax(): import jax from jax import lax from jax import numpy as jnp @@ -57,11 +57,11 @@ def __init__( calculate_sensitivities=True, t_interp=None, ): - if not pybamm.have_jax(): + if not pybamm.has_jax(): raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" ) # pragma: no cover - if not pybamm.have_idaklu(): + if not pybamm.has_idaklu(): raise ModuleNotFoundError( "IDAKLU is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html" ) # pragma: no cover diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index b92006d12d..41e0c8855f 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -14,7 +14,7 @@ import warnings -if pybamm.have_jax(): +if pybamm.has_jax(): import jax from jax import numpy as jnp @@ -33,11 +33,11 @@ idaklu_spec = None -def have_idaklu(): +def has_idaklu(): return idaklu_spec is not None -def have_iree(): +def has_iree(): try: import iree.compiler # noqa: F401 diff --git a/src/pybamm/solvers/jax_bdf_solver.py b/src/pybamm/solvers/jax_bdf_solver.py index 2c7bdc6d17..6f0c62b9a8 100644 --- a/src/pybamm/solvers/jax_bdf_solver.py +++ b/src/pybamm/solvers/jax_bdf_solver.py @@ -7,7 +7,7 @@ import pybamm -if pybamm.have_jax(): +if pybamm.has_jax(): import jax import jax.numpy as jnp from jax import core, dtypes @@ -1007,7 +1007,7 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): calculated state vector at each of the m time points """ - if not pybamm.have_jax(): + if not pybamm.has_jax(): raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" ) diff --git a/src/pybamm/solvers/jax_solver.py b/src/pybamm/solvers/jax_solver.py index 26a069e0fe..da5fd4983a 100644 --- a/src/pybamm/solvers/jax_solver.py +++ b/src/pybamm/solvers/jax_solver.py @@ -6,7 +6,7 @@ import pybamm -if pybamm.have_jax(): +if pybamm.has_jax(): import jax import jax.numpy as jnp from jax.experimental.ode import odeint @@ -59,7 +59,7 @@ def __init__( extrap_tol=None, extra_options=None, ): - if not pybamm.have_jax(): + if not pybamm.has_jax(): raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" ) diff --git a/src/pybamm/util.py b/src/pybamm/util.py index fd94eb88f4..527c55f526 100644 --- a/src/pybamm/util.py +++ b/src/pybamm/util.py @@ -264,7 +264,7 @@ def get_parameters_filepath(path): return os.path.join(pybamm.__path__[0], path) -def have_jax(): +def has_jax(): """ Check if jax and jaxlib are installed with the correct versions diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py index eddf2aa1e4..60e8dfb819 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/base_lithium_ion_tests.py @@ -22,7 +22,7 @@ def test_sensitivities(self): param = pybamm.ParameterValues("Ecker2015") rtol = 1e-6 atol = 1e-6 - if pybamm.have_idaklu(): + if pybamm.has_idaklu(): solver = pybamm.IDAKLUSolver(rtol=rtol, atol=atol) else: solver = pybamm.CasadiSolver(rtol=rtol, atol=atol) @@ -53,7 +53,7 @@ def test_optimisations(self): to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, to_python) - if pybamm.have_jax(): + if pybamm.has_jax(): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py index 6e67f349fa..00bce1a9d7 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_mpm.py @@ -25,7 +25,7 @@ def test_optimisations(self): to_python = optimtest.evaluate_model(to_python=True) np.testing.assert_array_almost_equal(original, to_python) - if pybamm.have_jax(): + if pybamm.has_jax(): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) diff --git a/tests/integration/test_solvers/test_idaklu.py b/tests/integration/test_solvers/test_idaklu.py index d70b64c783..88faa80dde 100644 --- a/tests/integration/test_solvers/test_idaklu.py +++ b/tests/integration/test_solvers/test_idaklu.py @@ -3,7 +3,7 @@ import numpy as np -@pytest.mark.skipif(not pybamm.have_idaklu(), reason="idaklu solver is not installed") +@pytest.mark.skipif(not pybamm.has_idaklu(), reason="idaklu solver is not installed") class TestIDAKLUSolver: def test_on_spme(self): model = pybamm.lithium_ion.SPMe() diff --git a/tests/unit/test_citations.py b/tests/unit/test_citations.py index 0928cc993c..7133cf234a 100644 --- a/tests/unit/test_citations.py +++ b/tests/unit/test_citations.py @@ -423,14 +423,14 @@ def test_solver_citations(self): assert "Virtanen2020" in citations._papers_to_cite assert "Virtanen2020" in citations._citation_tags.keys() - if pybamm.have_idaklu(): + if pybamm.has_idaklu(): citations._reset() assert "Hindmarsh2005" not in citations._papers_to_cite pybamm.IDAKLUSolver() assert "Hindmarsh2005" in citations._papers_to_cite assert "Hindmarsh2005" in citations._citation_tags.keys() - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_jax_citations(self): citations = pybamm.citations citations._reset() diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index c4f55889a1..3507d6e5c1 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -168,7 +168,7 @@ def test_run_experiment_multiple_times(self): sol1["Voltage [V]"].data, sol2["Voltage [V]"].data ) - @unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed") + @unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") def test_run_experiment_cccv_solvers(self): experiment_2step = pybamm.Experiment( [ diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index 518d8f8231..e6d8a0da83 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -12,7 +12,7 @@ from collections import OrderedDict import re -if pybamm.have_jax(): +if pybamm.has_jax(): import jax from tests import ( function_test, @@ -446,7 +446,7 @@ def test_evaluator_python(self): result = evaluator(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_find_symbols_jax(self): # test sparse conversion constant_symbols = OrderedDict() @@ -459,7 +459,7 @@ def test_find_symbols_jax(self): next(iter(constant_symbols.values())).toarray(), A.entries.toarray() ) - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) @@ -621,7 +621,7 @@ def test_evaluator_jax(self): result = evaluator(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_evaluator_jax_jacobian(self): a = pybamm.StateVector(slice(0, 1)) y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])] @@ -636,7 +636,7 @@ def test_evaluator_jax_jacobian(self): result_true = evaluator_jac(t=None, y=y) np.testing.assert_allclose(result_test, result_true) - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_evaluator_jax_jvp(self): a = pybamm.StateVector(slice(0, 1)) y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])] @@ -656,7 +656,7 @@ def test_evaluator_jax_jvp(self): np.testing.assert_allclose(result_test, result_true) np.testing.assert_allclose(result_test_times_v, result_true_times_v) - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_evaluator_jax_debug(self): a = pybamm.StateVector(slice(0, 1)) expr = a**2 @@ -664,7 +664,7 @@ def test_evaluator_jax_debug(self): evaluator = pybamm.EvaluatorJax(expr) evaluator.debug(y=y_test) - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_evaluator_jax_inputs(self): a = pybamm.InputParameter("a") expr = a**2 @@ -672,7 +672,7 @@ def test_evaluator_jax_inputs(self): result = evaluator(inputs={"a": 2}) assert result == 4 - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_evaluator_jax_demotion(self): for demote in [True, False]: pybamm.demote_expressions_to_32bit = demote # global flag @@ -734,7 +734,7 @@ def test_evaluator_jax_demotion(self): assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col) pybamm.demote_expressions_to_32bit = False - @pytest.mark.skipif(not pybamm.have_jax(), reason="jax or jaxlib is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_jax_coo_matrix(self): A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) Adense = jax.numpy.array([[1.0, 0], [0, 2.0]]) diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index e86b0f702e..a4b43e1dd2 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -355,12 +355,12 @@ def test_multiprocess_context(self): assert solver.get_platform_context("Linux") == "fork" assert solver.get_platform_context("Darwin") == "fork" - @unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed") + @unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") def test_sensitivities(self): def exact_diff_a(y, a, b): return np.array([[y[0] ** 2 + 2 * a], [y[0]]]) - @unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") + @unittest.skipIf(not pybamm.has_jax(), "jax or jaxlib is not installed") def exact_diff_b(y, a, b): return np.array([[y[0]], [0]]) diff --git a/tests/unit/test_solvers/test_idaklu_jax.py b/tests/unit/test_solvers/test_idaklu_jax.py index d44d895a0f..82790126f0 100644 --- a/tests/unit/test_solvers/test_idaklu_jax.py +++ b/tests/unit/test_solvers/test_idaklu_jax.py @@ -11,7 +11,7 @@ import sys testcase = [] -if pybamm.have_idaklu() and pybamm.have_jax(): +if pybamm.has_idaklu() and pybamm.has_jax(): from jax.tree_util import tree_flatten import jax import jax.numpy as jnp @@ -89,7 +89,7 @@ def no_jit(f): # Check the interface throws an appropriate error if either IDAKLU or JAX not available @unittest.skipIf( - pybamm.have_idaklu() and pybamm.have_jax(), + pybamm.has_idaklu() and pybamm.has_jax(), "Both IDAKLU and JAX are available", ) class TestIDAKLUJax_NoJax(unittest.TestCase): @@ -99,7 +99,7 @@ def test_instantiate_fails(self): @unittest.skipIf( - not pybamm.have_idaklu() or not pybamm.have_jax(), + not pybamm.has_idaklu() or not pybamm.has_jax(), "IDAKLU Solver and/or JAX are not available", ) @pytest.mark.skipif( diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 1a264d6445..37ba2c147a 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -14,14 +14,14 @@ @pytest.mark.cibw -@unittest.skipIf(not pybamm.have_idaklu(), "idaklu solver is not installed") +@unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") class TestIDAKLUSolver(unittest.TestCase): def test_ida_roberts_klu(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -67,7 +67,7 @@ def test_ida_roberts_klu(self): def test_model_events(self): for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -188,7 +188,7 @@ def test_model_events(self): def test_input_params(self): # test a mix of scalar and vector input params for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -246,9 +246,7 @@ def test_input_params(self): def test_sensitivities_initial_condition(self): for form in ["casadi", "iree"]: for output_variables in [[], ["2v"]]: - if (form == "iree") and ( - not pybamm.have_jax() or not pybamm.have_iree() - ): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -299,7 +297,7 @@ def test_ida_roberts_klu_sensitivities(self): # example provided in sundials # see sundials ida examples pdf for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -405,7 +403,7 @@ def test_ida_roberts_consistent_initialization(self): # example provided in sundials # see sundials ida examples pdf for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -447,7 +445,7 @@ def test_sensitivities_with_events(self): # example provided in sundials # see sundials ida examples pdf for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -606,7 +604,7 @@ def test_failures(self): def test_dae_solver_algebraic_model(self): for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -915,7 +913,7 @@ def test_with_output_variables_and_sensitivities(self): # equivalence for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.have_jax() or not pybamm.have_iree()): + if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): continue if form == "casadi": root_method = "casadi" @@ -1088,7 +1086,7 @@ def test_interpolate_time_step_start_offset(self): def test_python_idaklu_deprecation_errors(self): for form in ["python", "", "jax"]: - if form == "jax" and not pybamm.have_jax(): + if form == "jax" and not pybamm.has_jax(): continue model = pybamm.BaseModel() diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index e02bdb2510..e0064ae463 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -5,11 +5,11 @@ import sys import numpy as np -if pybamm.have_jax(): +if pybamm.has_jax(): import jax -@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") +@unittest.skipIf(not pybamm.has_jax(), "jax or jaxlib is not installed") class TestJaxBDFSolver(unittest.TestCase): def test_solver_(self): # Trailing _ manipulates the random seed # Create model diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 4f34497626..b1c293c2f2 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -5,11 +5,11 @@ import sys import numpy as np -if pybamm.have_jax(): +if pybamm.has_jax(): import jax -@unittest.skipIf(not pybamm.have_jax(), "jax or jaxlib is not installed") +@unittest.skipIf(not pybamm.has_jax(), "jax or jaxlib is not installed") class TestJaxSolver(unittest.TestCase): def test_model_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index c6afd16704..446206e95c 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -11,7 +11,7 @@ class TestScipySolver(unittest.TestCase): def test_model_solver_python_and_jax(self): - if pybamm.have_jax(): + if pybamm.has_jax(): formats = ["python", "jax"] else: formats = ["python"] @@ -339,7 +339,7 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) def test_model_solver_multiple_inputs_jax_format(self): - if pybamm.have_jax(): + if pybamm.has_jax(): # Create model model = pybamm.BaseModel() model.convert_to_format = "jax" diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 1b621d98f0..058b7d4a14 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -88,7 +88,7 @@ def test_get_parameters_filepath(self): path = os.path.join(package_dir, tempfile_obj.name) assert pybamm.get_parameters_filepath(tempfile_obj.name) == path - @pytest.mark.skipif(not pybamm.have_jax(), reason="JAX is not installed") + @pytest.mark.skipif(not pybamm.has_jax(), reason="JAX is not installed") def test_is_jax_compatible(self): assert pybamm.is_jax_compatible()