Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove time check from Jax tests to improve stability #3963

Merged
merged 2 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

## Breaking changes

- Removed support for Python 3.8 ([#3961](https://github.com/pybamm-team/PyBaMM/pull/3961))
- Renamed "ocp_soc_0_dimensional" to "ocp_soc_0" and "ocp_soc_100_dimensional" to "ocp_soc_100" ([#3942](https://github.com/pybamm-team/PyBaMM/pull/3942))
- The ODES solver was removed due to compatibility issues. Users should use IDAKLU, Casadi, or JAX instead. ([#3932](https://github.com/pybamm-team/PyBaMM/pull/3932))
- Integrated the `[pandas]` extra into the core PyBaMM package, deprecating the `pybamm[pandas]` optional dependency. Pandas is now a required dependency and will be installed upon installing PyBaMM ([#3892](https://github.com/pybamm-team/PyBaMM/pull/3892))
Expand Down
15 changes: 0 additions & 15 deletions tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from tests import get_mesh_for_testing
from tests import TestCase
import sys
import time
import numpy as np

if pybamm.have_jax():
Expand Down Expand Up @@ -36,19 +35,12 @@ def test_solver_(self): # Trailing _ manipulates the random seed
def fun(y, t):
return rhs(t=t, y=y).reshape(-1)

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8)
t1 = time.perf_counter() - t0

# test accuracy
np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6)

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8)
t2 = time.perf_counter() - t0

# second run should be much quicker
self.assertLess(t2, t1)

# test second run is accurate
np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-6, atol=1e-6)
Expand All @@ -66,21 +58,14 @@ def fun(y, t):
# this as a guess
y0 = jax.numpy.array([1.0, 1.5])

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8)
t1 = time.perf_counter() - t0

# test accuracy
soln = np.exp(0.05 * t_eval)
np.testing.assert_allclose(y[:, 0], soln, rtol=1e-7, atol=1e-7)
np.testing.assert_allclose(y[:, 1], 2.0 * soln, rtol=1e-7, atol=1e-7)

t0 = time.perf_counter()
y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8)
t2 = time.perf_counter() - t0

# second run should be much quicker
self.assertLess(t2, t1)

# test second run is accurate
np.testing.assert_allclose(y[:, 0], np.exp(0.05 * t_eval), rtol=1e-7, atol=1e-7)
Expand Down
19 changes: 0 additions & 19 deletions tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from tests import get_mesh_for_testing
from tests import TestCase
import sys
import time
import numpy as np

if pybamm.have_jax():
Expand Down Expand Up @@ -32,9 +31,7 @@ def test_model_solver(self):
# Solve
solver = pybamm.JaxSolver(method=method, rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 1, 80)
t0 = time.perf_counter()
solution = solver.solve(model, t_eval)
t_first_solve = time.perf_counter() - t0
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(
solution.y[0], np.exp(0.1 * solution.t), rtol=1e-6, atol=1e-6
Expand All @@ -46,11 +43,8 @@ def test_model_solver(self):
)
self.assertEqual(solution.termination, "final time")

t0 = time.perf_counter()
second_solution = solver.solve(model, t_eval)
t_second_solve = time.perf_counter() - t0

self.assertLess(t_second_solve, t_first_solve)
np.testing.assert_array_equal(second_solution.y, solution.y)

def test_semi_explicit_model(self):
Expand All @@ -75,9 +69,7 @@ def test_semi_explicit_model(self):
# Solve
solver = pybamm.JaxSolver(method="BDF", rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 1, 80)
t0 = time.perf_counter()
solution = solver.solve(model, t_eval)
t_first_solve = time.perf_counter() - t0
np.testing.assert_array_equal(solution.t, t_eval)
soln = np.exp(0.1 * solution.t)
np.testing.assert_allclose(solution.y[0], soln, rtol=1e-7, atol=1e-7)
Expand All @@ -89,11 +81,7 @@ def test_semi_explicit_model(self):
)
self.assertEqual(solution.termination, "final time")

t0 = time.perf_counter()
second_solution = solver.solve(model, t_eval)
t_second_solve = time.perf_counter() - t0

self.assertLess(t_second_solve, t_first_solve)
np.testing.assert_array_equal(second_solution.y, solution.y)

def test_solver_sensitivities(self):
Expand Down Expand Up @@ -205,25 +193,18 @@ def test_model_solver_with_inputs(self):
# Solve
solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 80)

t0 = time.perf_counter()
solution = solver.solve(model, t_eval, inputs={"rate": 0.1})
t_first_solve = time.perf_counter() - t0

np.testing.assert_allclose(
solution.y[0], np.exp(-0.1 * solution.t), rtol=1e-6, atol=1e-6
)

t0 = time.perf_counter()
solution = solver.solve(model, t_eval, inputs={"rate": 0.2})
t_second_solve = time.perf_counter() - t0

np.testing.assert_allclose(
solution.y[0], np.exp(-0.2 * solution.t), rtol=1e-6, atol=1e-6
)

self.assertLess(t_second_solve, t_first_solve)

def test_get_solve(self):
# Create model
model = pybamm.BaseModel()
Expand Down
Loading