From c5f8b8a53048c1ee38d18c054b639d27d6775cdd Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 5 Jul 2023 04:23:34 +0530 Subject: [PATCH 01/14] Add `max_step` arg in `basesolver` --- pybamm/solvers/base_solver.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 7740006310..750894b72b 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -48,6 +48,7 @@ def __init__( root_method=None, root_tol=1e-6, extrap_tol=None, + max_step=None ): self.method = method self.rtol = rtol @@ -55,6 +56,7 @@ def __init__( self.root_tol = root_tol self.root_method = root_method self.extrap_tol = extrap_tol or -1e-10 + self.max_step = max_step self._model_set_up = {} # Defaults, can be overwritten by specific solver From 9567e786afe1ef9501e284b2530029e7daf25e5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Jul 2023 22:54:57 +0000 Subject: [PATCH 02/14] style: pre-commit fixes --- pybamm/solvers/base_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 750894b72b..c8f58105a5 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -48,7 +48,7 @@ def __init__( root_method=None, root_tol=1e-6, extrap_tol=None, - max_step=None + max_step=None, ): self.method = method self.rtol = rtol From 1d706908061f8e8333acd55c958a44a5fdfa87b4 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Fri, 16 Feb 2024 11:54:57 +0530 Subject: [PATCH 03/14] Add docstring & `validate_max_step` --- pybamm/solvers/base_solver.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index b9a0285990..93d8d7e6a9 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -16,6 +16,13 @@ from pybamm.expression_tree.binary_operators import _Heaviside +def validate_max_step(max_step): + """Assert that max_Step is valid and return it.""" + if max_step <= 0: + raise ValueError("`max_step` must be positive.") + return max_step + + class BaseSolver: """Solve a discretised model. @@ -38,6 +45,9 @@ class BaseSolver: The tolerance for the initial-condition solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not. Default is 0. + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. output_variables : list[str], optional List of variables to calculate and return. If none are specified then the complete state vector is returned (can be very large) (default is []) @@ -51,7 +61,7 @@ def __init__( root_method=None, root_tol=1e-6, extrap_tol=None, - max_step=None, + max_step=np.inf, output_variables=[], ): self.method = method @@ -60,7 +70,7 @@ def __init__( self.root_tol = root_tol self.root_method = root_method self.extrap_tol = extrap_tol or -1e-10 - self.max_step = max_step + self.max_step = validate_max_step(max_step) self.output_variables = output_variables self._model_set_up = {} From 433cafbe9f748262b104911921c39428fc69b861 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 19 Feb 2024 14:37:26 +0530 Subject: [PATCH 04/14] Pass `max_step` to `idaklu_solver` --- pybamm/solvers/base_solver.py | 2 +- pybamm/solvers/idaklu_solver.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 93d8d7e6a9..7917dc4692 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -17,7 +17,7 @@ def validate_max_step(max_step): - """Assert that max_Step is valid and return it.""" + """Assert that max_step is valid and return it.""" if max_step <= 0: raise ValueError("`max_step` must be positive.") return max_step diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 6c81bf91e7..6cde765d58 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -43,6 +43,9 @@ class IDAKLUSolver(pybamm.BaseSolver): The tolerance for the initial-condition solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. output_variables : list[str], optional List of variables to calculate and return. If none are specified then the complete state vector is returned (can be very large) (default is []) @@ -87,6 +90,7 @@ def __init__( root_method="casadi", root_tol=1e-6, extrap_tol=None, + max_step=np.inf, output_variables=[], options=None, ): @@ -122,6 +126,7 @@ def __init__( root_method, root_tol, extrap_tol, + max_step, output_variables, ) self.name = "IDA KLU solver" From 71c9007cf485e4fd0d35d53a44352f447a936896 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 19 Feb 2024 15:17:47 +0530 Subject: [PATCH 05/14] Pass `max_step` to `algebric_solver` --- pybamm/solvers/algebraic_solver.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pybamm/solvers/algebraic_solver.py b/pybamm/solvers/algebraic_solver.py index d241d5b24c..e1da33ab5b 100644 --- a/pybamm/solvers/algebraic_solver.py +++ b/pybamm/solvers/algebraic_solver.py @@ -6,6 +6,7 @@ import numpy as np from scipy import optimize from scipy.sparse import issparse +from base_solver import validate_max_step class AlgebraicSolver(pybamm.BaseSolver): @@ -24,14 +25,18 @@ class AlgebraicSolver(pybamm.BaseSolver): specified in the form "lsq_methodname" tol : float, optional The tolerance for the solver (default is 1e-6). + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the rootfinder. Vary depending on which method is chosen. Please consult `SciPy documentation `_ for details. """ - def __init__(self, method="lm", tol=1e-6, extra_options=None): + def __init__(self, method="lm", tol=1e-6, max_step=np.inf, extra_options=None): super().__init__(method=method) + self.max_step = validate_max_step(max_step) self.tol = tol self.extra_options = extra_options or {} self.name = f"Algebraic solver ({method})" From 87f5593653e613a780eb882ad71e9492979a6941 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 19 Feb 2024 15:18:46 +0530 Subject: [PATCH 06/14] Pass `max_step` to `casadi_algebric_solver` --- pybamm/solvers/casadi_algebraic_solver.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pybamm/solvers/casadi_algebraic_solver.py b/pybamm/solvers/casadi_algebraic_solver.py index ec7305906a..54a61baedd 100644 --- a/pybamm/solvers/casadi_algebraic_solver.py +++ b/pybamm/solvers/casadi_algebraic_solver.py @@ -4,6 +4,7 @@ import casadi import pybamm import numpy as np +from base_solver import validate_max_step class CasadiAlgebraicSolver(pybamm.BaseSolver): @@ -17,6 +18,9 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver): ---------- tol : float, optional The tolerance for the solver (default is 1e-6). + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the CasADi rootfinder. Please consult `CasADi documentation `_ for @@ -24,10 +28,11 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver): """ - def __init__(self, tol=1e-6, extra_options=None): + def __init__(self, tol=1e-6, max_step=np.inf, extra_options=None): super().__init__() self.tol = tol self.name = "CasADi algebraic solver" + self.max_step = validate_max_step(max_step) self.algebraic_solver = True self.extra_options = extra_options or {} pybamm.citations.register("Andersson2019") From f5615b2f5b802e2fac1cd158ef23a1810db70f55 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 19 Feb 2024 15:28:14 +0530 Subject: [PATCH 07/14] Pass `max_step` to `casadi_solver` --- pybamm/solvers/casadi_solver.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index 02ff4a2cd9..5c054270a9 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -7,6 +7,7 @@ import warnings from scipy.interpolate import interp1d from .lrudict import LRUDict +from base_solver import validate_max_step class CasadiSolver(pybamm.BaseSolver): @@ -41,6 +42,9 @@ class CasadiSolver(pybamm.BaseSolver): specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for root-finding. Default is 1e-6. + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. max_step_decrease_count : float, optional The maximum number of times step size can be decreased before an error is raised. Default is 5. @@ -81,6 +85,7 @@ def __init__( atol=1e-6, root_method="casadi", root_tol=1e-6, + max_step=np.inf, max_step_decrease_count=5, dt_max=None, extrap_tol=None, @@ -96,6 +101,7 @@ def __init__( atol, root_method, root_tol, + max_step, extrap_tol, ) if mode in ["safe", "fast", "fast with events", "safe without grid"]: @@ -106,6 +112,7 @@ def __init__( "'fast', for solving quickly without events, or 'safe without grid' or " "'fast with events' (both experimental)" ) + self.max_step = validate_max_step(max_step) self.max_step_decrease_count = max_step_decrease_count self.dt_max = dt_max or 600 From 6cff7ae4ff5d63afbb2e5f446bc249f118703067 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 19 Feb 2024 15:49:03 +0530 Subject: [PATCH 08/14] Validate import --- pybamm/solvers/algebraic_solver.py | 2 +- pybamm/solvers/casadi_algebraic_solver.py | 2 +- pybamm/solvers/casadi_solver.py | 2 +- pybamm/solvers/idaklu_solver.py | 3 +++ 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pybamm/solvers/algebraic_solver.py b/pybamm/solvers/algebraic_solver.py index e1da33ab5b..e3ff430a61 100644 --- a/pybamm/solvers/algebraic_solver.py +++ b/pybamm/solvers/algebraic_solver.py @@ -6,7 +6,7 @@ import numpy as np from scipy import optimize from scipy.sparse import issparse -from base_solver import validate_max_step +from .base_solver import validate_max_step class AlgebraicSolver(pybamm.BaseSolver): diff --git a/pybamm/solvers/casadi_algebraic_solver.py b/pybamm/solvers/casadi_algebraic_solver.py index 54a61baedd..72808e0486 100644 --- a/pybamm/solvers/casadi_algebraic_solver.py +++ b/pybamm/solvers/casadi_algebraic_solver.py @@ -4,7 +4,7 @@ import casadi import pybamm import numpy as np -from base_solver import validate_max_step +from .base_solver import validate_max_step class CasadiAlgebraicSolver(pybamm.BaseSolver): diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index 5c054270a9..dfebdbb99d 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -7,7 +7,7 @@ import warnings from scipy.interpolate import interp1d from .lrudict import LRUDict -from base_solver import validate_max_step +from .base_solver import validate_max_step class CasadiSolver(pybamm.BaseSolver): diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 6cde765d58..2d2a52b439 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -6,6 +6,7 @@ import numpy as np import numbers import scipy.sparse as sparse +from .base_solver import validate_max_step import importlib @@ -114,6 +115,8 @@ def __init__( options[key] = value self._options = options + self.max_step = validate_max_step(max_step) + self.output_variables = output_variables if idaklu_spec is None: # pragma: no cover From f80bee74c76206cd713b7a7b02bff324acf27d4f Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 19 Feb 2024 19:36:24 +0530 Subject: [PATCH 09/14] Fix casadi solver passage & pass `max_step` to `jax_solver` --- pybamm/solvers/casadi_solver.py | 2 +- pybamm/solvers/jax_solver.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index dfebdbb99d..0cef290d8b 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -101,8 +101,8 @@ def __init__( atol, root_method, root_tol, - max_step, extrap_tol, + max_step, ) if mode in ["safe", "fast", "fast with events", "safe without grid"]: self.mode = mode diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index 6c89bed4dd..cbaec16b42 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -5,6 +5,7 @@ import asyncio import pybamm +from .base_solver import validate_max_step if pybamm.have_jax(): import jax @@ -43,6 +44,9 @@ class JaxSolver(pybamm.BaseSolver): The absolute tolerance for the solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the solver. Please consult `JAX documentation @@ -57,6 +61,7 @@ def __init__( rtol=1e-6, atol=1e-6, extrap_tol=None, + max_step=onp.inf, extra_options=None, ): if not pybamm.have_jax(): @@ -67,7 +72,12 @@ def __init__( # note: bdf solver itself calculates consistent initial conditions so can set # root_method to none, allow user to override this behavior super().__init__( - method, rtol, atol, root_method=root_method, extrap_tol=extrap_tol + method, + rtol, + atol, + root_method=root_method, + extrap_tol=extrap_tol, + max_step=max_step, ) method_options = ["RK45", "BDF"] if method not in method_options: @@ -77,6 +87,7 @@ def __init__( self.ode_solver = True self.extra_options = extra_options or {} self.name = f"JAX solver ({method})" + self.max_step = validate_max_step(max_step) self._cached_solves = dict() pybamm.citations.register("jax2018") From 0cc6fb1664dc466d3cab4380cccf5bf471e2941c Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Tue, 20 Feb 2024 14:22:07 +0530 Subject: [PATCH 10/14] Pass `max_step` to `scikits_ode_solver`, `scikits_dae_solver` & `scipy_solver` --- pybamm/solvers/scikits_dae_solver.py | 11 ++++++++++- pybamm/solvers/scikits_ode_solver.py | 9 ++++++++- pybamm/solvers/scipy_solver.py | 8 ++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py index a5bf1e5a4f..571412cd31 100644 --- a/pybamm/solvers/scikits_dae_solver.py +++ b/pybamm/solvers/scikits_dae_solver.py @@ -8,6 +8,8 @@ import importlib import scipy.sparse as sparse +from .base_solver import validate_max_step + scikits_odes_spec = importlib.util.find_spec("scikits") if scikits_odes_spec is not None: scikits_odes_spec = importlib.util.find_spec("scikits.odes") @@ -38,6 +40,9 @@ class ScikitsDaeSolver(pybamm.BaseSolver): The tolerance for the initial-condition solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the solver. Please consult `scikits.odes documentation @@ -55,15 +60,19 @@ def __init__( root_method="casadi", root_tol=1e-6, extrap_tol=None, + max_step=np.inf, extra_options=None, ): if scikits_odes_spec is None: raise ImportError("scikits.odes is not installed") - super().__init__(method, rtol, atol, root_method, root_tol, extrap_tol) + super().__init__( + method, rtol, atol, root_method, root_tol, extrap_tol, max_step + ) self.name = f"Scikits DAE solver ({method})" self.extra_options = extra_options or {} + self.max_step = validate_max_step(max_step) pybamm.citations.register("Malengier2018") pybamm.citations.register("Hindmarsh2000") diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py index 9f5ee67604..d7282afa63 100644 --- a/pybamm/solvers/scikits_ode_solver.py +++ b/pybamm/solvers/scikits_ode_solver.py @@ -8,6 +8,8 @@ import importlib import scipy.sparse as sparse +from .base_solver import validate_max_step + scikits_odes_spec = importlib.util.find_spec("scikits") if scikits_odes_spec is not None: scikits_odes_spec = importlib.util.find_spec("scikits.odes") @@ -33,6 +35,9 @@ class ScikitsOdeSolver(pybamm.BaseSolver): The absolute tolerance for the solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the solver. Please consult `scikits.odes documentation @@ -49,15 +54,17 @@ def __init__( rtol=1e-6, atol=1e-6, extrap_tol=None, + max_step=np.inf, extra_options=None, ): if scikits_odes_spec is None: # pragma: no cover raise ImportError("scikits.odes is not installed") - super().__init__(method, rtol, atol, extrap_tol=extrap_tol) + super().__init__(method, rtol, atol, extrap_tol=extrap_tol, max_step=max_step) self.extra_options = extra_options or {} self.ode_solver = True self.name = f"Scikits ODE solver ({method})" + self.max_step = validate_max_step(max_step) pybamm.citations.register("Malengier2018") pybamm.citations.register("Hindmarsh2000") diff --git a/pybamm/solvers/scipy_solver.py b/pybamm/solvers/scipy_solver.py index fb320f558d..61b7c459c7 100644 --- a/pybamm/solvers/scipy_solver.py +++ b/pybamm/solvers/scipy_solver.py @@ -7,6 +7,8 @@ import scipy.integrate as it import numpy as np +from .base_solver import validate_max_step + class ScipySolver(pybamm.BaseSolver): """Solve a discretised model, using scipy.integrate.solve_ivp. @@ -21,6 +23,9 @@ class ScipySolver(pybamm.BaseSolver): The absolute tolerance for the solver (default is 1e-6). extrap_tol : float, optional The tolerance to assert whether extrapolation occurs or not (default is 0). + max_step : float, optional + Maximum allowed step size. Default is np.inf, i.e., the step size is not + bounded and determined solely by the solver. extra_options : dict, optional Any options to pass to the solver. Please consult `SciPy documentation `_ for @@ -33,6 +38,7 @@ def __init__( rtol=1e-6, atol=1e-6, extrap_tol=None, + max_step=np.inf, extra_options=None, ): super().__init__( @@ -40,10 +46,12 @@ def __init__( rtol=rtol, atol=atol, extrap_tol=extrap_tol, + max_step=max_step, ) self.ode_solver = True self.extra_options = extra_options or {} self.name = f"Scipy solver ({method})" + self.max_step = validate_max_step(max_step) pybamm.citations.register("Virtanen2020") def _integrate(self, model, t_eval, inputs_dict=None): From da018274b34ef3ab850b659e32e9d3070c5c5ad9 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Tue, 20 Feb 2024 16:47:11 +0530 Subject: [PATCH 11/14] Rename `dt_max` to `dt_event` --- .github/workflows/work_precision_sets.yml | 2 +- benchmarks/release-work-precision-sets.md | 4 +- .../work_precision_sets/time_vs_dt_max.py | 14 +++--- .../notebooks/models/composite_particle.ipynb | 4 +- .../notebooks/models/rate-capability.ipynb | 2 +- .../models/submodel_cracking_DFN_or_SPM.ipynb | 2 +- .../notebooks/solvers/speed-up-solver.ipynb | 48 +++++++++---------- pybamm/solvers/casadi_solver.py | 26 +++++----- tests/unit/test_solvers/test_casadi_solver.py | 14 ++++-- 9 files changed, 60 insertions(+), 56 deletions(-) diff --git a/.github/workflows/work_precision_sets.yml b/.github/workflows/work_precision_sets.yml index fafc5b1738..5bd8806c78 100644 --- a/.github/workflows/work_precision_sets.yml +++ b/.github/workflows/work_precision_sets.yml @@ -21,7 +21,7 @@ jobs: run: python -m pip install pybamm==${{ env.VERSION }} - name: Run time_vs_* benchmarks for PyBaMM v${{ env.VERSION }} run: | - python benchmarks/work_precision_sets/time_vs_dt_max.py + python benchmarks/work_precision_sets/time_vs_dt_event.py python benchmarks/work_precision_sets/time_vs_mesh_size.py python benchmarks/work_precision_sets/time_vs_no_of_states.py python benchmarks/work_precision_sets/time_vs_reltols.py diff --git a/benchmarks/release-work-precision-sets.md b/benchmarks/release-work-precision-sets.md index 12ce444a98..b424f02366 100644 --- a/benchmarks/release-work-precision-sets.md +++ b/benchmarks/release-work-precision-sets.md @@ -12,9 +12,9 @@ -## Solve Time vs dt_max +## Solve Time vs dt_event - + ## Solve Time vs Number of states diff --git a/benchmarks/work_precision_sets/time_vs_dt_max.py b/benchmarks/work_precision_sets/time_vs_dt_max.py index a1f8ca06bc..16068259cd 100644 --- a/benchmarks/work_precision_sets/time_vs_dt_max.py +++ b/benchmarks/work_precision_sets/time_vs_dt_max.py @@ -12,7 +12,7 @@ models = {"SPM": pybamm.lithium_ion.SPM(), "DFN": pybamm.lithium_ion.DFN()} -dt_max = [ +dt_event = [ 10, 20, 50, @@ -70,8 +70,8 @@ disc = pybamm.Discretisation(mesh, model.default_spatial_methods) disc.process_model(model) - for t in dt_max: - solver = pybamm.CasadiSolver(dt_max=t) + for t in dt_event: + solver = pybamm.CasadiSolver(dt_event=t) solver.solve(model, t_eval=t_eval) time = 0 @@ -85,20 +85,20 @@ ax.set_xscale("log") ax.set_yscale("log") - ax.set_xlabel("dt_max") + ax.set_xlabel("dt_event") ax.set_ylabel("time(s)") ax.set_title(f"{model_name}") - ax.plot(dt_max, time_points) + ax.plot(dt_event, time_points) plt.tight_layout() plt.gca().legend( parameters, loc="upper right", ) -plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_max_{pybamm.__version__}.png") +plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_event_{pybamm.__version__}.png") -content = f"## Solve Time vs dt_max\n\n" +content = f"## Solve Time vs dt_event\n\n" with open("./benchmarks/release_work_precision_sets.md") as original: data = original.read() diff --git a/docs/source/examples/notebooks/models/composite_particle.ipynb b/docs/source/examples/notebooks/models/composite_particle.ipynb index 387ffd7512..c41f020e75 100644 --- a/docs/source/examples/notebooks/models/composite_particle.ipynb +++ b/docs/source/examples/notebooks/models/composite_particle.ipynb @@ -263,7 +263,7 @@ " sim = pybamm.Simulation(\n", " model,\n", " parameter_values=param,\n", - " solver=pybamm.CasadiSolver(dt_max=5),\n", + " solver=pybamm.CasadiSolver(dt_event=5),\n", " )\n", " solution.append(sim.solve(t_eval=t_eval))\n", "stop = timeit.default_timer()\n", @@ -887,7 +887,7 @@ " model,\n", " experiment=experiment,\n", " parameter_values=param,\n", - " solver=pybamm.CasadiSolver(dt_max=5),\n", + " solver=pybamm.CasadiSolver(dt_event=5),\n", " )\n", " solution.append(sim.solve(calc_esoh=False))\n", "stop = timeit.default_timer()\n", diff --git a/docs/source/examples/notebooks/models/rate-capability.ipynb b/docs/source/examples/notebooks/models/rate-capability.ipynb index ef09a37909..43ef82aee2 100644 --- a/docs/source/examples/notebooks/models/rate-capability.ipynb +++ b/docs/source/examples/notebooks/models/rate-capability.ipynb @@ -100,7 +100,7 @@ " [f\"Discharge at {C_rate:.4f}C until 3.2V\"], period=f\"{10 / C_rate:.4f} seconds\"\n", " )\n", " sim = pybamm.Simulation(\n", - " model, experiment=experiment, solver=pybamm.CasadiSolver(dt_max=120)\n", + " model, experiment=experiment, solver=pybamm.CasadiSolver(dt_event=120)\n", " )\n", " sim.solve()\n", "\n", diff --git a/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb b/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb index ac92c06d15..96cdf9556f 100644 --- a/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb +++ b/docs/source/examples/notebooks/models/submodel_cracking_DFN_or_SPM.ipynb @@ -144,7 +144,7 @@ "sim = pybamm.Simulation(\n", " model,\n", " parameter_values=param,\n", - " solver=pybamm.CasadiSolver(dt_max=600),\n", + " solver=pybamm.CasadiSolver(dt_event=600),\n", " var_pts=var_pts,\n", ")\n", "solution = sim.solve(t_eval=[0, 3600], inputs={\"C-rate\": 1})\n", diff --git a/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb b/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb index c49c8926fb..5f2967ce30 100644 --- a/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb +++ b/docs/source/examples/notebooks/solvers/speed-up-solver.ipynb @@ -366,7 +366,7 @@ "metadata": {}, "outputs": [], "source": [ - "safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_max=30)\n", + "safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_event=30)\n", "safe_sol_2 = sim.solve([0, 160], solver=safe_solver_2, inputs={\"Crate\": 10})" ] }, @@ -374,14 +374,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Choosing dt_max to speed up the safe mode" + "### Choosing dt_event to speed up the safe mode" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The parameter `dt_max` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events." + "The parameter `dt_event` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events." ] }, { @@ -393,24 +393,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "With dt_max=10, took 575.783 ms (integration time: 508.473 ms)\n", - "With dt_max=20, took 575.500 ms (integration time: 510.705 ms)\n", - "With dt_max=100, took 316.721 ms (integration time: 275.459 ms)\n", - "With dt_max=1000, took 76.646 ms (integration time: 49.294 ms)\n", - "With dt_max=3700, took 48.773 ms (integration time: 32.436 ms)\n", + "With dt_event=10, took 575.783 ms (integration time: 508.473 ms)\n", + "With dt_event=20, took 575.500 ms (integration time: 510.705 ms)\n", + "With dt_event=100, took 316.721 ms (integration time: 275.459 ms)\n", + "With dt_event=1000, took 76.646 ms (integration time: 49.294 ms)\n", + "With dt_event=3700, took 48.773 ms (integration time: 32.436 ms)\n", "With 'fast' mode, took 42.224 ms (integration time: 32.177 ms)\n" ] } ], "source": [ - "for dt_max in [10, 20, 100, 1000, 3700]:\n", + "for dt_event in [10, 20, 100, 1000, 3700]:\n", " safe_sol = sim.solve(\n", " [0, 3600],\n", - " solver=pybamm.CasadiSolver(mode=\"safe\", dt_max=dt_max),\n", + " solver=pybamm.CasadiSolver(mode=\"safe\", dt_event=dt_event),\n", " inputs={\"Crate\": 1},\n", " )\n", " print(\n", - " f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n", + " f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n", " + f\"(integration time: {safe_sol.integration_time})\"\n", " )\n", "\n", @@ -425,9 +425,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In general, a larger value of `dt_max` gives a faster solution, since fewer integrator creations and calls are required.\n", + "In general, a larger value of `dt_event` gives a faster solution, since fewer integrator creations and calls are required.\n", "\n", - "Below the solution time interval of 36s, the value of `dt_max` does not affect the solve time, since steps must be at least 36s large.\n", + "Below the solution time interval of 36s, the value of `dt_event` does not affect the solve time, since steps must be at least 36s large.\n", "The discrepancy between the solve time and integration time is due to the extra operations recorded by \"solve time\", such as creating the integrator. The \"fast\" solver does not need to do this (it reuses the first one it had already created), so the solve time is much closer to the integration time." ] }, @@ -435,7 +435,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The example above was a case where no events are triggered, so the largest `dt_max` works well. If we step over events, then it is possible to makes `dt_max` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:" + "The example above was a case where no events are triggered, so the largest `dt_event` works well. If we step over events, then it is possible to makes `dt_event` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:" ] }, { @@ -447,10 +447,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "With dt_max=10, took 504.163 ms (integration time: 419.740 ms)\n", - "With dt_max=20, took 504.691 ms (integration time: 421.396 ms)\n", - "With dt_max=100, took 286.620 ms (integration time: 238.390 ms)\n", - "With dt_max=1000, took 98.500 ms (integration time: 60.880 ms)\n" + "With dt_event=10, took 504.163 ms (integration time: 419.740 ms)\n", + "With dt_event=20, took 504.691 ms (integration time: 421.396 ms)\n", + "With dt_event=100, took 286.620 ms (integration time: 238.390 ms)\n", + "With dt_event=1000, took 98.500 ms (integration time: 60.880 ms)\n" ] }, { @@ -466,22 +466,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "With dt_max=3600, took 645.118 ms (integration time: 32.601 ms)\n" + "With dt_event=3600, took 645.118 ms (integration time: 32.601 ms)\n" ] } ], "source": [ - "for dt_max in [10, 20, 100, 1000, 3600]:\n", + "for dt_event in [10, 20, 100, 1000, 3600]:\n", " # Reduce max_num_steps to fail faster\n", " safe_sol = sim.solve(\n", " [0, 4500],\n", " solver=pybamm.CasadiSolver(\n", - " mode=\"safe\", dt_max=dt_max, extra_options_setup={\"max_num_steps\": 1000}\n", + " mode=\"safe\", dt_event=dt_event, extra_options_setup={\"max_num_steps\": 1000}\n", " ),\n", " inputs={\"Crate\": 1},\n", " )\n", " print(\n", - " f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n", + " f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n", " + f\"(integration time: {safe_sol.integration_time})\"\n", " )" ] @@ -490,7 +490,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The integration time with `dt_max=3600` remains the fastest, but the solve time is the slowest due to all the failed steps." + "The integration time with `dt_event=3600` remains the fastest, but the solve time is the slowest due to all the failed steps." ] }, { @@ -504,7 +504,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_max`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve." + "The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_event`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve." ] }, { diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index 0cef290d8b..d68152dc3c 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -24,7 +24,7 @@ class CasadiSolver(pybamm.BaseSolver): - "fast with events": perform direct integration of the whole timespan, \ then go back and check where events were crossed. Experimental only. - "safe": perform step-and-check integration in global steps of size \ - dt_max, checking whether events have been triggered. Recommended for \ + dt_event, checking whether events have been triggered. Recommended for \ simulations of a full charge or discharge. - "safe without grid": perform step-and-check integration step-by-step. \ Takes more steps than "safe" mode, but doesn't require creating the grid \ @@ -48,7 +48,7 @@ class CasadiSolver(pybamm.BaseSolver): max_step_decrease_count : float, optional The maximum number of times step size can be decreased before an error is raised. Default is 5. - dt_max : float, optional + dt_event : float, optional The maximum global step size (in seconds) used in "safe" mode. If None the default value is 600 seconds. extrap_tol : float, optional @@ -87,7 +87,7 @@ def __init__( root_tol=1e-6, max_step=np.inf, max_step_decrease_count=5, - dt_max=None, + dt_event=None, extrap_tol=None, extra_options_setup=None, extra_options_call=None, @@ -114,7 +114,7 @@ def __init__( ) self.max_step = validate_max_step(max_step) self.max_step_decrease_count = max_step_decrease_count - self.dt_max = dt_max or 600 + self.dt_event = dt_event or 600 self.extra_options_setup = extra_options_setup or {} self.extra_options_call = extra_options_call or {} @@ -212,24 +212,24 @@ def _integrate(self, model, t_eval, inputs_dict=None): solution = None use_grid = True - # Try to integrate in global steps of size dt_max. Note: dt_max must + # Try to integrate in global steps of size dt_event. Note: dt_event must # be at least as big as the the biggest step in t_eval (multiplied # by some tolerance, here 1.01) to avoid an empty integration window below - dt_max = self.dt_max + dt_event = self.dt_event dt_eval_max = np.max(np.diff(t_eval)) * 1.01 - if dt_max < dt_eval_max: + if dt_event < dt_eval_max: pybamm.logger.debug( - "Setting dt_max to be as big as the largest step in " + "Setting dt_event to be as big as the largest step in " f"t_eval ({dt_eval_max})" ) - dt_max = dt_eval_max + dt_event = dt_eval_max termination_due_to_small_dt = False first_ts_solved = False while t < t_f: # Step solved = False count = 0 - dt = dt_max + dt = dt_event while not solved: # Get window of time to integrate over (so that we return # all the points in t_eval, not just t and t+dt) @@ -273,13 +273,13 @@ def _integrate(self, model, t_eval, inputs_dict=None): # needed, but this won't affect the global timesteps. The # global timestep will only be reduced after the first timestep. if first_ts_solved: - dt_max = dt + dt_event = dt if count > self.max_step_decrease_count: message = ( "Maximum number of decreased steps occurred at " f"t={t} (final SolverError: '{error}'). " - "For a full solution try reducing dt_max (currently, " - f"dt_max={dt_max}) and/or reducing the size of the " + "For a full solution try reducing dt_event (currently, " + f"dt_event={dt_event}) and/or reducing the size of the " "time steps or period of the experiment." ) if first_ts_solved and self.return_solution_if_failed_early: diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 3030f80af0..dca550ab94 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -141,7 +141,9 @@ def test_model_solver_failure(self): disc = pybamm.Discretisation() model_disc = disc.process_model(model, inplace=False) solver = pybamm.CasadiSolver( - dt_max=1e-3, return_solution_if_failed_early=True, max_step_decrease_count=2 + dt_event=1e-3, + return_solution_if_failed_early=True, + max_step_decrease_count=2, ) # Solve with failure at t=2 # Solution fails early but manages to take some steps so we return it anyway @@ -152,7 +154,9 @@ def test_model_solver_failure(self): self.assertLess(solution.t[-1], 20) # Solve with failure at t=0 solver = pybamm.CasadiSolver( - dt_max=1e-3, return_solution_if_failed_early=True, max_step_decrease_count=2 + dt_event=1e-3, + return_solution_if_failed_early=True, + max_step_decrease_count=2, ) model.initial_conditions = {var: 0, var2: 1} model_disc = disc.process_model(model, inplace=False) @@ -195,7 +199,7 @@ def test_model_solver_events(self): # Solve using "safe" mode with debug off pybamm.settings.debug_mode = False - solver = pybamm.CasadiSolver(mode="safe", rtol=1e-8, atol=1e-8, dt_max=1) + solver = pybamm.CasadiSolver(mode="safe", rtol=1e-8, atol=1e-8, dt_event=1) t_eval = np.linspace(0, 5, 100) solution = solver.solve(model, t_eval) np.testing.assert_array_less(solution.y.full()[0], 1.5) @@ -210,8 +214,8 @@ def test_model_solver_events(self): ) pybamm.settings.debug_mode = True - # Try dt_max=0 to enforce using all timesteps - solver = pybamm.CasadiSolver(dt_max=0, rtol=1e-8, atol=1e-8) + # Try dt_event=0 to enforce using all timesteps + solver = pybamm.CasadiSolver(dt_event=0, rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 5, 100) solution = solver.solve(model, t_eval) np.testing.assert_array_less(solution.y.full()[0], 1.5) From d23891d1dbeca0ed11d2abcb367127a9cd0b2e89 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 21 Feb 2024 12:20:46 +0530 Subject: [PATCH 12/14] Fix coverage for `validate_max_step` --- pybamm/__init__.py | 1 + tests/unit/test_solvers/test_base_solver.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index d9b866ff0d..521c7d5665 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -219,6 +219,7 @@ from .solvers.jax_solver import JaxSolver from .solvers.jax_bdf_solver import jax_bdf_integrate +from .solvers.base_solver import validate_max_step from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 577e50e68b..cd2e587923 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -385,6 +385,10 @@ def exact_diff_b(y, a, b): sens_b, exact_diff_b(y, inputs["a"], inputs["b"]) ) + def test_validate_max_step(self): + with self.assertRaisesRegex(ValueError, "`max_step` must be positive."): + pybamm.validate_max_step(-1) + if __name__ == "__main__": print("Add -v for more debug output") From b3a612b2a0b5bb685f0a2804da274fa39196d10e Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 21 Feb 2024 16:40:58 +0530 Subject: [PATCH 13/14] CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bdacb00ae3..00cc102c5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Renamed "electrode diffusivity" to "particle diffusivity" as a non-breaking change with a deprecation warning ([#3624](https://github.com/pybamm-team/PyBaMM/pull/3624)) - Add support for BPX version 0.4.0 which allows for blended electrodes and user-defined parameters in BPX([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414)) - Added the ability to specify a custom solver tolerance in `get_initial_stoichiometries` and related functions ([#3714](https://github.com/pybamm-team/PyBaMM/pull/3714)) +- Added `max_step` parameter to `BaseSolver` and passed it to dependent solvers ([#3106](https://github.com/pybamm-team/PyBaMM/pull/3106)) ## Bug Fixes From 68c29256797fd20fdef007f54264a1a48925d493 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 21 Feb 2024 16:46:12 +0530 Subject: [PATCH 14/14] Rename `time_vs_dt_max.py` to `time_vs_dt_event.py` --- .../{time_vs_dt_max.py => time_vs_dt_event.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename benchmarks/work_precision_sets/{time_vs_dt_max.py => time_vs_dt_event.py} (100%) diff --git a/benchmarks/work_precision_sets/time_vs_dt_max.py b/benchmarks/work_precision_sets/time_vs_dt_event.py similarity index 100% rename from benchmarks/work_precision_sets/time_vs_dt_max.py rename to benchmarks/work_precision_sets/time_vs_dt_event.py