Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_step arg in basesolver #3106

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/work_precision_sets.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
run: python -m pip install pybamm==${{ env.VERSION }}
- name: Run time_vs_* benchmarks for PyBaMM v${{ env.VERSION }}
run: |
python benchmarks/work_precision_sets/time_vs_dt_max.py
python benchmarks/work_precision_sets/time_vs_dt_event.py
arjxn-py marked this conversation as resolved.
Show resolved Hide resolved
python benchmarks/work_precision_sets/time_vs_mesh_size.py
python benchmarks/work_precision_sets/time_vs_no_of_states.py
python benchmarks/work_precision_sets/time_vs_reltols.py
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Renamed "electrode diffusivity" to "particle diffusivity" as a non-breaking change with a deprecation warning ([#3624](https://github.com/pybamm-team/PyBaMM/pull/3624))
- Add support for BPX version 0.4.0 which allows for blended electrodes and user-defined parameters in BPX([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414))
- Added the ability to specify a custom solver tolerance in `get_initial_stoichiometries` and related functions ([#3714](https://github.com/pybamm-team/PyBaMM/pull/3714))
- Added `max_step` parameter to `BaseSolver` and passed it to dependent solvers ([#3106](https://github.com/pybamm-team/PyBaMM/pull/3106))

## Bug Fixes

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/release-work-precision-sets.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

<img src='./benchmark_images/time_vs_mesh_size_22.7.png'>

## Solve Time vs dt_max
## Solve Time vs dt_event

<img src='./benchmark_images/time_vs_dt_max_22.7.png'>
<img src='./benchmark_images/time_vs_dt_event_22.7.png'>

## Solve Time vs Number of states

Expand Down
14 changes: 7 additions & 7 deletions benchmarks/work_precision_sets/time_vs_dt_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

models = {"SPM": pybamm.lithium_ion.SPM(), "DFN": pybamm.lithium_ion.DFN()}

dt_max = [
dt_event = [
10,
20,
50,
Expand Down Expand Up @@ -70,8 +70,8 @@
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

for t in dt_max:
solver = pybamm.CasadiSolver(dt_max=t)
for t in dt_event:
solver = pybamm.CasadiSolver(dt_event=t)

solver.solve(model, t_eval=t_eval)
time = 0
Expand All @@ -85,20 +85,20 @@

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("dt_max")
ax.set_xlabel("dt_event")
ax.set_ylabel("time(s)")
ax.set_title(f"{model_name}")
ax.plot(dt_max, time_points)
ax.plot(dt_event, time_points)

plt.tight_layout()
plt.gca().legend(
parameters,
loc="upper right",
)
plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_max_{pybamm.__version__}.png")
plt.savefig(f"benchmarks/benchmark_images/time_vs_dt_event_{pybamm.__version__}.png")


content = f"## Solve Time vs dt_max\n<img src='./benchmark_images/time_vs_dt_max_{pybamm.__version__}.png'>\n"
content = f"## Solve Time vs dt_event\n<img src='./benchmark_images/time_vs_dt_event_{pybamm.__version__}.png'>\n"

with open("./benchmarks/release_work_precision_sets.md") as original:
data = original.read()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@
" sim = pybamm.Simulation(\n",
" model,\n",
" parameter_values=param,\n",
" solver=pybamm.CasadiSolver(dt_max=5),\n",
" solver=pybamm.CasadiSolver(dt_event=5),\n",
" )\n",
" solution.append(sim.solve(t_eval=t_eval))\n",
"stop = timeit.default_timer()\n",
Expand Down Expand Up @@ -887,7 +887,7 @@
" model,\n",
" experiment=experiment,\n",
" parameter_values=param,\n",
" solver=pybamm.CasadiSolver(dt_max=5),\n",
" solver=pybamm.CasadiSolver(dt_event=5),\n",
" )\n",
" solution.append(sim.solve(calc_esoh=False))\n",
"stop = timeit.default_timer()\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
" [f\"Discharge at {C_rate:.4f}C until 3.2V\"], period=f\"{10 / C_rate:.4f} seconds\"\n",
" )\n",
" sim = pybamm.Simulation(\n",
" model, experiment=experiment, solver=pybamm.CasadiSolver(dt_max=120)\n",
" model, experiment=experiment, solver=pybamm.CasadiSolver(dt_event=120)\n",
" )\n",
" sim.solve()\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
"sim = pybamm.Simulation(\n",
" model,\n",
" parameter_values=param,\n",
" solver=pybamm.CasadiSolver(dt_max=600),\n",
" solver=pybamm.CasadiSolver(dt_event=600),\n",
" var_pts=var_pts,\n",
")\n",
"solution = sim.solve(t_eval=[0, 3600], inputs={\"C-rate\": 1})\n",
Expand Down
48 changes: 24 additions & 24 deletions docs/source/examples/notebooks/solvers/speed-up-solver.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -366,22 +366,22 @@
"metadata": {},
"outputs": [],
"source": [
"safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_max=30)\n",
"safe_solver_2 = pybamm.CasadiSolver(mode=\"safe\", dt_event=30)\n",
"safe_sol_2 = sim.solve([0, 160], solver=safe_solver_2, inputs={\"Crate\": 10})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Choosing dt_max to speed up the safe mode"
"### Choosing dt_event to speed up the safe mode"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The parameter `dt_max` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events."
"The parameter `dt_event` controls how large the steps taken by the `CasadiSolver` with \"safe\" mode are when looking for events."
]
},
{
Expand All @@ -393,24 +393,24 @@
"name": "stdout",
"output_type": "stream",
"text": [
"With dt_max=10, took 575.783 ms (integration time: 508.473 ms)\n",
"With dt_max=20, took 575.500 ms (integration time: 510.705 ms)\n",
"With dt_max=100, took 316.721 ms (integration time: 275.459 ms)\n",
"With dt_max=1000, took 76.646 ms (integration time: 49.294 ms)\n",
"With dt_max=3700, took 48.773 ms (integration time: 32.436 ms)\n",
"With dt_event=10, took 575.783 ms (integration time: 508.473 ms)\n",
"With dt_event=20, took 575.500 ms (integration time: 510.705 ms)\n",
"With dt_event=100, took 316.721 ms (integration time: 275.459 ms)\n",
"With dt_event=1000, took 76.646 ms (integration time: 49.294 ms)\n",
"With dt_event=3700, took 48.773 ms (integration time: 32.436 ms)\n",
"With 'fast' mode, took 42.224 ms (integration time: 32.177 ms)\n"
]
}
],
"source": [
"for dt_max in [10, 20, 100, 1000, 3700]:\n",
"for dt_event in [10, 20, 100, 1000, 3700]:\n",
" safe_sol = sim.solve(\n",
" [0, 3600],\n",
" solver=pybamm.CasadiSolver(mode=\"safe\", dt_max=dt_max),\n",
" solver=pybamm.CasadiSolver(mode=\"safe\", dt_event=dt_event),\n",
" inputs={\"Crate\": 1},\n",
" )\n",
" print(\n",
" f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n",
" f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n",
" + f\"(integration time: {safe_sol.integration_time})\"\n",
" )\n",
"\n",
Expand All @@ -425,17 +425,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In general, a larger value of `dt_max` gives a faster solution, since fewer integrator creations and calls are required.\n",
"In general, a larger value of `dt_event` gives a faster solution, since fewer integrator creations and calls are required.\n",
"\n",
"Below the solution time interval of 36s, the value of `dt_max` does not affect the solve time, since steps must be at least 36s large.\n",
"Below the solution time interval of 36s, the value of `dt_event` does not affect the solve time, since steps must be at least 36s large.\n",
"The discrepancy between the solve time and integration time is due to the extra operations recorded by \"solve time\", such as creating the integrator. The \"fast\" solver does not need to do this (it reuses the first one it had already created), so the solve time is much closer to the integration time."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The example above was a case where no events are triggered, so the largest `dt_max` works well. If we step over events, then it is possible to makes `dt_max` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:"
"The example above was a case where no events are triggered, so the largest `dt_event` works well. If we step over events, then it is possible to makes `dt_event` too large, so that the solver will attempt (and fail) to take large steps past the event, iteratively reducing the step size until it works. For example:"
]
},
{
Expand All @@ -447,10 +447,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"With dt_max=10, took 504.163 ms (integration time: 419.740 ms)\n",
"With dt_max=20, took 504.691 ms (integration time: 421.396 ms)\n",
"With dt_max=100, took 286.620 ms (integration time: 238.390 ms)\n",
"With dt_max=1000, took 98.500 ms (integration time: 60.880 ms)\n"
"With dt_event=10, took 504.163 ms (integration time: 419.740 ms)\n",
"With dt_event=20, took 504.691 ms (integration time: 421.396 ms)\n",
"With dt_event=100, took 286.620 ms (integration time: 238.390 ms)\n",
"With dt_event=1000, took 98.500 ms (integration time: 60.880 ms)\n"
]
},
{
Expand All @@ -466,22 +466,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
"With dt_max=3600, took 645.118 ms (integration time: 32.601 ms)\n"
"With dt_event=3600, took 645.118 ms (integration time: 32.601 ms)\n"
]
}
],
"source": [
"for dt_max in [10, 20, 100, 1000, 3600]:\n",
"for dt_event in [10, 20, 100, 1000, 3600]:\n",
" # Reduce max_num_steps to fail faster\n",
" safe_sol = sim.solve(\n",
" [0, 4500],\n",
" solver=pybamm.CasadiSolver(\n",
" mode=\"safe\", dt_max=dt_max, extra_options_setup={\"max_num_steps\": 1000}\n",
" mode=\"safe\", dt_event=dt_event, extra_options_setup={\"max_num_steps\": 1000}\n",
" ),\n",
" inputs={\"Crate\": 1},\n",
" )\n",
" print(\n",
" f\"With dt_max={dt_max}, took {safe_sol.solve_time} \"\n",
" f\"With dt_event={dt_event}, took {safe_sol.solve_time} \"\n",
" + f\"(integration time: {safe_sol.integration_time})\"\n",
" )"
]
Expand All @@ -490,7 +490,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The integration time with `dt_max=3600` remains the fastest, but the solve time is the slowest due to all the failed steps."
"The integration time with `dt_event=3600` remains the fastest, but the solve time is the slowest due to all the failed steps."
]
},
{
Expand All @@ -504,7 +504,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_max`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve."
"The \"period\" argument of the experiments also affects how long the simulations take, for a similar reason to `dt_event`. Therefore, this argument can be manually tuned to speed up how long an experiment takes to solve."
]
},
{
Expand Down
1 change: 1 addition & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@

from .solvers.jax_solver import JaxSolver
from .solvers.jax_bdf_solver import jax_bdf_integrate
from .solvers.base_solver import validate_max_step

from .solvers.idaklu_solver import IDAKLUSolver, have_idaklu

Expand Down
7 changes: 6 additions & 1 deletion pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from scipy import optimize
from scipy.sparse import issparse
from .base_solver import validate_max_step


class AlgebraicSolver(pybamm.BaseSolver):
Expand All @@ -24,14 +25,18 @@ class AlgebraicSolver(pybamm.BaseSolver):
specified in the form "lsq_methodname"
tol : float, optional
The tolerance for the solver (default is 1e-6).
max_step : float, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the algebraic solver doesn't step in time so it doesn't make sense for it to be an argument here

Maximum allowed step size. Default is np.inf, i.e., the step size is not
bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the rootfinder. Vary depending on which method is chosen.
Please consult `SciPy documentation <https://tinyurl.com/ybr6cfqs>`_ for
details.
"""

def __init__(self, method="lm", tol=1e-6, extra_options=None):
def __init__(self, method="lm", tol=1e-6, max_step=np.inf, extra_options=None):
super().__init__(method=method)
self.max_step = validate_max_step(max_step)
self.tol = tol
self.extra_options = extra_options or {}
self.name = f"Algebraic solver ({method})"
Expand Down
12 changes: 12 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from pybamm.expression_tree.binary_operators import _Heaviside


def validate_max_step(max_step):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't validate other arguments such as tolerances. my guess is that all the solvers will have their own check to make sure max step makes sense, so maybe we should remove this?

"""Assert that max_step is valid and return it."""
if max_step <= 0:
raise ValueError("`max_step` must be positive.")
return max_step


class BaseSolver:
"""Solve a discretised model.

Expand All @@ -38,6 +45,9 @@ class BaseSolver:
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
max_step : float, optional
Maximum allowed step size. Default is np.inf, i.e., the step size is not
bounded and determined solely by the solver.
output_variables : list[str], optional
List of variables to calculate and return. If none are specified then
the complete state vector is returned (can be very large) (default is [])
Expand All @@ -51,6 +61,7 @@ def __init__(
root_method=None,
root_tol=1e-6,
extrap_tol=None,
max_step=np.inf,
output_variables=[],
):
self.method = method
Expand All @@ -59,6 +70,7 @@ def __init__(
self.root_tol = root_tol
self.root_method = root_method
self.extrap_tol = extrap_tol or -1e-10
self.max_step = validate_max_step(max_step)
self.output_variables = output_variables
self._model_set_up = {}

Expand Down
7 changes: 6 additions & 1 deletion pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import casadi
import pybamm
import numpy as np
from .base_solver import validate_max_step


class CasadiAlgebraicSolver(pybamm.BaseSolver):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above

Expand All @@ -17,17 +18,21 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver):
----------
tol : float, optional
The tolerance for the solver (default is 1e-6).
max_step : float, optional
Maximum allowed step size. Default is np.inf, i.e., the step size is not
bounded and determined solely by the solver.
extra_options : dict, optional
Any options to pass to the CasADi rootfinder.
Please consult `CasADi documentation <https://tinyurl.com/y7hrxm7d>`_ for
details.

"""

def __init__(self, tol=1e-6, extra_options=None):
def __init__(self, tol=1e-6, max_step=np.inf, extra_options=None):
super().__init__()
self.tol = tol
self.name = "CasADi algebraic solver"
self.max_step = validate_max_step(max_step)
self.algebraic_solver = True
self.extra_options = extra_options or {}
pybamm.citations.register("Andersson2019")
Expand Down
Loading
Loading