diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index ac549d79f3..c5668391e1 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -430,11 +430,16 @@ class EvaluatorPython: symbol : :class:`pybamm.Symbol` The symbol to convert to python code - + is_event: bool + Indicates this symbol is an event expression + is_matrix: bool + Indicates the evaluation of this symbol results in a matrix (otherwise result is a vector) """ - def __init__(self, symbol: pybamm.Symbol): + def __init__( + self, symbol: pybamm.Symbol, is_event: bool = False, is_matrix: bool = False + ): constants, python_str = pybamm.to_python(symbol, debug=False) # extract constants in generated function @@ -472,6 +477,8 @@ def __init__(self, symbol: pybamm.Symbol): self._python_str = python_str self._result_var = result_var self._symbol = symbol + self._is_event = is_event + self._is_matrix = is_matrix # compile and run the generated python code, compiled_function = compile(python_str, result_var, "exec") @@ -486,21 +493,41 @@ def __call__(self, t=None, y=None, inputs=None): y = y.reshape(-1, 1) if isinstance(inputs, list): - result = self._evaluate(self._constants, t, y, inputs[0]) - if len(inputs) > 1: - if isinstance(result, numbers.Number): - result = np.array([result]) - ny = result.shape[0] - ni = len(inputs) - results = np.zeros((ni * ny, 1)) - results[:ny] = result - i = ny - for input in inputs[1:]: - results[i : i + ny] += self._evaluate( - self._constants, t, y[i : i + ny], input - ) - i += ny - result = results + if len(inputs) == 1: + # nothing to do for a single input + result = self._evaluate(self._constants, t, y, inputs[0]) + elif self._is_event: + # if an event do a soft max on the results to combine events from multiple + # inputs + results = np.array( + [self._evaluate(self._constants, t, y, input) for input in inputs] + ) + margin = 1e-4 + alpha = np.log(len(inputs)) / margin + result = scipy.special.logsumexp(alpha * results) / alpha + elif self._is_matrix: + # if a matrix output, concatenate the results in a block diagonal matrix + results = [ + self._evaluate(self._constants, t, y, input) for input in inputs + ] + result = scipy.sparse.block_diag(*results, format="csr") + else: + # otherwise concatenate the results in a column vector + result = self._evaluate(self._constants, t, y, inputs[0]) + if len(inputs) > 1: + if isinstance(result, numbers.Number): + result = np.array([result]) + ny = result.shape[0] + ni = len(inputs) + results = np.zeros((ni * ny, 1)) + results[:ny] = result + i = ny + for input in inputs[1:]: + results[i : i + ny] += self._evaluate( + self._constants, t, y[i : i + ny], input + ) + i += ny + result = results else: result = self._evaluate(self._constants, t, y, inputs) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 434e1b91e7..65fa2fd9fc 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1390,7 +1390,76 @@ def _set_up_model_inputs(model, inputs): return ordered_inputs -def map_func_over_inputs(name, f, vars_for_processing, ninputs): +def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs): + """ + This takes a casadi function f and returns a new casadi function that maps f over + the provided number of inputs. Some functions (e.g. jacobian action) require an additional + vector input v, which is why add_v is provided. + + Parameters + ---------- + name: str + name of the new function. This must end in the string "_action" for jacobian action functions, + "_jac" for jacobian functions, or "_jacp" for jacp functions. + f: casadi.Function + function to map + vars_for_processing: dict + dictionary of variables for processing + ninputs: int + number of inputs to map over + """ + if f is None: + return None + + is_event = "event" in name + add_v = name.endswith("_action") + matrix_output = name.endswith("_jac") or name.endswith("_jacp") + + nstates = vars_for_processing["y_and_S"].shape[0] + nparams = vars_for_processing["p_casadi_stacked"].shape[0] + + parallelisation = "thread" + y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs) + p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs) + v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs) + + y_and_S_2d = y_and_S_inputs_stacked.reshape((nstates, ninputs)) + p_casadi_2d = p_casadi_inputs_stacked.reshape((nparams, ninputs)) + v_2d = v_inputs_stacked.reshape((nstates, ninputs)) + + t_casadi = vars_for_processing["t_casadi"] + + if add_v: + inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d, v_2d] + inputs_stacked = [ + t_casadi, + y_and_S_inputs_stacked, + p_casadi_inputs_stacked, + v_inputs_stacked, + ] + else: + inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d] + inputs_stacked = [t_casadi, y_and_S_inputs_stacked, p_casadi_inputs_stacked] + + mapped_f = f.map(ninputs, parallelisation)(*inputs_2d) + if matrix_output: + # for matrix output we need to stack the outputs in a block diagonal matrix + splits = [i * nstates for i in range(ninputs + 1)] + split = casadi.horzsplit(mapped_f, splits) + stack = casadi.diagcat(*split) + elif is_event: + # Events need to return a scalar, so we combine the vector output + # of the mapped function into a scalar output by calculating a smooth max of the vector output. + stack = casadi.logsumexp(casadi.transpose(mapped_f), 1e-4) + else: + # for vector outputs we need to stack them vertically in a single column vector + splits = [i for i in range(ninputs + 1)] + split = casadi.horzsplit(mapped_f, splits) + stack = casadi.vertcat(*split) + return casadi.Function(name, inputs_stacked, [stack]) + + +def map_func_over_inputs_jax(name, f, vars_for_processing, ninputs): """ This takes a casadi function f and returns a new casadi function that maps f over the provided number of inputs. Some functions (e.g. jacobian action) require an additional @@ -1574,7 +1643,7 @@ def jacp(*args, **kwargs): jac_action = None report(f"Converting {name} to python") - func = pybamm.EvaluatorPython(symbol) + func = pybamm.EvaluatorPython(symbol, is_event="event" in name) else: t_casadi = vars_for_processing["t_casadi"] @@ -1703,12 +1772,14 @@ def jacp(*args, **kwargs): ) if ninputs > 1: - func = map_func_over_inputs(name, func, vars_for_processing, ninputs) - jac = map_func_over_inputs(name + "_jac", jac, vars_for_processing, ninputs) - jacp = map_func_over_inputs( + func = map_func_over_inputs_casadi(name, func, vars_for_processing, ninputs) + jac = map_func_over_inputs_casadi( + name + "_jac", jac, vars_for_processing, ninputs + ) + jacp = map_func_over_inputs_casadi( name + "_jacp", jacp, vars_for_processing, ninputs ) - jac_action = map_func_over_inputs( + jac_action = map_func_over_inputs_casadi( name + "_jac_action", jac_action, vars_for_processing, ninputs ) diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index ed152a0140..603bc9689f 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -5,7 +5,6 @@ import casadi import pybamm import numpy as np -import numbers import scipy.sparse as sparse import importlib @@ -154,14 +153,28 @@ def _check_atol_type(self, atol, size): return atol - def set_up(self, model, inputs=None, t_eval=None, ics_only=False): - base_set_up_return = super().set_up(model, inputs, t_eval, ics_only) + def set_up(self, model, input_list=None, t_eval=None, ics_only=False): + base_set_up_return = super().set_up(model, input_list, t_eval, ics_only) + + inputs_list = input_list or [{}] + nparams = sum( + len(np.array(v).reshape(-1, 1)) for _, v in inputs_list[0].items() + ) + ninputs = len(inputs_list) + nstates = model.y0.shape[0] + print("nstates", nstates) + print("nparams", nparams) + print("ninputs", ninputs) - inputs_dict = inputs or {} # stack inputs - if inputs_dict: - arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] + if inputs_list and len(inputs_list) > 0 and len(inputs_list[0]) > 0: + arrays_to_stack = [ + np.array(x).reshape(-1, 1) + for inputs in inputs_list + for x in inputs.values() + ] inputs_sizes = [len(array) for array in arrays_to_stack] + print(arrays_to_stack) inputs = np.vstack(arrays_to_stack) else: inputs_sizes = [] @@ -169,10 +182,14 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): def inputs_to_dict(inputs): index = 0 - for n, key in zip(inputs_sizes, inputs_dict.keys()): - inputs_dict[key] = inputs[index : (index + n)] - index += n - return inputs_dict + new_inputs_list = [] + for inputs in inputs_list: + inputs_dict = {} + for n, key in zip(inputs_sizes, inputs.keys()): + inputs_dict[key] = inputs[index : (index + n)] + index += n + new_inputs_list.append(inputs_dict) + return inputs_list y0 = model.y0 if isinstance(y0, casadi.DM): @@ -212,9 +229,12 @@ def inputs_to_dict(inputs): if model.convert_to_format == "casadi": # TODO: do we need densify here? rhs_algebraic = model.rhs_algebraic_eval + print("rhs_algebraic_eval", rhs_algebraic(0, y0, inputs)) else: def resfn(t, y, inputs, ydot): + print("resfn", y, inputs) + print(model.rhs_algebraic_eval(t, y, inputs_to_dict(inputs))) return ( model.rhs_algebraic_eval(t, y, inputs_to_dict(inputs)).flatten() - mass_matrix @ ydot @@ -226,15 +246,10 @@ def resfn(t, y, inputs, ydot): # need to provide jacobian_rhs_alg - cj * mass_matrix if model.convert_to_format == "casadi": t_casadi = casadi.MX.sym("t") - y_casadi = casadi.MX.sym("y", model.len_rhs_and_alg) + print("nstates", nstates) + y_casadi = casadi.MX.sym("y", nstates) cj_casadi = casadi.MX.sym("cj") - p_casadi = {} - for name, value in inputs_dict.items(): - if isinstance(value, numbers.Number): - p_casadi[name] = casadi.MX.sym(name) - else: - p_casadi[name] = casadi.MX.sym(name, value.shape[0]) - p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()]) + p_casadi_stacked = casadi.MX.sym("p_stacked", nparams * ninputs) jac_times_cjmass = casadi.Function( "jac_times_cjmass", @@ -244,6 +259,7 @@ def resfn(t, y, inputs, ydot): - cj_casadi * mass_matrix ], ) + print("jac_times_cjmass", jac_times_cjmass(0, y0, inputs, 1)) jac_times_cjmass_sparsity = jac_times_cjmass.sparsity_out(0) jac_bw_lower = jac_times_cjmass_sparsity.bw_lower() @@ -256,7 +272,7 @@ def resfn(t, y, inputs, ydot): jac_times_cjmass_sparsity.row(), dtype=np.int64 ) - v_casadi = casadi.MX.sym("v", model.len_rhs_and_alg) + v_casadi = casadi.MX.sym("v", nstates) jac_rhs_algebraic_action = model.jac_rhs_algebraic_action_eval @@ -295,19 +311,23 @@ def resfn(t, y, inputs, ydot): else: t0 = 0 if t_eval is None else t_eval[0] - jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_dict) + jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_list) if sparse.issparse(jac_y0_t0): def jacfn(t, y, inputs, cj): + print("calling jacfn", y, inputs, inputs_to_dict(inputs)) + print(model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs))) j = ( model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs)) - cj * mass_matrix ) + print("jacfn", j) return j else: def jacfn(t, y, inputs, cj): + print("calling jacfn", y, inputs) jac_eval = ( model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs)) - cj * mass_matrix @@ -355,6 +375,7 @@ def get_jac_col_ptrs(self): ) ], ) + print("rootfn", rootfn(0, y0, inputs)) else: def rootfn(t, y, inputs): @@ -369,9 +390,10 @@ def rootfn(t, y, inputs): if model.convert_to_format == "casadi": rhs_ids = np.ones(model.rhs_eval(0, y0, inputs).shape[0]) else: - rhs_ids = np.ones(model.rhs_eval(0, y0, inputs_dict).shape[0]) + rhs_ids = np.ones(model.rhs_eval(0, y0, inputs_list).shape[0]) alg_ids = np.zeros(len(y0) - len(rhs_ids)) ids = np.concatenate((rhs_ids, alg_ids)) + print("ids", ids) number_of_sensitivity_parameters = 0 if model.jacp_rhs_algebraic_eval is not None: @@ -526,15 +548,16 @@ def _integrate(self, model, t_eval, inputs_list=None): """ inputs_list = inputs_list or [] # stack inputs - if inputs_list: + if inputs_list and len(inputs_list) > 0 and len(inputs_list[0]) > 0: arrays_to_stack = [ np.array(x).reshape(-1, 1) - for inputs in inputs_list.values() - for x in inputs + for inputs in inputs_list + for x in inputs.values() ] inputs = np.vstack(arrays_to_stack) else: inputs = np.array([[]]) + print("inputs", inputs) # do this here cause y0 is set after set_up (calc consistent conditions) y0 = model.y0 @@ -593,7 +616,7 @@ def _integrate(self, model, t_eval, inputs_list=None): self._setup["jac_class"].nnz, self._setup["rootfn"], self._setup["num_of_events"], - self._setup["use_jac"], + 1, self._setup["ids"], atol, rtol, diff --git a/pybamm/solvers/scipy_solver.py b/pybamm/solvers/scipy_solver.py index 5685967ffa..8a72f716ae 100644 --- a/pybamm/solvers/scipy_solver.py +++ b/pybamm/solvers/scipy_solver.py @@ -56,7 +56,7 @@ def _integrate(self, model, t_eval, inputs_list=None): The model whose solution to calculate. t_eval : :class:`numpy.array`, size (k,) The times at which to compute the solution - inputs_dict : dict, optional + inputs_list : list of dict, optional Any input parameters to pass to the model when solving Returns @@ -67,7 +67,7 @@ def _integrate(self, model, t_eval, inputs_list=None): """ # Save inputs dictionary, and if necessary convert inputs to a casadi vector - inputs_list = inputs_list or {} + inputs_list = inputs_list or [{}] if model.convert_to_format == "casadi": inputs = casadi.vertcat( *[x for inputs in inputs_list for x in inputs.values()] diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 1697623486..2eacc97462 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -18,7 +18,7 @@ def test_ida_roberts_klu(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax"]: + for form in ["casadi", "jax"]: if form == "jax" and not pybamm.have_jax(): continue if form == "casadi":