Skip to content

Commit

Permalink
jax and python evaluator api now match casadi, sensitivities still br…
Browse files Browse the repository at this point in the history
…oken #4087
  • Loading branch information
martinjrobins committed Jul 5, 2024
1 parent 64a4a3b commit 88d56bd
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 111 deletions.
21 changes: 5 additions & 16 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def find_symbols(
symbol_str = "t"

elif isinstance(symbol, pybamm.InputParameter):
symbol_str = f'inputs["{input_slices[symbol.name]}"]'

input_slice = input_slices[symbol.name]
symbol_str = f"inputs[{input_slice}]"
else:
raise NotImplementedError(
f"Conversion to python not implemented for a symbol of type '{type(symbol)}'"
Expand Down Expand Up @@ -407,17 +407,7 @@ def to_python(
"""
constant_values: OrderedDict = OrderedDict()
variable_symbols: OrderedDict = OrderedDict()
input_slices = {}
i = 0
for input_dict in inputs:
for key, value in input_dict.items():
if isinstance(value, np.ndarray):
inc = value.shape[0]
input_slices[key] = slice(i, i + inc)
else:
inc = 1
input_slices[key] = i
i += inc
input_slices = pybamm.BaseSolver._input_dict_to_slices(inputs[0])
find_symbols(symbol, constant_values, variable_symbols, input_slices, output_jax)

line_format = "{} = {}"
Expand Down Expand Up @@ -526,6 +516,8 @@ def __call__(self, t=None, y=None, inputs=None):
else:
nstates = y.shape[0] // self._ninputs
nparams = len(inputs) // self._ninputs
print("nstates:", nstates)
print("nparams:", nparams)

results = [
self._evaluate(
Expand Down Expand Up @@ -809,8 +801,5 @@ def __call__(self, t=None, y=None, inputs=None):

# execute code
result = self._jac_evaluate(*self._constants, t, y, inputs)
result = {
key: value.reshape(value.shape[0], -1) for key, value in result.items()
}

return result
9 changes: 3 additions & 6 deletions pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,9 @@ def _integrate(self, model, t_eval, inputs_list=None):
Any input parameters to pass to the model when solving
"""
inputs_list = inputs_list or [{}]
if model.convert_to_format == "casadi":
inputs = casadi.vertcat(
*[x for inputs in inputs_list for x in inputs.values()]
)
else:
inputs = inputs_list
inputs = pybamm.BaseSolver._inputs_to_stacked_vect(
inputs_list, model.convert_to_format
)

y0 = model.y0
if isinstance(y0, casadi.DM):
Expand Down
92 changes: 62 additions & 30 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model.concatenated_initial_conditions,
"initial_conditions",
vars_for_processing,
inputs,
use_jacobian=False,
ninputs=len(inputs),
)
model.initial_conditions_eval = initial_conditions
model.jacp_initial_conditions_eval = jacp_ic
Expand All @@ -153,30 +153,27 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model.len_rhs + model.len_rhs_sens + model.len_alg + model.len_alg_sens
)
y_zero = np.zeros((y0_total_size, 1))

stacked_inputs = self._inputs_to_stacked_vect(inputs, model.convert_to_format)
if model.convert_to_format == "casadi":
# stack inputs
inputs_casadi = casadi.vertcat(
*[x for inpt in inputs for x in inpt.values()]
)
model.y0 = initial_conditions(0.0, y_zero, inputs_casadi)
model.y0 = initial_conditions(0.0, y_zero, stacked_inputs)
if jacp_ic is None:
model.y0S = None
else:
model.y0S = jacp_ic(0.0, y_zero, inputs_casadi)
model.y0S = jacp_ic(0.0, y_zero, stacked_inputs)
else:
model.y0 = initial_conditions(0.0, y_zero, inputs)
model.y0 = initial_conditions(0.0, y_zero, stacked_inputs)
if jacp_ic is None:
model.y0S = None
else:
# we are calculating the derivative wrt the inputs
# so need to make sure we convert int -> float
# This is to satisfy JAX jacfwd function which requires
# float inputs
inputs_float = {
key: float(value) if isinstance(value, int) else value
for key, value in inputs.items()
}
model.y0S = jacp_ic(0.0, y_zero, inputs_float)
if stacked_inputs.dtype == int:
stacked_inputs = stacked_inputs.astype(float)
model.y0S = jacp_ic(0.0, y_zero, stacked_inputs)

if ics_only:
pybamm.logger.info("Finish solver set-up")
Expand All @@ -192,14 +189,14 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
else:
rhs = model.concatenated_rhs
rhs, jac_rhs, jacp_rhs, jac_rhs_action = process(
rhs, "RHS", vars_for_processing, ninputs=len(inputs)
rhs, "RHS", vars_for_processing, inputs
)

algebraic, jac_algebraic, jacp_algebraic, jac_algebraic_action = process(
model.concatenated_algebraic,
"algebraic",
vars_for_processing,
ninputs=len(inputs),
inputs,
)

# combine rhs and algebraic functions
Expand All @@ -218,7 +215,10 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
jacp_rhs_algebraic,
jac_rhs_algebraic_action,
) = process(
rhs_algebraic, "rhs_algebraic", vars_for_processing, ninputs=len(inputs)
rhs_algebraic,
"rhs_algebraic",
vars_for_processing,
inputs,
)

(
Expand Down Expand Up @@ -286,9 +286,9 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model.variables_and_events[key],
BaseSolver._wrangle_name(key),
vars_for_processing,
inputs,
use_jacobian=True,
return_jacp_stacked=True,
ninputs=len(inputs),
)

pybamm.logger.info("Finish solver set-up")
Expand Down Expand Up @@ -554,8 +554,8 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
event_sigmoid,
f"event_{n}",
vars_for_processing,
inputs,
use_jacobian=False,
ninputs=len(inputs),
)[0]
# use the actual casadi object as this will go into the rhs
casadi_switch_events.append(event_casadi)
Expand All @@ -565,8 +565,8 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
event.expression,
f"event_{n}",
vars_for_processing,
inputs,
use_jacobian=False,
ninputs=len(inputs),
)[0]
if event.event_type == pybamm.EventType.TERMINATION:
terminate_events.append(event_call)
Expand Down Expand Up @@ -1393,6 +1393,38 @@ def _set_up_model_inputs(model, inputs):

return ordered_inputs

@staticmethod
def _inputs_to_stacked_vect(inputs_list: list[dict], convert_to_format: str):
if len(inputs_list) == 0 or len(inputs_list[0]) == 0:
return np.array([[]])
if convert_to_format == "casadi":
inputs = casadi.vertcat(
*[x for inputs in inputs_list for x in inputs.values()]
)
else:
arrays_to_stack = [
np.array(x).reshape(-1, 1)
for inputs in inputs_list
for x in inputs.values()
]
print(inputs_list, arrays_to_stack)
inputs = np.vstack(arrays_to_stack)
return inputs

@staticmethod
def _input_dict_to_slices(input_dict: dict):
input_slices = {}
i = 0
for key, value in input_dict.items():
if isinstance(value, np.ndarray):
inc = value.shape[0]
input_slices[key] = slice(i, i + inc)
else:
inc = 1
input_slices[key] = i
i += inc
return input_slices


def map_func_over_inputs_casadi(name, f, vars_for_processing, ninputs):
"""
Expand Down Expand Up @@ -1536,9 +1568,9 @@ def process(
symbol,
name,
vars_for_processing,
inputs: list[dict],
use_jacobian=None,
return_jacp_stacked=None,
ninputs=1,
):
"""
Parameters
Expand All @@ -1547,6 +1579,10 @@ def process(
expression tree to convert
name: str
function evaluators created will have this base name
vars_for_processing: dict
dictionary of variables for processing
inputs: list of dict
list of input parameters to pass to the model when solving
use_jacobian: bool, optional
whether to return Jacobian functions
return_jacp_stacked: bool, optional
Expand Down Expand Up @@ -1579,6 +1615,8 @@ def process(
i.e. $\\frac{\\partial f}{\\partial y} * v$
"""
ninputs = len(inputs)
is_event = "event" in name

def report(string):
# don't log event conversion
Expand All @@ -1592,7 +1630,7 @@ def report(string):

if model.convert_to_format == "jax":
report(f"Converting {name} to jax")
func = pybamm.EvaluatorJax(symbol)
func = pybamm.EvaluatorJax(symbol, inputs, is_event=is_event)
jacp = None
if model.calculate_sensitivities:
report(
Expand Down Expand Up @@ -1622,32 +1660,26 @@ def report(string):
p: symbol.diff(pybamm.InputParameter(p))
for p in model.calculate_sensitivities
}
jacp = pybamm.NumpyConcatenation(*[v for v in jacp_dict.values()])

report(f"Converting sensitivities for {name} to python")
jacp_dict = {
p: pybamm.EvaluatorPython(jacp) for p, jacp in jacp_dict.items()
}

# jacp should be a function that returns a dict of sensitivities
def jacp(*args, **kwargs):
return {k: v(*args, **kwargs) for k, v in jacp_dict.items()}

jacp = pybamm.EvaluatorPython(jacp, inputs)
else:
jacp = None

if use_jacobian:
report(f"Calculating jacobian for {name}")
jac = jacobian.jac(symbol, y)
report(f"Converting jacobian for {name} to python")
jac = pybamm.EvaluatorPython(jac)
jac = pybamm.EvaluatorPython(jac, inputs)
# cannot do jacobian action efficiently for now
jac_action = None
else:
jac = None
jac_action = None

report(f"Converting {name} to python")
func = pybamm.EvaluatorPython(symbol, is_event="event" in name)
func = pybamm.EvaluatorPython(symbol, inputs, is_event=is_event)

else:
t_casadi = vars_for_processing["t_casadi"]
Expand Down
Loading

0 comments on commit 88d56bd

Please sign in to comment.