From 9ce091705b4b47ec58a06b9c1e6028a352abc07a Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 7 Jun 2024 16:51:47 +0000 Subject: [PATCH] casadi and scipy solver working with multiple inputs #4087 --- pybamm/solvers/base_solver.py | 170 +++++++++++------- pybamm/solvers/casadi_solver.py | 91 +++++----- pybamm/solvers/solution.py | 108 +++++++++++ tests/unit/test_solvers/test_casadi_solver.py | 22 +++ tests/unit/test_solvers/test_scipy_solver.py | 18 +- 5 files changed, 285 insertions(+), 124 deletions(-) diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index f42633cd1f..81e74f89c9 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -1,6 +1,5 @@ import copy import itertools -from scipy.sparse import block_diag import numbers import sys import warnings @@ -184,8 +183,15 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # Process rhs, algebraic, residual and event expressions # and wrap in callables + is_casadi_solver = isinstance( + self.root_method, pybamm.CasadiAlgebraicSolver + ) or isinstance(self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver)) + if is_casadi_solver and len(model.rhs) > 0: + rhs = model.mass_matrix_inv @ model.concatenated_rhs + else: + rhs = model.concatenated_rhs rhs, jac_rhs, jacp_rhs, jac_rhs_action = process( - model.concatenated_rhs, "RHS", vars_for_processing, ninputs=len(inputs) + rhs, "RHS", vars_for_processing, ninputs=len(inputs) ) algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action = process( @@ -245,23 +251,11 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # Save CasADi functions for the CasADi solver # Save CasADi functions for solvers that use CasADi # Note: when we pass to casadi the ode part of the problem must be in - if isinstance(self.root_method, pybamm.CasadiAlgebraicSolver) or isinstance( - self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver) - ): + if is_casadi_solver: # can use DAE solver to solve model with algebraic equations only + # todo: do I need this check? if len(model.rhs) > 0: - t_casadi = vars_for_processing["t_casadi"] - y_and_S = vars_for_processing["y_and_S"] - p_casadi_stacked = vars_for_processing["p_casadi_stacked"] - mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries) - explicit_rhs = mass_matrix_inv @ rhs( - t_casadi, y_and_S, p_casadi_stacked - ) - model.casadi_rhs = casadi.Function( - "rhs", [t_casadi, y_and_S, p_casadi_stacked], [explicit_rhs] - ) - if len(inputs) > 1: - model.casadi_rhs = model.casadi_rhs.map(len(inputs), "openmp") + model.casadi_rhs = rhs model.casadi_switch_events = casadi_switch_events model.casadi_algebraic = algebraic model.casadi_sensitivities = jacp_rhs_algebraic @@ -459,19 +453,6 @@ def _set_up_model_sensitivities_inplace( model.bounds[0][: model.len_rhs_and_alg], model.bounds[1][: model.len_rhs_and_alg], ) - if ( - model.mass_matrix is not None - and model.mass_matrix.shape[0] > model.len_rhs_and_alg - ): - if model.mass_matrix_inv is not None: - model.mass_matrix_inv = pybamm.Matrix( - model.mass_matrix_inv.entries[: model.len_rhs, : model.len_rhs] - ) - model.mass_matrix = pybamm.Matrix( - model.mass_matrix.entries[ - : model.len_rhs_and_alg, : model.len_rhs_and_alg - ] - ) # now we can extend them by the number of sensitivity parameters # if needed @@ -485,22 +466,6 @@ def _set_up_model_sensitivities_inplace( np.repeat(model.bounds[0], n_inputs + 1), np.repeat(model.bounds[1], n_inputs + 1), ) - if ( - model.mass_matrix is not None - and model.mass_matrix.shape[0] == model.len_rhs_and_alg - ): - if model.mass_matrix_inv is not None: - model.mass_matrix_inv = pybamm.Matrix( - block_diag( - [model.mass_matrix_inv.entries] * (n_inputs + 1), - format="csr", - ) - ) - model.mass_matrix = pybamm.Matrix( - block_diag( - [model.mass_matrix.entries] * (n_inputs + 1), format="csr" - ) - ) def _set_up_events(self, model, t_eval, inputs, vars_for_processing): # Check for heaviside and modulo functions in rhs and algebraic and add @@ -632,7 +597,7 @@ def _set_initial_conditions(self, model, time, inputs_list, update_rhs): """ - y0_total_size = ( + y0_total_size = len(inputs_list) * ( model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens ) y_zero = np.zeros((y0_total_size, 1)) @@ -1057,14 +1022,19 @@ def _check_events_with_initial_conditions(t_eval, model, inputs_list): event_eval = event(t=t_eval[0], y=model.y0, inputs=inputs_list) events_eval[idx] = event_eval - events_eval = np.array(events_eval) + if model.convert_to_format == "casadi": + events_eval = casadi.vertcat(*events_eval).toarray().flatten() + else: + events_eval = np.vstack(events_eval).flatten() if any(events_eval < 0): # find the events that were triggered by initial conditions termination_events = [ x for x in model.events if x.event_type == pybamm.EventType.TERMINATION ] idxs = np.where(events_eval < 0)[0] - event_names = [termination_events[idx].name for idx in idxs] + event_names = [ + termination_events[idx / len(inputs_list)].name for idx in idxs + ] raise pybamm.SolverError( f"Events {event_names} are non-positive at initial conditions" ) @@ -1311,15 +1281,14 @@ def get_termination_reason(solution, events): ) termination_event = min(final_event_values, key=final_event_values.get) - # Check that it's actually an event - if final_event_values[termination_event] > 0.1: # pragma: no cover - # Hard to test this - raise pybamm.SolverError( - "Could not determine which event was triggered " - "(possibly due to NaNs)" - ) # Add the event to the solution object - solution.termination = f"event: {termination_event}" + # Check that it's actually an event for this solution (might be from another input set) + if final_event_values[termination_event] > 0.1: + solution.termination = ( + f"event: {termination_event} from another input set" + ) + else: + solution.termination = f"event: {termination_event}" # Update t, y and inputs to include event time and state # Note: if the final entry of t is equal to the event time we skip # this (having duplicate entries causes an error later in ProcessedVariable) @@ -1333,8 +1302,8 @@ def get_termination_reason(solution, events): solution.y_event, solution.termination, ) - event_sol.solve_time = 0 - event_sol.integration_time = 0 + event_sol.solve_time = 0.0 + event_sol.integration_time = 0.0 solution = solution + event_sol pybamm.logger.debug("Finish post-processing events") @@ -1421,6 +1390,72 @@ def _set_up_model_inputs(model, inputs): return ordered_inputs +def map_func_over_inputs(name, f, vars_for_processing, ninputs): + """ + This takes a casadi function f and returns a new casadi function that maps f over + the provided number of inputs. Some functions (e.g. jacobian action) require an additional + vector input v, which is why add_v is provided. + + Parameters + ---------- + name: str + name of the new function. This must end in the string "_action" for jacobian action functions, + "_jac" for jacobian functions, or "_jacp" for jacp functions. + f: casadi.Function + function to map + vars_for_processing: dict + dictionary of variables for processing + ninputs: int + number of inputs to map over + add_v: bool + whether to add a vector v to the inputs + """ + if f is None: + return None + + add_v = name.endswith("_action") + matrix_output = name.endswith("_jac") or name.endswith("_jacp") + + nstates = vars_for_processing["y_and_S"].shape[0] + nparams = vars_for_processing["p_casadi_stacked"].shape[0] + + parallelisation = "thread" + y_and_S_inputs_stacked = casadi.MX.sym("y_and_S_stacked", nstates * ninputs) + p_casadi_inputs_stacked = casadi.MX.sym("p_stacked", nparams * ninputs) + v_inputs_stacked = casadi.MX.sym("v_stacked", nstates * ninputs) + + y_and_S_2d = y_and_S_inputs_stacked.reshape((nstates, ninputs)) + p_casadi_2d = p_casadi_inputs_stacked.reshape((nparams, ninputs)) + v_2d = v_inputs_stacked.reshape((nstates, ninputs)) + + t_casadi = vars_for_processing["t_casadi"] + + if add_v: + inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d, v_2d] + inputs_stacked = [ + t_casadi, + y_and_S_inputs_stacked, + p_casadi_inputs_stacked, + v_inputs_stacked, + ] + else: + inputs_2d = [t_casadi, y_and_S_2d, p_casadi_2d] + inputs_stacked = [t_casadi, y_and_S_inputs_stacked, p_casadi_inputs_stacked] + + mapped_f = f.map(ninputs, parallelisation)(*inputs_2d) + if matrix_output: + # for matrix output we need to stack the outputs in a block diagonal matrix + splits = [i * nstates for i in range(ninputs + 1)] + split = casadi.horzsplit(mapped_f, splits) + stack = casadi.diagcat(*split) + else: + # for vector outputs we need to stack them vertically in a single column vector + splits = [i for i in range(ninputs + 1)] + split = casadi.horzsplit(mapped_f, splits) + stack = casadi.vertcat(*split) + return casadi.Function(name, inputs_stacked, [stack]) + + def process( symbol, name, @@ -1466,6 +1501,7 @@ def process( :class:`casadi.Function` evaluator for product of the Jacobian with a vector $v$, i.e. $\\frac{\\partial f}{\\partial y} * v$ + """ def report(string): @@ -1664,13 +1700,13 @@ def jacp(*args, **kwargs): ) if ninputs > 1: - parallelisation = "openmp" - func = func.map(ninputs, parallelisation) - if jac is not None: - jac = jac.map(ninputs, parallelisation) - if jacp is not None: - jacp = jacp.map(ninputs, parallelisation) - if jac_action is not None: - jac_action = jac_action.map(ninputs, parallelisation) + func = map_func_over_inputs(name, func, vars_for_processing, ninputs) + jac = map_func_over_inputs(name + "_jac", jac, vars_for_processing, ninputs) + jacp = map_func_over_inputs( + name + "_jacp", jacp, vars_for_processing, ninputs + ) + jac_action = map_func_over_inputs( + name + "_jac_action", jac_action, vars_for_processing, ninputs + ) return func, jac, jacp, jac_action diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index cf86c20821..5c2e720f65 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -166,16 +166,16 @@ def _integrate(self, model, t_eval, inputs_list=None): self.create_integrator( model, inputs, t_eval, use_event_switch=use_event_switch ) - solutions = self._run_integrator( + solution = self._run_integrator( model, model.y0, inputs_list, inputs, t_eval ) # Check if the sign of an event changes, if so find an accurate # termination point and exit # Note: this is only done for the first solution, is this correct? - solutions[0] = self._solve_for_event(solutions[0]) - for s in solutions: - s.check_ys_are_not_too_large() - return solutions + solution = self._solve_for_event(solution, inputs_list) + solution.check_ys_are_not_too_large() + + return solution.split(model.len_rhs, model.len_alg, inputs_list) elif self.mode in ["safe", "safe without grid"]: y0 = model.y0 # Step-and-check @@ -190,19 +190,18 @@ def _integrate(self, model, t_eval, inputs_list=None): # to avoid having to create several times self.create_integrator(model, inputs) # Initialize solution - solutions = pybamm.Solution.from_concatenated_state( + solution = pybamm.Solution( np.array([t]), y0, model, - inputs_list, + inputs_list[0], sensitivities=False, ) - for s in solutions: - s.integration_time = 0 - s.solve_time = 0 + solution.integration_time = 0.0 + solution.solve_time = 0.0 use_grid = False else: - solutions = None + solution = None use_grid = True # Try to integrate in global steps of size dt_max. Note: dt_max must @@ -245,7 +244,7 @@ def _integrate(self, model, t_eval, inputs_list=None): "Running integrator for " f"{t_window[0]:.2f} < t < {t_window[-1]:.2f}" ) - current_step_sols = self._run_integrator( + current_step_sol = self._run_integrator( model, y0, inputs_list, @@ -293,19 +292,15 @@ def _integrate(self, model, t_eval, inputs_list=None): # Check if the sign of an event changes, if so find an accurate # termination point and exit # Note: this is only done for the first solution, is this correct? - current_step_sols[0] = self._solve_for_event(current_step_sols[0]) - if solutions is None: - for s in current_step_sols: - s.solve_time = np.nan - solutions = current_step_sols + current_step_sol = self._solve_for_event(current_step_sol, inputs_list) + if solution is None: + current_step_sol.solve_time = np.nan + solution = current_step_sol else: - for i, current_step_sol in enumerate(current_step_sols): - # assign temporary solve time - current_step_sol.solve_time = np.nan - # append solution from the current step to solution - solutions[i] = solutions[i] + current_step_sol + current_step_sol.solve_time = np.nan + solution = solution + current_step_sol - if current_step_sols[0].termination == "event": + if current_step_sol.termination == "event": break else: # update time as time @@ -313,16 +308,15 @@ def _integrate(self, model, t_eval, inputs_list=None): t = t_window[-1] # update y0 as initial_values # from which to start the new casadi integrator - y0 = casadi.vertcat(*[s.all_ys[-1][:, -1] for s in solutions]) + y0 = solution.all_ys[-1][:, -1] # now we extract sensitivities from the solution - for s in solutions: - if bool(model.calculate_sensitivities): - s.sensitivities = True - s.check_ys_are_not_too_large() - return solutions + if bool(model.calculate_sensitivities): + solution.sensitivities = True + solution.check_ys_are_not_too_large() + return solution.split(model.len_rhs, model.len_alg, inputs_list) - def _solve_for_event(self, coarse_solution): + def _solve_for_event(self, coarse_solution, inputs_list): """ Check if the sign of an event changes, if so find an accurate termination point and exit @@ -334,8 +328,7 @@ def _solve_for_event(self, coarse_solution): pybamm.logger.debug("Solving for events") model = coarse_solution.all_models[-1] - inputs_dict = coarse_solution.all_inputs[-1] - inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) + inputs = casadi.vertcat(*[x for inputs in inputs_list for x in inputs.values()]) def find_t_event(sol, typ): # Check most recent y to see if any events have been crossed @@ -344,7 +337,7 @@ def find_t_event(sol, typ): crossed_events = np.sign( np.concatenate( [ - event(sol.t[-1], y_last, inputs) + casadi.mmin(event(sol.t[-1], y_last, inputs)) for event in model.terminate_events_eval ] ) @@ -377,7 +370,9 @@ def f(idx, f_eval=f_eval, event=event): # We take away 1e-5 to deal with the case where the event sits # exactly on zero, as can happen when the event switch is used # (fast with events mode) - f_eval[idx] = event(sol.t[idx], sol.y[:, idx], inputs) - 1e-5 + f_eval[idx] = ( + casadi.mmin(event(sol.t[idx], sol.y[:, idx], inputs)) - 1e-5 + ) return f_eval[idx] def integer_bisect(): @@ -455,10 +450,10 @@ def integer_bisect(): use_grid = True y0 = coarse_solution.y[:, event_idx_lower] - [dense_step_sol] = self._run_integrator( + dense_step_sol = self._run_integrator( model, y0, - [inputs_dict], + inputs_list, inputs, t_window_event_dense, use_grid=use_grid, @@ -485,7 +480,7 @@ def integer_bisect(): t_sol, y_sol, model, - inputs_dict, + inputs_list[0], np.array([t_event]), y_event[:, np.newaxis], "event", @@ -656,8 +651,8 @@ def _run_integrator( else: integrator = self.integrators[model]["no grid"] - len_rhs = model.concatenated_rhs.size - len_alg = model.concatenated_algebraic.size + len_rhs = model.concatenated_rhs.size * len(inputs_list) + len_alg = model.concatenated_algebraic.size * len(inputs_list) # Check y0 to see if it includes sensitivities if explicit_sensitivities: @@ -702,17 +697,16 @@ def _run_integrator( y_sol = casadi.vertcat(x_sol, z_sol) else: y_sol = x_sol - sols = pybamm.Solution.from_concatenated_state( + sol = pybamm.Solution( t_eval, y_sol, model, - inputs_list, + inputs_list[0], sensitivities=extract_sensitivities_in_solution, check_solution=False, ) - for s in sols: - s.integration_time = integration_time - return sols + sol.integration_time = integration_time + return sol else: # Repeated calls to the integrator x = y0_diff @@ -743,14 +737,13 @@ def _run_integrator( else: y_sol = casadi.vertcat(y_diff, y_alg) - sols = pybamm.Solution.from_concatenated_state( + sol = pybamm.Solution( t_eval, y_sol, model, - inputs_list, + inputs_list[0], sensitivities=extract_sensitivities_in_solution, check_solution=False, ) - for s in sols: - s.integration_time = integration_time - return sols + sol.integration_time = integration_time + return sol diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 8b253dfdb2..fe1bbc6e49 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -138,6 +138,95 @@ def __init__( # Solution now uses CasADi pybamm.citations.register("Andersson2019") + def split(self, rhs_len, alg_len, inputs_list): + """ + split up the concatenated solution into a list of solutions for each input + the state vector is assumed to have the form: + [rhs0p0, rhs1p0, ..., rhs0p1, rhs1p1, ..., alg0p0, alg1p0, ..., alg0p1, alg1p1, ...] + """ + if not isinstance(self, Solution): + raise TypeError("split can only be called on a Solution object") + + ninputs = len(inputs_list) + if ninputs == 1: + return [self] + + if isinstance(self.all_ys[0], (casadi.DM, casadi.MX)): + all_ys_split = [ + [ + casadi.vertcat( + self.all_ys[i][(p * rhs_len) : (p * rhs_len + rhs_len), :], + self.all_ys[i][ + (p * alg_len + ninputs * rhs_len) : ( + p * alg_len + ninputs * rhs_len + alg_len + ), + :, + ], + ) + for i in range(len(self.all_ys)) + ] + for p in range(ninputs) + ] + y_events = [ + casadi.vertcat( + self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)], + self.y_event[ + (p * alg_len + ninputs * rhs_len) : ( + p * alg_len + ninputs * rhs_len + alg_len + ) + ], + ) + for p in range(ninputs) + ] + else: + all_ys_split = [ + [ + np.vstack( + [ + self.all_ys[i][(p * rhs_len) : (p * rhs_len + rhs_len)], + self.all_ys[i][ + (p * alg_len + ninputs * rhs_len) : ( + p * alg_len + ninputs * rhs_len + alg_len + ) + ], + ] + ) + for i in range(len(self.all_ys)) + ] + for p in range(ninputs) + ] + y_events = [ + np.vstack( + [ + self.y_event[(p * rhs_len) : (p * rhs_len + rhs_len)], + self.y_event[ + (p * alg_len + ninputs * rhs_len) : ( + p * alg_len + ninputs * rhs_len + alg_len + ) + ], + ] + ) + for p in range(ninputs) + ] + + ret = [ + type(self)( + self.all_ts, + all_ys, + self.all_models, + inputs, + self.t_event, + y_event, + self.termination, + self.sensitivities, + False, + ) + for all_ys, inputs, y_event in zip(all_ys_split, inputs_list, y_events) + ] + for sol in ret: + sol.integration_time = self.integration_time + return ret + @classmethod def from_concatenated_state( cls, @@ -184,6 +273,25 @@ def from_concatenated_state( for i in range(ninputs) ] + # @classmethod + # def to_concatenated_state(cls, solutions): + # solution = solutions[0] + # all_ys = [ + # [np.vstack([sol.all_ys[i] for sol in solutions])] + # for i in range(len(solution.all_ys)) + # ] + # return cls( + # solution._all_ts, + # all_ys, + # solution._all_models, + # solution._all_inputs, + # solution._t_event, + # solution._y_event, + # solution._termination, + # solution.sensitivities, + # False, + # ) + def extract_explicit_sensitivities(self): # if we got here, we haven't set y yet self.set_y() diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 3030f80af0..6a79d639f2 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -390,6 +390,7 @@ def test_model_solver_with_inputs(self): spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) + # Solve solver = pybamm.CasadiSolver(rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 10, 100) @@ -399,6 +400,17 @@ def test_model_solver_with_inputs(self): solution.y.full()[0], np.exp(-0.1 * solution.t), rtol=1e-04 ) + # multiple inputs + solver = pybamm.CasadiSolver(rtol=1e-8, atol=1e-8) + t_eval = np.linspace(0, 10, 100) + solutions = solver.solve(model, t_eval, inputs=[{"rate": 0.1}, {"rate": 0.2}]) + self.assertEqual(len(solutions), 2) + for solution, rate in zip(solutions, [0.1, 0.2]): + self.assertLess(len(solution.t), len(t_eval)) + np.testing.assert_allclose( + solution.y.full()[0], np.exp(-rate * solution.t), rtol=1e-04 + ) + # Without grid solver = pybamm.CasadiSolver(mode="safe without grid", rtol=1e-8, atol=1e-8) t_eval = np.linspace(0, 10, 100) @@ -413,6 +425,16 @@ def test_model_solver_with_inputs(self): solution.y.full()[0], np.exp(-1.1 * solution.t), rtol=1e-04 ) + # multiple inputs and without grid + solver = pybamm.CasadiSolver(mode="safe without grid", rtol=1e-8, atol=1e-8) + t_eval = np.linspace(0, 10, 100) + solutions = solver.solve(model, t_eval, inputs=[{"rate": 0.1}, {"rate": 0.2}]) + self.assertEqual(len(solutions), 2) + for solution, rate in zip(solutions, [0.1, 0.2]): + np.testing.assert_allclose( + solution.y.full()[0], np.exp(-rate * solution.t), rtol=1e-04 + ) + def test_model_solver_dae_inputs_in_initial_conditions(self): # Create model model = pybamm.BaseModel() diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index fad6651d55..e5647288dc 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -313,14 +313,15 @@ def test_model_solver_multiple_inputs_discontinuity_error(self): ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - def test_model_solver_multiple_inputs_initial_conditions_error(self): + def test_model_solver_multiple_inputs_initial_conditions(self): # Create model model = pybamm.BaseModel() model.convert_to_format = "casadi" domain = ["negative electrode", "separator", "positive electrode"] var = pybamm.Variable("var", domain=domain) - model.rhs = {var: -pybamm.InputParameter("rate") * var} - model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")} + rate = pybamm.InputParameter("rate") + model.rhs = {var: -rate * var} + model.initial_conditions = {var: 2 * rate} # create discretisation mesh = get_mesh_for_testing() spatial_methods = {"macroscale": pybamm.FiniteVolume()} @@ -332,11 +333,12 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): ninputs = 8 inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] - with self.assertRaisesRegex( - pybamm.SolverError, - ("Input parameters cannot appear in expression " "for initial conditions."), - ): - solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) + for inputs, solution in zip(inputs_list, solutions): + np.testing.assert_array_equal(solution.t, t_eval) + np.testing.assert_allclose( + solution.y[0], 2 * inputs["rate"] * np.exp(-inputs["rate"] * solution.t) + ) def test_model_solver_multiple_inputs_jax_format(self): if pybamm.have_jax():