diff --git a/.github/workflows/work_precision_sets.yml b/.github/workflows/work_precision_sets.yml
index fafc5b1738..5bd8806c78 100644
--- a/.github/workflows/work_precision_sets.yml
+++ b/.github/workflows/work_precision_sets.yml
@@ -21,7 +21,7 @@ jobs:
run: python -m pip install pybamm==${{ env.VERSION }}
- name: Run time_vs_* benchmarks for PyBaMM v${{ env.VERSION }}
run: |
- python benchmarks/work_precision_sets/time_vs_dt_max.py
+ python benchmarks/work_precision_sets/time_vs_dt_event.py
python benchmarks/work_precision_sets/time_vs_mesh_size.py
python benchmarks/work_precision_sets/time_vs_no_of_states.py
python benchmarks/work_precision_sets/time_vs_reltols.py
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 11c32a0b49..734cfc969d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,7 @@
- Renamed "electrode diffusivity" to "particle diffusivity" as a non-breaking change with a deprecation warning ([#3624](https://github.com/pybamm-team/PyBaMM/pull/3624))
- Add support for BPX version 0.4.0 which allows for blended electrodes and user-defined parameters in BPX([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414))
- Added the ability to specify a custom solver tolerance in `get_initial_stoichiometries` and related functions ([#3714](https://github.com/pybamm-team/PyBaMM/pull/3714))
+- Added `max_step` parameter to `BaseSolver` and passed it to dependent solvers ([#3106](https://github.com/pybamm-team/PyBaMM/pull/3106))
## Bug Fixes
diff --git a/benchmarks/release-work-precision-sets.md b/benchmarks/release-work-precision-sets.md
index 12ce444a98..b424f02366 100644
--- a/benchmarks/release-work-precision-sets.md
+++ b/benchmarks/release-work-precision-sets.md
@@ -12,9 +12,9 @@
-## Solve Time vs dt_max
+## Solve Time vs dt_event
-
+
## Solve Time vs Number of states
diff --git a/benchmarks/work_precision_sets/time_vs_dt_max.py b/benchmarks/work_precision_sets/time_vs_dt_event.py
similarity index 85%
rename from benchmarks/work_precision_sets/time_vs_dt_max.py
rename to benchmarks/work_precision_sets/time_vs_dt_event.py
index a1f8ca06bc..16068259cd 100644
--- a/benchmarks/work_precision_sets/time_vs_dt_max.py
+++ b/benchmarks/work_precision_sets/time_vs_dt_event.py
@@ -12,7 +12,7 @@
models = {"SPM": pybamm.lithium_ion.SPM(), "DFN": pybamm.lithium_ion.DFN()}
-dt_max = [
+dt_event = [
10,
20,
50,
@@ -70,8 +70,8 @@
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
- for t in dt_max:
- solver = pybamm.CasadiSolver(dt_max=t)
+ for t in dt_event:
+ solver = pybamm.CasadiSolver(dt_event=t)
solver.solve(model, t_eval=t_eval)
time = 0
@@ -85,20 +85,20 @@
ax.set_xscale("log")
ax.set_yscale("log")
- ax.set_xlabel("dt_max")
+ ax.set_xlabel("dt_event")
ax.set_ylabel("time(s)")
ax.set_title(f"{model_name}")
- ax.plot(dt_max, time_points)
+ ax.plot(dt_event, time_points)
plt.tight_layout()
plt.gca().legend(
parameters,
loc="upper right",
)
-plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_max_{pybamm.__version__}.png")
+plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_event_{pybamm.__version__}.png")
-content = f"## Solve Time vs dt_max\n\n"
+content = f"## Solve Time vs dt_event\n\n"
with open("./benchmarks/release_work_precision_sets.md") as original:
data = original.read()
diff --git a/docs/source/examples/notebooks/models/composite_particle.ipynb b/docs/source/examples/notebooks/models/composite_particle.ipynb
index 387ffd7512..c41f020e75 100644
--- a/docs/source/examples/notebooks/models/composite_particle.ipynb
+++ b/docs/source/examples/notebooks/models/composite_particle.ipynb
@@ -263,7 +263,7 @@
" sim = pybamm.Simulation(\n",
" model,\n",
" parameter_values=param,\n",
- " solver=pybamm.CasadiSolver(dt_max=5),\n",
+ " solver=pybamm.CasadiSolver(dt_event=5),\n",
" )\n",
" solution.append(sim.solve(t_eval=t_eval))\n",
"stop = timeit.default_timer()\n",
@@ -887,7 +887,7 @@
" model,\n",
" experiment=experiment,\n",
" parameter_values=param,\n",
- " solver=pybamm.CasadiSolver(dt_max=5),\n",
+ " solver=pybamm.CasadiSolver(dt_event=5),\n",
" )\n",
" solution.append(sim.solve(calc_esoh=False))\n",
"stop = timeit.default_timer()\n",
diff --git a/docs/source/examples/notebooks/models/rate-capability.ipynb b/docs/source/examples/notebooks/models/rate-capability.ipynb
index ef09a37909..43ef82aee2 100644
--- a/docs/source/examples/notebooks/models/rate-capability.ipynb
+++ b/docs/source/examples/notebooks/models/rate-capability.ipynb
@@ -100,7 +100,7 @@
" [f\"Discharge at {C_rate:.4f}C until 3.2V\"], period=f\"{10 / C_rate:.4f} seconds\"\n",
" )\n",
" sim = pybamm.Simulation(\n",
- " model, experiment=experiment, solver=pybamm.CasadiSolver(dt_max=120)\n",
+ " model, experiment=experiment, solver=pybamm.CasadiSolver(dt_event=120)\n",
" )\n",
" sim.solve()\n",
"\n",
diff --git a/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb b/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb
index ac92c06d15..96cdf9556f 100644
--- a/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb
+++ b/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb
@@ -144,7 +144,7 @@
"sim = pybamm.Simulation(\n",
" model,\n",
" parameter_values=param,\n",
- " solver=pybamm.CasadiSolver(dt_max=600),\n",
+ " solver=pybamm.CasadiSolver(dt_event=600),\n",
" var_pts=var_pts,\n",
")\n",
"solution = sim.solve(t_eval=[0, 3600], inputs={\"C-rate\": 1})\n",
diff --git a/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb b/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb
index c49c8926fb..5f2967ce30 100644
--- a/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb
+++ b/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb
@@ -366,7 +366,7 @@
"metadata": {},
"outputs": [],
"source": [
- "safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_max=30)\n",
+ "safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_event=30)\n",
"safe_sol_2 = sim.solve([0, 160], solver=safe_solver_2, inputs={\"Crate\": 10})"
]
},
@@ -374,14 +374,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Choosing dt_max to speed up the safe mode"
+ "### Choosing dt_event to speed up the safe mode"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "The parameter `dt_max` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events."
+ "The parameter `dt_event` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events."
]
},
{
@@ -393,24 +393,24 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "With dt_max=10, took 575.783 ms (integration time: 508.473 ms)\n",
- "With dt_max=20, took 575.500 ms (integration time: 510.705 ms)\n",
- "With dt_max=100, took 316.721 ms (integration time: 275.459 ms)\n",
- "With dt_max=1000, took 76.646 ms (integration time: 49.294 ms)\n",
- "With dt_max=3700, took 48.773 ms (integration time: 32.436 ms)\n",
+ "With dt_event=10, took 575.783 ms (integration time: 508.473 ms)\n",
+ "With dt_event=20, took 575.500 ms (integration time: 510.705 ms)\n",
+ "With dt_event=100, took 316.721 ms (integration time: 275.459 ms)\n",
+ "With dt_event=1000, took 76.646 ms (integration time: 49.294 ms)\n",
+ "With dt_event=3700, took 48.773 ms (integration time: 32.436 ms)\n",
"With 'fast' mode, took 42.224 ms (integration time: 32.177 ms)\n"
]
}
],
"source": [
- "for dt_max in [10, 20, 100, 1000, 3700]:\n",
+ "for dt_event in [10, 20, 100, 1000, 3700]:\n",
" safe_sol = sim.solve(\n",
" [0, 3600],\n",
- " solver=pybamm.CasadiSolver(mode=\"safe\", dt_max=dt_max),\n",
+ " solver=pybamm.CasadiSolver(mode=\"safe\", dt_event=dt_event),\n",
" inputs={\"Crate\": 1},\n",
" )\n",
" print(\n",
- " f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n",
+ " f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n",
" + f\"(integration time: {safe_sol.integration_time})\"\n",
" )\n",
"\n",
@@ -425,9 +425,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "In general, a larger value of `dt_max` gives a faster solution, since fewer integrator creations and calls are required.\n",
+ "In general, a larger value of `dt_event` gives a faster solution, since fewer integrator creations and calls are required.\n",
"\n",
- "Below the solution time interval of 36s, the value of `dt_max` does not affect the solve time, since steps must be at least 36s large.\n",
+ "Below the solution time interval of 36s, the value of `dt_event` does not affect the solve time, since steps must be at least 36s large.\n",
"The discrepancy between the solve time and integration time is due to the extra operations recorded by \"solve time\", such as creating the integrator. The \"fast\" solver does not need to do this (it reuses the first one it had already created), so the solve time is much closer to the integration time."
]
},
@@ -435,7 +435,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "The example above was a case where no events are triggered, so the largest `dt_max` works well. If we step over events, then it is possible to makes `dt_max` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:"
+ "The example above was a case where no events are triggered, so the largest `dt_event` works well. If we step over events, then it is possible to makes `dt_event` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:"
]
},
{
@@ -447,10 +447,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "With dt_max=10, took 504.163 ms (integration time: 419.740 ms)\n",
- "With dt_max=20, took 504.691 ms (integration time: 421.396 ms)\n",
- "With dt_max=100, took 286.620 ms (integration time: 238.390 ms)\n",
- "With dt_max=1000, took 98.500 ms (integration time: 60.880 ms)\n"
+ "With dt_event=10, took 504.163 ms (integration time: 419.740 ms)\n",
+ "With dt_event=20, took 504.691 ms (integration time: 421.396 ms)\n",
+ "With dt_event=100, took 286.620 ms (integration time: 238.390 ms)\n",
+ "With dt_event=1000, took 98.500 ms (integration time: 60.880 ms)\n"
]
},
{
@@ -466,22 +466,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "With dt_max=3600, took 645.118 ms (integration time: 32.601 ms)\n"
+ "With dt_event=3600, took 645.118 ms (integration time: 32.601 ms)\n"
]
}
],
"source": [
- "for dt_max in [10, 20, 100, 1000, 3600]:\n",
+ "for dt_event in [10, 20, 100, 1000, 3600]:\n",
" # Reduce max_num_steps to fail faster\n",
" safe_sol = sim.solve(\n",
" [0, 4500],\n",
" solver=pybamm.CasadiSolver(\n",
- " mode=\"safe\", dt_max=dt_max, extra_options_setup={\"max_num_steps\": 1000}\n",
+ " mode=\"safe\", dt_event=dt_event, extra_options_setup={\"max_num_steps\": 1000}\n",
" ),\n",
" inputs={\"Crate\": 1},\n",
" )\n",
" print(\n",
- " f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n",
+ " f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n",
" + f\"(integration time: {safe_sol.integration_time})\"\n",
" )"
]
@@ -490,7 +490,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "The integration time with `dt_max=3600` remains the fastest, but the solve time is the slowest due to all the failed steps."
+ "The integration time with `dt_event=3600` remains the fastest, but the solve time is the slowest due to all the failed steps."
]
},
{
@@ -504,7 +504,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_max`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve."
+ "The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_event`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve."
]
},
{
diff --git a/pybamm/__init__.py b/pybamm/__init__.py
index ab2e72ed28..add419f94c 100644
--- a/pybamm/__init__.py
+++ b/pybamm/__init__.py
@@ -219,6 +219,7 @@
from .solvers.jax_solver import JaxSolver
from .solvers.jax_bdf_solver import jax_bdf_integrate
+from .solvers.base_solver import validate_max_step
from .solvers.idaklu_jax import IDAKLUJax
from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu
diff --git a/pybamm/solvers/algebraic_solver.py b/pybamm/solvers/algebraic_solver.py
index d241d5b24c..e3ff430a61 100644
--- a/pybamm/solvers/algebraic_solver.py
+++ b/pybamm/solvers/algebraic_solver.py
@@ -6,6 +6,7 @@
import numpy as np
from scipy import optimize
from scipy.sparse import issparse
+from .base_solver import validate_max_step
class AlgebraicSolver(pybamm.BaseSolver):
@@ -24,14 +25,18 @@ class AlgebraicSolver(pybamm.BaseSolver):
specified in the form "lsq_methodname"
tol : float, optional
The tolerance for the solver (default is 1e-6).
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the rootfinder. Vary depending on which method is chosen.
Please consult `SciPy documentation `_ for
details.
"""
- def __init__(self, method="lm", tol=1e-6, extra_options=None):
+ def __init__(self, method="lm", tol=1e-6, max_step=np.inf, extra_options=None):
super().__init__(method=method)
+ self.max_step = validate_max_step(max_step)
self.tol = tol
self.extra_options = extra_options or {}
self.name = f"Algebraic solver ({method})"
diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py
index 69de3be968..7917dc4692 100644
--- a/pybamm/solvers/base_solver.py
+++ b/pybamm/solvers/base_solver.py
@@ -16,6 +16,13 @@
from pybamm.expression_tree.binary_operators import _Heaviside
+def validate_max_step(max_step):
+ """Assert that max_step is valid and return it."""
+ if max_step <= 0:
+ raise ValueError("`max_step` must be positive.")
+ return max_step
+
+
class BaseSolver:
"""Solve a discretised model.
@@ -38,6 +45,9 @@ class BaseSolver:
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
output_variables : list[str], optional
List of variables to calculate and return. If none are specified then
the complete state vector is returned (can be very large) (default is [])
@@ -51,6 +61,7 @@ def __init__(
root_method=None,
root_tol=1e-6,
extrap_tol=None,
+ max_step=np.inf,
output_variables=[],
):
self.method = method
@@ -59,6 +70,7 @@ def __init__(
self.root_tol = root_tol
self.root_method = root_method
self.extrap_tol = extrap_tol or -1e-10
+ self.max_step = validate_max_step(max_step)
self.output_variables = output_variables
self._model_set_up = {}
diff --git a/pybamm/solvers/casadi_algebraic_solver.py b/pybamm/solvers/casadi_algebraic_solver.py
index ec7305906a..72808e0486 100644
--- a/pybamm/solvers/casadi_algebraic_solver.py
+++ b/pybamm/solvers/casadi_algebraic_solver.py
@@ -4,6 +4,7 @@
import casadi
import pybamm
import numpy as np
+from .base_solver import validate_max_step
class CasadiAlgebraicSolver(pybamm.BaseSolver):
@@ -17,6 +18,9 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver):
----------
tol : float, optional
The tolerance for the solver (default is 1e-6).
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the CasADi rootfinder.
Please consult `CasADi documentation `_ for
@@ -24,10 +28,11 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver):
"""
- def __init__(self, tol=1e-6, extra_options=None):
+ def __init__(self, tol=1e-6, max_step=np.inf, extra_options=None):
super().__init__()
self.tol = tol
self.name = "CasADi algebraic solver"
+ self.max_step = validate_max_step(max_step)
self.algebraic_solver = True
self.extra_options = extra_options or {}
pybamm.citations.register("Andersson2019")
diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py
index 02ff4a2cd9..d68152dc3c 100644
--- a/pybamm/solvers/casadi_solver.py
+++ b/pybamm/solvers/casadi_solver.py
@@ -7,6 +7,7 @@
import warnings
from scipy.interpolate import interp1d
from .lrudict import LRUDict
+from .base_solver import validate_max_step
class CasadiSolver(pybamm.BaseSolver):
@@ -23,7 +24,7 @@ class CasadiSolver(pybamm.BaseSolver):
- "fast with events": perform direct integration of the whole timespan, \
then go back and check where events were crossed. Experimental only.
- "safe": perform step-and-check integration in global steps of size \
- dt_max, checking whether events have been triggered. Recommended for \
+ dt_event, checking whether events have been triggered. Recommended for \
simulations of a full charge or discharge.
- "safe without grid": perform step-and-check integration step-by-step. \
Takes more steps than "safe" mode, but doesn't require creating the grid \
@@ -41,10 +42,13 @@ class CasadiSolver(pybamm.BaseSolver):
specified by 'root_method' (e.g. "lm", "hybr", ...)
root_tol : float, optional
The tolerance for root-finding. Default is 1e-6.
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
max_step_decrease_count : float, optional
The maximum number of times step size can be decreased before an error is
raised. Default is 5.
- dt_max : float, optional
+ dt_event : float, optional
The maximum global step size (in seconds) used in "safe" mode. If None
the default value is 600 seconds.
extrap_tol : float, optional
@@ -81,8 +85,9 @@ def __init__(
atol=1e-6,
root_method="casadi",
root_tol=1e-6,
+ max_step=np.inf,
max_step_decrease_count=5,
- dt_max=None,
+ dt_event=None,
extrap_tol=None,
extra_options_setup=None,
extra_options_call=None,
@@ -97,6 +102,7 @@ def __init__(
root_method,
root_tol,
extrap_tol,
+ max_step,
)
if mode in ["safe", "fast", "fast with events", "safe without grid"]:
self.mode = mode
@@ -106,8 +112,9 @@ def __init__(
"'fast', for solving quickly without events, or 'safe without grid' or "
"'fast with events' (both experimental)"
)
+ self.max_step = validate_max_step(max_step)
self.max_step_decrease_count = max_step_decrease_count
- self.dt_max = dt_max or 600
+ self.dt_event = dt_event or 600
self.extra_options_setup = extra_options_setup or {}
self.extra_options_call = extra_options_call or {}
@@ -205,24 +212,24 @@ def _integrate(self, model, t_eval, inputs_dict=None):
solution = None
use_grid = True
- # Try to integrate in global steps of size dt_max. Note: dt_max must
+ # Try to integrate in global steps of size dt_event. Note: dt_event must
# be at least as big as the the biggest step in t_eval (multiplied
# by some tolerance, here 1.01) to avoid an empty integration window below
- dt_max = self.dt_max
+ dt_event = self.dt_event
dt_eval_max = np.max(np.diff(t_eval)) * 1.01
- if dt_max < dt_eval_max:
+ if dt_event < dt_eval_max:
pybamm.logger.debug(
- "Setting dt_max to be as big as the largest step in "
+ "Setting dt_event to be as big as the largest step in "
f"t_eval ({dt_eval_max})"
)
- dt_max = dt_eval_max
+ dt_event = dt_eval_max
termination_due_to_small_dt = False
first_ts_solved = False
while t < t_f:
# Step
solved = False
count = 0
- dt = dt_max
+ dt = dt_event
while not solved:
# Get window of time to integrate over (so that we return
# all the points in t_eval, not just t and t+dt)
@@ -266,13 +273,13 @@ def _integrate(self, model, t_eval, inputs_dict=None):
# needed, but this won't affect the global timesteps. The
# global timestep will only be reduced after the first timestep.
if first_ts_solved:
- dt_max = dt
+ dt_event = dt
if count > self.max_step_decrease_count:
message = (
"Maximum number of decreased steps occurred at "
f"t={t} (final SolverError: '{error}'). "
- "For a full solution try reducing dt_max (currently, "
- f"dt_max={dt_max}) and/or reducing the size of the "
+ "For a full solution try reducing dt_event (currently, "
+ f"dt_event={dt_event}) and/or reducing the size of the "
"time steps or period of the experiment."
)
if first_ts_solved and self.return_solution_if_failed_early:
diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py
index e9976fc28c..872e50d300 100644
--- a/pybamm/solvers/idaklu_solver.py
+++ b/pybamm/solvers/idaklu_solver.py
@@ -7,6 +7,7 @@
import numpy as np
import numbers
import scipy.sparse as sparse
+from .base_solver import validate_max_step
import importlib
@@ -45,6 +46,9 @@ class IDAKLUSolver(pybamm.BaseSolver):
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
output_variables : list[str], optional
List of variables to calculate and return. If none are specified then
the complete state vector is returned (can be very large) (default is [])
@@ -89,6 +93,7 @@ def __init__(
root_method="casadi",
root_tol=1e-6,
extrap_tol=None,
+ max_step=np.inf,
output_variables=[],
options=None,
):
@@ -112,6 +117,8 @@ def __init__(
options[key] = value
self._options = options
+ self.max_step = validate_max_step(max_step)
+
self.output_variables = output_variables
if idaklu_spec is None: # pragma: no cover
@@ -124,6 +131,7 @@ def __init__(
root_method,
root_tol,
extrap_tol,
+ max_step,
output_variables,
)
self.name = "IDA KLU solver"
diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py
index 6c89bed4dd..cbaec16b42 100644
--- a/pybamm/solvers/jax_solver.py
+++ b/pybamm/solvers/jax_solver.py
@@ -5,6 +5,7 @@
import asyncio
import pybamm
+from .base_solver import validate_max_step
if pybamm.have_jax():
import jax
@@ -43,6 +44,9 @@ class JaxSolver(pybamm.BaseSolver):
The absolute tolerance for the solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the solver.
Please consult `JAX documentation
@@ -57,6 +61,7 @@ def __init__(
rtol=1e-6,
atol=1e-6,
extrap_tol=None,
+ max_step=onp.inf,
extra_options=None,
):
if not pybamm.have_jax():
@@ -67,7 +72,12 @@ def __init__(
# note: bdf solver itself calculates consistent initial conditions so can set
# root_method to none, allow user to override this behavior
super().__init__(
- method, rtol, atol, root_method=root_method, extrap_tol=extrap_tol
+ method,
+ rtol,
+ atol,
+ root_method=root_method,
+ extrap_tol=extrap_tol,
+ max_step=max_step,
)
method_options = ["RK45", "BDF"]
if method not in method_options:
@@ -77,6 +87,7 @@ def __init__(
self.ode_solver = True
self.extra_options = extra_options or {}
self.name = f"JAX solver ({method})"
+ self.max_step = validate_max_step(max_step)
self._cached_solves = dict()
pybamm.citations.register("jax2018")
diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py
index c942e8ccd7..8a9eebaf58 100644
--- a/pybamm/solvers/scikits_dae_solver.py
+++ b/pybamm/solvers/scikits_dae_solver.py
@@ -9,6 +9,8 @@
import importlib
import scipy.sparse as sparse
+from .base_solver import validate_max_step
+
scikits_odes_spec = importlib.util.find_spec("scikits")
if scikits_odes_spec is not None:
scikits_odes_spec = importlib.util.find_spec("scikits.odes")
@@ -39,6 +41,9 @@ class ScikitsDaeSolver(pybamm.BaseSolver):
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the solver.
Please consult `scikits.odes documentation
@@ -56,15 +61,19 @@ def __init__(
root_method="casadi",
root_tol=1e-6,
extrap_tol=None,
+ max_step=np.inf,
extra_options=None,
):
if scikits_odes_spec is None:
raise ImportError("scikits.odes is not installed")
- super().__init__(method, rtol, atol, root_method, root_tol, extrap_tol)
+ super().__init__(
+ method, rtol, atol, root_method, root_tol, extrap_tol, max_step
+ )
self.name = f"Scikits DAE solver ({method})"
self.extra_options = extra_options or {}
+ self.max_step = validate_max_step(max_step)
pybamm.citations.register("Malengier2018")
pybamm.citations.register("Hindmarsh2000")
diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py
index f3a4232da9..83d1c78312 100644
--- a/pybamm/solvers/scikits_ode_solver.py
+++ b/pybamm/solvers/scikits_ode_solver.py
@@ -9,6 +9,8 @@
import importlib
import scipy.sparse as sparse
+from .base_solver import validate_max_step
+
scikits_odes_spec = importlib.util.find_spec("scikits")
if scikits_odes_spec is not None:
scikits_odes_spec = importlib.util.find_spec("scikits.odes")
@@ -34,6 +36,9 @@ class ScikitsOdeSolver(pybamm.BaseSolver):
The absolute tolerance for the solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the solver.
Please consult `scikits.odes documentation
@@ -50,15 +55,17 @@ def __init__(
rtol=1e-6,
atol=1e-6,
extrap_tol=None,
+ max_step=np.inf,
extra_options=None,
):
if scikits_odes_spec is None: # pragma: no cover
raise ImportError("scikits.odes is not installed")
- super().__init__(method, rtol, atol, extrap_tol=extrap_tol)
+ super().__init__(method, rtol, atol, extrap_tol=extrap_tol, max_step=max_step)
self.extra_options = extra_options or {}
self.ode_solver = True
self.name = f"Scikits ODE solver ({method})"
+ self.max_step = validate_max_step(max_step)
pybamm.citations.register("Malengier2018")
pybamm.citations.register("Hindmarsh2000")
diff --git a/pybamm/solvers/scipy_solver.py b/pybamm/solvers/scipy_solver.py
index fb320f558d..61b7c459c7 100644
--- a/pybamm/solvers/scipy_solver.py
+++ b/pybamm/solvers/scipy_solver.py
@@ -7,6 +7,8 @@
import scipy.integrate as it
import numpy as np
+from .base_solver import validate_max_step
+
class ScipySolver(pybamm.BaseSolver):
"""Solve a discretised model, using scipy.integrate.solve_ivp.
@@ -21,6 +23,9 @@ class ScipySolver(pybamm.BaseSolver):
The absolute tolerance for the solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
+ max_step : float, optional
+ Maximum allowed step size. Default is np.inf, i.e., the step size is not
+ bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the solver.
Please consult `SciPy documentation `_ for
@@ -33,6 +38,7 @@ def __init__(
rtol=1e-6,
atol=1e-6,
extrap_tol=None,
+ max_step=np.inf,
extra_options=None,
):
super().__init__(
@@ -40,10 +46,12 @@ def __init__(
rtol=rtol,
atol=atol,
extrap_tol=extrap_tol,
+ max_step=max_step,
)
self.ode_solver = True
self.extra_options = extra_options or {}
self.name = f"Scipy solver ({method})"
+ self.max_step = validate_max_step(max_step)
pybamm.citations.register("Virtanen2020")
def _integrate(self, model, t_eval, inputs_dict=None):
diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py
index 577e50e68b..cd2e587923 100644
--- a/tests/unit/test_solvers/test_base_solver.py
+++ b/tests/unit/test_solvers/test_base_solver.py
@@ -385,6 +385,10 @@ def exact_diff_b(y, a, b):
sens_b, exact_diff_b(y, inputs["a"], inputs["b"])
)
+ def test_validate_max_step(self):
+ with self.assertRaisesRegex(ValueError, "`max_step` must be positive."):
+ pybamm.validate_max_step(-1)
+
if __name__ == "__main__":
print("Add -v for more debug output")
diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py
index 3030f80af0..dca550ab94 100644
--- a/tests/unit/test_solvers/test_casadi_solver.py
+++ b/tests/unit/test_solvers/test_casadi_solver.py
@@ -141,7 +141,9 @@ def test_model_solver_failure(self):
disc = pybamm.Discretisation()
model_disc = disc.process_model(model, inplace=False)
solver = pybamm.CasadiSolver(
- dt_max=1e-3, return_solution_if_failed_early=True, max_step_decrease_count=2
+ dt_event=1e-3,
+ return_solution_if_failed_early=True,
+ max_step_decrease_count=2,
)
# Solve with failure at t=2
# Solution fails early but manages to take some steps so we return it anyway
@@ -152,7 +154,9 @@ def test_model_solver_failure(self):
self.assertLess(solution.t[-1], 20)
# Solve with failure at t=0
solver = pybamm.CasadiSolver(
- dt_max=1e-3, return_solution_if_failed_early=True, max_step_decrease_count=2
+ dt_event=1e-3,
+ return_solution_if_failed_early=True,
+ max_step_decrease_count=2,
)
model.initial_conditions = {var: 0, var2: 1}
model_disc = disc.process_model(model, inplace=False)
@@ -195,7 +199,7 @@ def test_model_solver_events(self):
# Solve using "safe" mode with debug off
pybamm.settings.debug_mode = False
- solver = pybamm.CasadiSolver(mode="safe", rtol=1e-8, atol=1e-8, dt_max=1)
+ solver = pybamm.CasadiSolver(mode="safe", rtol=1e-8, atol=1e-8, dt_event=1)
t_eval = np.linspace(0, 5, 100)
solution = solver.solve(model, t_eval)
np.testing.assert_array_less(solution.y.full()[0], 1.5)
@@ -210,8 +214,8 @@ def test_model_solver_events(self):
)
pybamm.settings.debug_mode = True
- # Try dt_max=0 to enforce using all timesteps
- solver = pybamm.CasadiSolver(dt_max=0, rtol=1e-8, atol=1e-8)
+ # Try dt_event=0 to enforce using all timesteps
+ solver = pybamm.CasadiSolver(dt_event=0, rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 100)
solution = solver.solve(model, t_eval)
np.testing.assert_array_less(solution.y.full()[0], 1.5)