Skip to content

Commit

Permalink
casadi and scipy solver working with multiple inputs #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jun 7, 2024
1 parent d09bf18 commit 9ce0917
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 124 deletions.
170 changes: 103 additions & 67 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import itertools
from scipy.sparse import block_diag
import numbers
import sys
import warnings
Expand Down Expand Up @@ -184,8 +183,15 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):

# Process rhs, algebraic, residual and event expressions
# and wrap in callables
is_casadi_solver = isinstance(
self.root_method, pybamm.CasadiAlgebraicSolver
) or isinstance(self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver))
if is_casadi_solver and len(model.rhs) > 0:
rhs = model.mass_matrix_inv @ model.concatenated_rhs
else:
rhs = model.concatenated_rhs
rhs, jac_rhs, jacp_rhs, jac_rhs_action = process(
model.concatenated_rhs, "RHS", vars_for_processing, ninputs=len(inputs)
rhs, "RHS", vars_for_processing, ninputs=len(inputs)
)

algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action = process(
Expand Down Expand Up @@ -245,23 +251,11 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
# Save CasADi functions for the CasADi solver
# Save CasADi functions for solvers that use CasADi
# Note: when we pass to casadi the ode part of the problem must be in
if isinstance(self.root_method, pybamm.CasadiAlgebraicSolver) or isinstance(
self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver)
):
if is_casadi_solver:
# can use DAE solver to solve model with algebraic equations only
# todo: do I need this check?
if len(model.rhs) > 0:
t_casadi = vars_for_processing["t_casadi"]
y_and_S = vars_for_processing["y_and_S"]
p_casadi_stacked = vars_for_processing["p_casadi_stacked"]
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
explicit_rhs = mass_matrix_inv @ rhs(
t_casadi, y_and_S, p_casadi_stacked
)
model.casadi_rhs = casadi.Function(
"rhs", [t_casadi, y_and_S, p_casadi_stacked], [explicit_rhs]
)
if len(inputs) > 1:
model.casadi_rhs = model.casadi_rhs.map(len(inputs), "openmp")
model.casadi_rhs = rhs
model.casadi_switch_events = casadi_switch_events
model.casadi_algebraic = algebraic
model.casadi_sensitivities = jacp_rhs_algebraic
Expand Down Expand Up @@ -459,19 +453,6 @@ def _set_up_model_sensitivities_inplace(
model.bounds[0][: model.len_rhs_and_alg],
model.bounds[1][: model.len_rhs_and_alg],
)
if (
model.mass_matrix is not None
and model.mass_matrix.shape[0] > model.len_rhs_and_alg
):
if model.mass_matrix_inv is not None:
model.mass_matrix_inv = pybamm.Matrix(
model.mass_matrix_inv.entries[: model.len_rhs, : model.len_rhs]
)
model.mass_matrix = pybamm.Matrix(
model.mass_matrix.entries[
: model.len_rhs_and_alg, : model.len_rhs_and_alg
]
)

# now we can extend them by the number of sensitivity parameters
# if needed
Expand All @@ -485,22 +466,6 @@ def _set_up_model_sensitivities_inplace(
np.repeat(model.bounds[0], n_inputs + 1),
np.repeat(model.bounds[1], n_inputs + 1),
)
if (
model.mass_matrix is not None
and model.mass_matrix.shape[0] == model.len_rhs_and_alg
):
if model.mass_matrix_inv is not None:
model.mass_matrix_inv = pybamm.Matrix(
block_diag(
[model.mass_matrix_inv.entries] * (n_inputs + 1),
format="csr",
)
)
model.mass_matrix = pybamm.Matrix(
block_diag(
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
)
)

def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
# Check for heaviside and modulo functions in rhs and algebraic and add
Expand Down Expand Up @@ -632,7 +597,7 @@ def _set_initial_conditions(self, model, time, inputs_list, update_rhs):
"""

y0_total_size = (
y0_total_size = len(inputs_list) * (
model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens
)
y_zero = np.zeros((y0_total_size, 1))
Expand Down Expand Up @@ -1057,14 +1022,19 @@ def _check_events_with_initial_conditions(t_eval, model, inputs_list):
event_eval = event(t=t_eval[0], y=model.y0, inputs=inputs_list)
events_eval[idx] = event_eval

events_eval = np.array(events_eval)
if model.convert_to_format == "casadi":
events_eval = casadi.vertcat(*events_eval).toarray().flatten()
else:
events_eval = np.vstack(events_eval).flatten()
if any(events_eval < 0):
# find the events that were triggered by initial conditions
termination_events = [
x for x in model.events if x.event_type == pybamm.EventType.TERMINATION
]
idxs = np.where(events_eval < 0)[0]
event_names = [termination_events[idx].name for idx in idxs]
event_names = [
termination_events[idx / len(inputs_list)].name for idx in idxs
]
raise pybamm.SolverError(
f"Events {event_names} are non-positive at initial conditions"
)
Expand Down Expand Up @@ -1311,15 +1281,14 @@ def get_termination_reason(solution, events):
)
termination_event = min(final_event_values, key=final_event_values.get)

# Check that it's actually an event
if final_event_values[termination_event] > 0.1: # pragma: no cover
# Hard to test this
raise pybamm.SolverError(
"Could not determine which event was triggered "
"(possibly due to NaNs)"
)
# Add the event to the solution object
solution.termination = f"event: {termination_event}"
# Check that it's actually an event for this solution (might be from another input set)
if final_event_values[termination_event] > 0.1:
solution.termination = (
f"event: {termination_event} from another input set"
)
else:
solution.termination = f"event: {termination_event}"
# Update t, y and inputs to include event time and state
# Note: if the final entry of t is equal to the event time we skip
# this (having duplicate entries causes an error later in ProcessedVariable)
Expand All @@ -1333,8 +1302,8 @@ def get_termination_reason(solution, events):
solution.y_event,
solution.termination,
)
event_sol.solve_time = 0
event_sol.integration_time = 0
event_sol.solve_time = 0.0
event_sol.integration_time = 0.0
solution = solution + event_sol

pybamm.logger.debug("Finish post-processing events")
Expand Down Expand Up @@ -1421,6 +1390,72 @@ def _set_up_model_inputs(model, inputs):
return ordered_inputs


def map_func_over_inputs(name, f, vars_for_processing, ninputs):
"""
This takes a casadi function f and returns a new casadi function that maps f over
the provided number of inputs. Some functions (e.g. jacobian action) require an additional
vector input v, which is why add_v is provided.
Parameters
----------
name: str
name of the new function. This must end in the string "_action" for jacobian action functions,
"_jac" for jacobian functions, or "_jacp" for jacp functions.
f: casadi.Function
function to map
vars_for_processing: dict
dictionary of variables for processing
ninputs: int
number of inputs to map over
add_v: bool
whether to add a vector v to the inputs
"""
if f is None:
return None

add_v = name.endswith("_action")
matrix_output = name.endswith("_jac") or name.endswith("_jacp")

nstates = vars_for_processing["y_and_S"].shape[0]
nparams = vars_for_processing["p_casadi_stacked"].shape[0]

parallelisation = "thread"
y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs)
p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs)
v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs)

y_and_S_2d = y_and_S_inputs_stacked.reshape((nstates, ninputs))
p_casadi_2d = p_casadi_inputs_stacked.reshape((nparams, ninputs))
v_2d = v_inputs_stacked.reshape((nstates, ninputs))

t_casadi = vars_for_processing["t_casadi"]

if add_v:
inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d, v_2d]
inputs_stacked = [
t_casadi,
y_and_S_inputs_stacked,
p_casadi_inputs_stacked,
v_inputs_stacked,
]
else:
inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d]
inputs_stacked = [t_casadi, y_and_S_inputs_stacked, p_casadi_inputs_stacked]

mapped_f = f.map(ninputs, parallelisation)(*inputs_2d)
if matrix_output:
# for matrix output we need to stack the outputs in a block diagonal matrix
splits = [i * nstates for i in range(ninputs + 1)]
split = casadi.horzsplit(mapped_f, splits)
stack = casadi.diagcat(*split)
else:
# for vector outputs we need to stack them vertically in a single column vector
splits = [i for i in range(ninputs + 1)]
split = casadi.horzsplit(mapped_f, splits)
stack = casadi.vertcat(*split)
return casadi.Function(name, inputs_stacked, [stack])


def process(
symbol,
name,
Expand Down Expand Up @@ -1466,6 +1501,7 @@ def process(
:class:`casadi.Function`
evaluator for product of the Jacobian with a vector $v$,
i.e. $\\frac{\\partial f}{\\partial y} * v$
"""

def report(string):
Expand Down Expand Up @@ -1664,13 +1700,13 @@ def jacp(*args, **kwargs):
)

if ninputs > 1:
parallelisation = "openmp"
func = func.map(ninputs, parallelisation)
if jac is not None:
jac = jac.map(ninputs, parallelisation)
if jacp is not None:
jacp = jacp.map(ninputs, parallelisation)
if jac_action is not None:
jac_action = jac_action.map(ninputs, parallelisation)
func = map_func_over_inputs(name, func, vars_for_processing, ninputs)
jac = map_func_over_inputs(name + "_jac", jac, vars_for_processing, ninputs)
jacp = map_func_over_inputs(
name + "_jacp", jacp, vars_for_processing, ninputs
)
jac_action = map_func_over_inputs(
name + "_jac_action", jac_action, vars_for_processing, ninputs
)

return func, jac, jacp, jac_action
Loading

0 comments on commit 9ce0917

Please sign in to comment.