From 8512040f7ca4a96b2ee79fc563d567efe62784e6 Mon Sep 17 00:00:00 2001 From: "Eric G. Kratz" Date: Wed, 3 Apr 2024 22:07:02 -0400 Subject: [PATCH] Remove time check from Jax tests to improve stability (#3963) * Remove time checks in tests * Old changelog --- CHANGELOG.md | 1 + .../unit/test_solvers/test_jax_bdf_solver.py | 15 --------------- tests/unit/test_solvers/test_jax_solver.py | 19 ------------------- 3 files changed, 1 insertion(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d13c17ad25..351b644ac8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ ## Breaking changes +- Removed support for Python 3.8 ([#3961](https://github.com/pybamm-team/PyBaMM/pull/3961)) - Renamed "ocp_soc_0_dimensional" to "ocp_soc_0" and "ocp_soc_100_dimensional" to "ocp_soc_100" ([#3942](https://github.com/pybamm-team/PyBaMM/pull/3942)) - The ODES solver was removed due to compatibility issues. Users should use IDAKLU, Casadi, or JAX instead. ([#3932](https://github.com/pybamm-team/PyBaMM/pull/3932)) - Integrated the `[pandas]` extra into the core PyBaMM package, deprecating the `pybamm[pandas]` optional dependency. Pandas is now a required dependency and will be installed upon installing PyBaMM ([#3892](https://github.com/pybamm-team/PyBaMM/pull/3892)) diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index 92fb710ea9..854a618fba 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -3,7 +3,6 @@ from tests import get_mesh_for_testing from tests import TestCase import sys -import time import numpy as np if pybamm.have_jax(): @@ -36,19 +35,12 @@ def test_solver_(self): # Trailing _ manipulates the random seed def fun(y, t): return rhs(t=t, y=y).reshape(-1) - t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) - t1 = time.perf_counter() - t0 # test accuracy np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6) - t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) - t2 = time.perf_counter() - t0 - - # second run should be much quicker - self.assertLess(t2, t1) # test second run is accurate np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6) @@ -66,21 +58,14 @@ def fun(y, t): # this as a guess y0 = jax.numpy.array([1.0, 1.5]) - t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) - t1 = time.perf_counter() - t0 # test accuracy soln = np.exp(0.05 * t_eval) np.testing.assert_allclose(y[:, 0], soln, rtol=1e-7, atol=1e-7) np.testing.assert_allclose(y[:, 1], 2.0 * soln, rtol=1e-7, atol=1e-7) - t0 = time.perf_counter() y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) - t2 = time.perf_counter() - t0 - - # second run should be much quicker - self.assertLess(t2, t1) # test second run is accurate np.testing.assert_allclose(y[:, 0], np.exp(0.05 * t_eval), rtol=1e-7, atol=1e-7) diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 1a84f2bea4..9df28e8ac2 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -3,7 +3,6 @@ from tests import get_mesh_for_testing from tests import TestCase import sys -import time import numpy as np if pybamm.have_jax(): @@ -32,9 +31,7 @@ def test_model_solver(self): # Solve solver = pybamm.JaxSolver(method=method, rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 1, 80) - t0 = time.perf_counter() solution = solver.solve(model, t_eval) - t_first_solve = time.perf_counter() - t0 np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_allclose( solution.y[0], np.exp(0.1 * solution.t), rtol=1e-6, atol=1e-6 @@ -46,11 +43,8 @@ def test_model_solver(self): ) self.assertEqual(solution.termination, "final time") - t0 = time.perf_counter() second_solution = solver.solve(model, t_eval) - t_second_solve = time.perf_counter() - t0 - self.assertLess(t_second_solve, t_first_solve) np.testing.assert_array_equal(second_solution.y, solution.y) def test_semi_explicit_model(self): @@ -75,9 +69,7 @@ def test_semi_explicit_model(self): # Solve solver = pybamm.JaxSolver(method="BDF", rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 1, 80) - t0 = time.perf_counter() solution = solver.solve(model, t_eval) - t_first_solve = time.perf_counter() - t0 np.testing.assert_array_equal(solution.t, t_eval) soln = np.exp(0.1 * solution.t) np.testing.assert_allclose(solution.y[0], soln, rtol=1e-7, atol=1e-7) @@ -89,11 +81,7 @@ def test_semi_explicit_model(self): ) self.assertEqual(solution.termination, "final time") - t0 = time.perf_counter() second_solution = solver.solve(model, t_eval) - t_second_solve = time.perf_counter() - t0 - - self.assertLess(t_second_solve, t_first_solve) np.testing.assert_array_equal(second_solution.y, solution.y) def test_solver_sensitivities(self): @@ -205,25 +193,18 @@ def test_model_solver_with_inputs(self): # Solve solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 5, 80) - - t0 = time.perf_counter() solution = solver.solve(model, t_eval, inputs={"rate": 0.1}) - t_first_solve = time.perf_counter() - t0 np.testing.assert_allclose( solution.y[0], np.exp(-0.1 * solution.t), rtol=1e-6, atol=1e-6 ) - t0 = time.perf_counter() solution = solver.solve(model, t_eval, inputs={"rate": 0.2}) - t_second_solve = time.perf_counter() - t0 np.testing.assert_allclose( solution.y[0], np.exp(-0.2 * solution.t), rtol=1e-6, atol=1e-6 ) - self.assertLess(t_second_solve, t_first_solve) - def test_get_solve(self): # Create model model = pybamm.BaseModel()