Skip to content

Commit

Permalink
Remove time check from Jax tests to improve stability (#3963)
Browse files Browse the repository at this point in the history
* Remove time checks in tests

* Old changelog
  • Loading branch information
kratman authored Apr 4, 2024
1 parent 426e441 commit 8512040
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

## Breaking changes

- Removed support for Python 3.8 ([#3961](https://github.com/pybamm-team/PyBaMM/pull/3961))
- Renamed "ocp_soc_0_dimensional" to "ocp_soc_0" and "ocp_soc_100_dimensional" to "ocp_soc_100" ([#3942](https://github.com/pybamm-team/PyBaMM/pull/3942))
- The ODES solver was removed due to compatibility issues. Users should use IDAKLU, Casadi, or JAX instead. ([#3932](https://github.com/pybamm-team/PyBaMM/pull/3932))
- Integrated the `[pandas]` extra into the core PyBaMM package, deprecating the `pybamm[pandas]` optional dependency. Pandas is now a required dependency and will be installed upon installing PyBaMM ([#3892](https://github.com/pybamm-team/PyBaMM/pull/3892))
Expand Down
15 changes: 0 additions & 15 deletions tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from tests import get_mesh_for_testing
from tests import TestCase
import sys
import time
import numpy as np

if pybamm.have_jax():
Expand Down Expand Up @@ -36,19 +35,12 @@ def test_solver_(self): # Trailing _ manipulates the random seed
def fun(y, t):
return rhs(t=t, y=y).reshape(-1)

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8)
t1 = time.perf_counter() - t0

# test accuracy
np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6)

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8)
t2 = time.perf_counter() - t0

# second run should be much quicker
self.assertLess(t2, t1)

# test second run is accurate
np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6)
Expand All @@ -66,21 +58,14 @@ def fun(y, t):
# this as a guess
y0 = jax.numpy.array([1.0, 1.5])

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8)
t1 = time.perf_counter() - t0

# test accuracy
soln = np.exp(0.05 * t_eval)
np.testing.assert_allclose(y[:, 0], soln, rtol=1e-7, atol=1e-7)
np.testing.assert_allclose(y[:, 1], 2.0 * soln, rtol=1e-7, atol=1e-7)

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8)
t2 = time.perf_counter() - t0

# second run should be much quicker
self.assertLess(t2, t1)

# test second run is accurate
np.testing.assert_allclose(y[:, 0], np.exp(0.05 * t_eval), rtol=1e-7, atol=1e-7)
Expand Down
19 changes: 0 additions & 19 deletions tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from tests import get_mesh_for_testing
from tests import TestCase
import sys
import time
import numpy as np

if pybamm.have_jax():
Expand Down Expand Up @@ -32,9 +31,7 @@ def test_model_solver(self):
# Solve
solver = pybamm.JaxSolver(method=method, rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 1, 80)
t0 = time.perf_counter()
solution = solver.solve(model, t_eval)
t_first_solve = time.perf_counter() - t0
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(
solution.y[0], np.exp(0.1 * solution.t), rtol=1e-6, atol=1e-6
Expand All @@ -46,11 +43,8 @@ def test_model_solver(self):
)
self.assertEqual(solution.termination, "final time")

t0 = time.perf_counter()
second_solution = solver.solve(model, t_eval)
t_second_solve = time.perf_counter() - t0

self.assertLess(t_second_solve, t_first_solve)
np.testing.assert_array_equal(second_solution.y, solution.y)

def test_semi_explicit_model(self):
Expand All @@ -75,9 +69,7 @@ def test_semi_explicit_model(self):
# Solve
solver = pybamm.JaxSolver(method="BDF", rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 1, 80)
t0 = time.perf_counter()
solution = solver.solve(model, t_eval)
t_first_solve = time.perf_counter() - t0
np.testing.assert_array_equal(solution.t, t_eval)
soln = np.exp(0.1 * solution.t)
np.testing.assert_allclose(solution.y[0], soln, rtol=1e-7, atol=1e-7)
Expand All @@ -89,11 +81,7 @@ def test_semi_explicit_model(self):
)
self.assertEqual(solution.termination, "final time")

t0 = time.perf_counter()
second_solution = solver.solve(model, t_eval)
t_second_solve = time.perf_counter() - t0

self.assertLess(t_second_solve, t_first_solve)
np.testing.assert_array_equal(second_solution.y, solution.y)

def test_solver_sensitivities(self):
Expand Down Expand Up @@ -205,25 +193,18 @@ def test_model_solver_with_inputs(self):
# Solve
solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 80)

t0 = time.perf_counter()
solution = solver.solve(model, t_eval, inputs={"rate": 0.1})
t_first_solve = time.perf_counter() - t0

np.testing.assert_allclose(
solution.y[0], np.exp(-0.1 * solution.t), rtol=1e-6, atol=1e-6
)

t0 = time.perf_counter()
solution = solver.solve(model, t_eval, inputs={"rate": 0.2})
t_second_solve = time.perf_counter() - t0

np.testing.assert_allclose(
solution.y[0], np.exp(-0.2 * solution.t), rtol=1e-6, atol=1e-6
)

self.assertLess(t_second_solve, t_first_solve)

def test_get_solve(self):
# Create model
model = pybamm.BaseModel()
Expand Down

0 comments on commit 8512040

Please sign in to comment.