Skip to content

Commit

Permalink
make a start on idaklu solver #4087
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jun 20, 2024
1 parent 7727aef commit f8b3209
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 51 deletions.
61 changes: 44 additions & 17 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,16 @@ class EvaluatorPython:
symbol : :class:`pybamm.Symbol`
The symbol to convert to python code
is_event: bool
Indicates this symbol is an event expression
is_matrix: bool
Indicates the evaluation of this symbol results in a matrix (otherwise result is a vector)
"""

def __init__(self, symbol: pybamm.Symbol):
def __init__(
self, symbol: pybamm.Symbol, is_event: bool = False, is_matrix: bool = False
):
constants, python_str = pybamm.to_python(symbol, debug=False)

# extract constants in generated function
Expand Down Expand Up @@ -472,6 +477,8 @@ def __init__(self, symbol: pybamm.Symbol):
self._python_str = python_str
self._result_var = result_var
self._symbol = symbol
self._is_event = is_event
self._is_matrix = is_matrix

# compile and run the generated python code,
compiled_function = compile(python_str, result_var, "exec")
Expand All @@ -486,21 +493,41 @@ def __call__(self, t=None, y=None, inputs=None):
y = y.reshape(-1, 1)

if isinstance(inputs, list):
result = self._evaluate(self._constants, t, y, inputs[0])
if len(inputs) > 1:
if isinstance(result, numbers.Number):
result = np.array([result])
ny = result.shape[0]
ni = len(inputs)
results = np.zeros((ni * ny, 1))
results[:ny] = result
i = ny
for input in inputs[1:]:
results[i : i + ny] += self._evaluate(
self._constants, t, y[i : i + ny], input
)
i += ny
result = results
if len(inputs) == 1:
# nothing to do for a single input
result = self._evaluate(self._constants, t, y, inputs[0])
elif self._is_event:
# if an event do a soft max on the results to combine events from multiple
# inputs
results = np.array(
[self._evaluate(self._constants, t, y, input) for input in inputs]
)
margin = 1e-4
alpha = np.log(len(inputs)) / margin
result = scipy.special.logsumexp(alpha * results) / alpha
elif self._is_matrix:
# if a matrix output, concatenate the results in a block diagonal matrix
results = [
self._evaluate(self._constants, t, y, input) for input in inputs
]
result = scipy.sparse.block_diag(*results, format="csr")
else:
# otherwise concatenate the results in a column vector
result = self._evaluate(self._constants, t, y, inputs[0])
if len(inputs) > 1:
if isinstance(result, numbers.Number):
result = np.array([result])
ny = result.shape[0]
ni = len(inputs)
results = np.zeros((ni * ny, 1))
results[:ny] = result
i = ny
for input in inputs[1:]:
results[i : i + ny] += self._evaluate(
self._constants, t, y[i : i + ny], input
)
i += ny
result = results
else:
result = self._evaluate(self._constants, t, y, inputs)

Expand Down
83 changes: 77 additions & 6 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,76 @@ def _set_up_model_inputs(model, inputs):
return ordered_inputs


def map_func_over_inputs(name, f, vars_for_processing, ninputs):
def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs):
"""
This takes a casadi function f and returns a new casadi function that maps f over
the provided number of inputs. Some functions (e.g. jacobian action) require an additional
vector input v, which is why add_v is provided.
Parameters
----------
name: str
name of the new function. This must end in the string "_action" for jacobian action functions,
"_jac" for jacobian functions, or "_jacp" for jacp functions.
f: casadi.Function
function to map
vars_for_processing: dict
dictionary of variables for processing
ninputs: int
number of inputs to map over
"""
if f is None:
return None

is_event = "event" in name
add_v = name.endswith("_action")
matrix_output = name.endswith("_jac") or name.endswith("_jacp")

nstates = vars_for_processing["y_and_S"].shape[0]
nparams = vars_for_processing["p_casadi_stacked"].shape[0]

parallelisation = "thread"
y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs)
p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs)
v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs)

y_and_S_2d = y_and_S_inputs_stacked.reshape((nstates, ninputs))
p_casadi_2d = p_casadi_inputs_stacked.reshape((nparams, ninputs))
v_2d = v_inputs_stacked.reshape((nstates, ninputs))

t_casadi = vars_for_processing["t_casadi"]

if add_v:
inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d, v_2d]
inputs_stacked = [
t_casadi,
y_and_S_inputs_stacked,
p_casadi_inputs_stacked,
v_inputs_stacked,
]
else:
inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d]
inputs_stacked = [t_casadi, y_and_S_inputs_stacked, p_casadi_inputs_stacked]

mapped_f = f.map(ninputs, parallelisation)(*inputs_2d)
if matrix_output:
# for matrix output we need to stack the outputs in a block diagonal matrix
splits = [i * nstates for i in range(ninputs + 1)]
split = casadi.horzsplit(mapped_f, splits)
stack = casadi.diagcat(*split)
elif is_event:
# Events need to return a scalar, so we combine the vector output
# of the mapped function into a scalar output by calculating a smooth max of the vector output.
stack = casadi.logsumexp(casadi.transpose(mapped_f), 1e-4)
else:
# for vector outputs we need to stack them vertically in a single column vector
splits = [i for i in range(ninputs + 1)]
split = casadi.horzsplit(mapped_f, splits)
stack = casadi.vertcat(*split)
return casadi.Function(name, inputs_stacked, [stack])


def map_func_over_inputs_jax(name, f, vars_for_processing, ninputs):
"""
This takes a casadi function f and returns a new casadi function that maps f over
the provided number of inputs. Some functions (e.g. jacobian action) require an additional
Expand Down Expand Up @@ -1574,7 +1643,7 @@ def jacp(*args, **kwargs):
jac_action = None

report(f"Converting {name} to python")
func = pybamm.EvaluatorPython(symbol)
func = pybamm.EvaluatorPython(symbol, is_event="event" in name)

else:
t_casadi = vars_for_processing["t_casadi"]
Expand Down Expand Up @@ -1703,12 +1772,14 @@ def jacp(*args, **kwargs):
)

if ninputs > 1:
func = map_func_over_inputs(name, func, vars_for_processing, ninputs)
jac = map_func_over_inputs(name + "_jac", jac, vars_for_processing, ninputs)
jacp = map_func_over_inputs(
func = map_func_over_inputs_casadi(name, func, vars_for_processing, ninputs)
jac = map_func_over_inputs_casadi(
name + "_jac", jac, vars_for_processing, ninputs
)
jacp = map_func_over_inputs_casadi(
name + "_jacp", jacp, vars_for_processing, ninputs
)
jac_action = map_func_over_inputs(
jac_action = map_func_over_inputs_casadi(
name + "_jac_action", jac_action, vars_for_processing, ninputs
)

Expand Down
73 changes: 48 additions & 25 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import casadi
import pybamm
import numpy as np
import numbers
import scipy.sparse as sparse

import importlib
Expand Down Expand Up @@ -154,25 +153,43 @@ def _check_atol_type(self, atol, size):

return atol

def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
base_set_up_return = super().set_up(model, inputs, t_eval, ics_only)
def set_up(self, model, input_list=None, t_eval=None, ics_only=False):
base_set_up_return = super().set_up(model, input_list, t_eval, ics_only)

inputs_list = input_list or [{}]
nparams = sum(
len(np.array(v).reshape(-1, 1)) for _, v in inputs_list[0].items()
)
ninputs = len(inputs_list)
nstates = model.y0.shape[0]
print("nstates", nstates)
print("nparams", nparams)
print("ninputs", ninputs)

inputs_dict = inputs or {}
# stack inputs
if inputs_dict:
arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()]
if inputs_list and len(inputs_list) > 0 and len(inputs_list[0]) > 0:
arrays_to_stack = [
np.array(x).reshape(-1, 1)
for inputs in inputs_list
for x in inputs.values()
]
inputs_sizes = [len(array) for array in arrays_to_stack]
print(arrays_to_stack)
inputs = np.vstack(arrays_to_stack)
else:
inputs_sizes = []
inputs = np.array([[]])

def inputs_to_dict(inputs):
index = 0
for n, key in zip(inputs_sizes, inputs_dict.keys()):
inputs_dict[key] = inputs[index : (index + n)]
index += n
return inputs_dict
new_inputs_list = []
for inputs in inputs_list:
inputs_dict = {}
for n, key in zip(inputs_sizes, inputs.keys()):
inputs_dict[key] = inputs[index : (index + n)]
index += n
new_inputs_list.append(inputs_dict)
return inputs_list

y0 = model.y0
if isinstance(y0, casadi.DM):
Expand Down Expand Up @@ -212,9 +229,12 @@ def inputs_to_dict(inputs):
if model.convert_to_format == "casadi":
# TODO: do we need densify here?
rhs_algebraic = model.rhs_algebraic_eval
print("rhs_algebraic_eval", rhs_algebraic(0, y0, inputs))
else:

def resfn(t, y, inputs, ydot):
print("resfn", y, inputs)
print(model.rhs_algebraic_eval(t, y, inputs_to_dict(inputs)))
return (
model.rhs_algebraic_eval(t, y, inputs_to_dict(inputs)).flatten()
- mass_matrix @ ydot
Expand All @@ -226,15 +246,10 @@ def resfn(t, y, inputs, ydot):
# need to provide jacobian_rhs_alg - cj * mass_matrix
if model.convert_to_format == "casadi":
t_casadi = casadi.MX.sym("t")
y_casadi = casadi.MX.sym("y", model.len_rhs_and_alg)
print("nstates", nstates)
y_casadi = casadi.MX.sym("y", nstates)
cj_casadi = casadi.MX.sym("cj")
p_casadi = {}
for name, value in inputs_dict.items():
if isinstance(value, numbers.Number):
p_casadi[name] = casadi.MX.sym(name)
else:
p_casadi[name] = casadi.MX.sym(name, value.shape[0])
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
p_casadi_stacked = casadi.MX.sym("p_stacked", nparams * ninputs)

jac_times_cjmass = casadi.Function(
"jac_times_cjmass",
Expand All @@ -244,6 +259,7 @@ def resfn(t, y, inputs, ydot):
- cj_casadi * mass_matrix
],
)
print("jac_times_cjmass", jac_times_cjmass(0, y0, inputs, 1))

jac_times_cjmass_sparsity = jac_times_cjmass.sparsity_out(0)
jac_bw_lower = jac_times_cjmass_sparsity.bw_lower()
Expand All @@ -256,7 +272,7 @@ def resfn(t, y, inputs, ydot):
jac_times_cjmass_sparsity.row(), dtype=np.int64
)

v_casadi = casadi.MX.sym("v", model.len_rhs_and_alg)
v_casadi = casadi.MX.sym("v", nstates)

jac_rhs_algebraic_action = model.jac_rhs_algebraic_action_eval

Expand Down Expand Up @@ -295,19 +311,23 @@ def resfn(t, y, inputs, ydot):

else:
t0 = 0 if t_eval is None else t_eval[0]
jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_dict)
jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_list)
if sparse.issparse(jac_y0_t0):

def jacfn(t, y, inputs, cj):
print("calling jacfn", y, inputs, inputs_to_dict(inputs))
print(model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs)))
j = (
model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs))
- cj * mass_matrix
)
print("jacfn", j)
return j

else:

def jacfn(t, y, inputs, cj):
print("calling jacfn", y, inputs)
jac_eval = (
model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs))
- cj * mass_matrix
Expand Down Expand Up @@ -355,6 +375,7 @@ def get_jac_col_ptrs(self):
)
],
)
print("rootfn", rootfn(0, y0, inputs))
else:

def rootfn(t, y, inputs):
Expand All @@ -369,9 +390,10 @@ def rootfn(t, y, inputs):
if model.convert_to_format == "casadi":
rhs_ids = np.ones(model.rhs_eval(0, y0, inputs).shape[0])
else:
rhs_ids = np.ones(model.rhs_eval(0, y0, inputs_dict).shape[0])
rhs_ids = np.ones(model.rhs_eval(0, y0, inputs_list).shape[0])
alg_ids = np.zeros(len(y0) - len(rhs_ids))
ids = np.concatenate((rhs_ids, alg_ids))
print("ids", ids)

number_of_sensitivity_parameters = 0
if model.jacp_rhs_algebraic_eval is not None:
Expand Down Expand Up @@ -526,15 +548,16 @@ def _integrate(self, model, t_eval, inputs_list=None):
"""
inputs_list = inputs_list or []
# stack inputs
if inputs_list:
if inputs_list and len(inputs_list) > 0 and len(inputs_list[0]) > 0:
arrays_to_stack = [
np.array(x).reshape(-1, 1)
for inputs in inputs_list.values()
for x in inputs
for inputs in inputs_list
for x in inputs.values()
]
inputs = np.vstack(arrays_to_stack)
else:
inputs = np.array([[]])
print("inputs", inputs)

# do this here cause y0 is set after set_up (calc consistent conditions)
y0 = model.y0
Expand Down Expand Up @@ -593,7 +616,7 @@ def _integrate(self, model, t_eval, inputs_list=None):
self._setup["jac_class"].nnz,
self._setup["rootfn"],
self._setup["num_of_events"],
self._setup["use_jac"],
1,
self._setup["ids"],
atol,
rtol,
Expand Down
4 changes: 2 additions & 2 deletions pybamm/solvers/scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _integrate(self, model, t_eval, inputs_list=None):
The model whose solution to calculate.
t_eval : :class:`numpy.array`, size (k,)
The times at which to compute the solution
inputs_dict : dict, optional
inputs_list : list of dict, optional
Any input parameters to pass to the model when solving
Returns
Expand All @@ -67,7 +67,7 @@ def _integrate(self, model, t_eval, inputs_list=None):
"""
# Save inputs dictionary, and if necessary convert inputs to a casadi vector
inputs_list = inputs_list or {}
inputs_list = inputs_list or [{}]
if model.convert_to_format == "casadi":
inputs = casadi.vertcat(
*[x for inputs in inputs_list for x in inputs.values()]
Expand Down
Loading

0 comments on commit f8b3209

Please sign in to comment.