Skip to content

Commit

Permalink
Allow using IDAKLU(output_variables=...) with Experiments (#4534)
Browse files Browse the repository at this point in the history
* Add test for idaklu+output_variables+experiment

* edit Solution.last_state to pull y_event if all_ys is empty

* Ensure ProcessedVariableComputed variables are passed through Solution copies during an Experiment
Don't compute 'Change in x' summary variables if output_variables are specified

* populate first_state using the initial condition if output_variables used
Remove warnings about 'Change in x' summary variables

* Add to computed processed variable tests

* Add test for solution::add with computed variables

* add test for solution::copy with computed variables

* add check for idaklu on copy test

* Add 'variables_returned' attribute to Solution
Indicates if 'output_variables' are specified in solver
and therefore empty state vector

* Use `variables_returned` in `_update_variable()`, update test

* Update CHANGELOG

* Add test
  • Loading branch information
pipliggins authored Oct 28, 2024
1 parent 9560875 commit 3864e5d
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Adds support to `pybamm.Experiment` for the `output_variables` option in the `IDAKLUSolver`. ([#4534](https://github.com/pybamm-team/PyBaMM/pull/4534))
- Adds an option "voltage as a state" that can be "false" (default) or "true". If "true" adds an explicit algebraic equation for the voltage. ([#4507](https://github.com/pybamm-team/PyBaMM/pull/4507))
- Improved `QuickPlot` accuracy for simulations with Hermite interpolation. ([#4483](https://github.com/pybamm-team/PyBaMM/pull/4483))
- Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464))
Expand Down
1 change: 1 addition & 0 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,7 @@ def get_termination_reason(solution, events):
solution.t_event,
solution.y_event,
solution.termination,
variables_returned=solution.variables_returned,
)
event_sol.solve_time = 0
event_sol.integration_time = 0
Expand Down
1 change: 1 addition & 0 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict):
termination,
all_sensitivities=yS_out,
all_yps=yp,
variables_returned=bool(save_outputs_only),
)

newsol.integration_time = integration_time
Expand Down
27 changes: 26 additions & 1 deletion src/pybamm/solvers/processed_variable_computed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Processed Variable class
# Processed Variable Computed class
#
from __future__ import annotations
import casadi
import numpy as np
import pybamm
Expand Down Expand Up @@ -450,3 +451,27 @@ def sensitivities(self):
if len(self.all_inputs[0]) == 0:
return {}
return self._sensitivities

def _update(
self, other: pybamm.ProcessedVariableComputed, new_sol: pybamm.Solution
) -> pybamm.ProcessedVariableComputed:
"""
Returns a new ProcessedVariableComputed object that is the result of appending
the data from other to this object. Used exclusively in running experiments, to
append the data from one cycle to the next.
Parameters
----------
other : :class:`pybamm.ProcessedVariableComputed`
The other ProcessedVariableComputed object to append to this one
new_sol : :class:`pybamm.Solution`
The new solution object to be used to create the processed variables
"""

bv = self.base_variables + other.base_variables
bvc = self.base_variables_casadi + other.base_variables_casadi
bvd = self.base_variables_data + other.base_variables_data

new_var = self.__class__(bv, bvc, bvd, new_sol)

return new_var
60 changes: 49 additions & 11 deletions src/pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class Solution:
True if sensitivities included as the solution of the explicit forwards
equations. False if no sensitivities included/wanted. Dict if sensitivities are
provided as a dict of {parameter: [sensitivities]} pairs.
variables_returned: bool
Bool to indicate if `all_ys` contains the full state vector, or is empty because
only requested variables have been returned. True if `output_variables` is used
with a solver, otherwise False.
"""

Expand All @@ -76,6 +80,7 @@ def __init__(
termination="final time",
all_sensitivities=False,
all_yps=None,
variables_returned=False,
check_solution=True,
):
if not isinstance(all_ts, list):
Expand All @@ -93,6 +98,8 @@ def __init__(
all_yps = [all_yps]
self._all_yps = all_yps

self.variables_returned = variables_returned

# Set up inputs
if not isinstance(all_inputs, list):
all_inputs_copy = dict(all_inputs)
Expand Down Expand Up @@ -460,9 +467,15 @@ def first_state(self):
else:
all_yps = self.all_yps[0][:, :1]

if not self.variables_returned:
all_ys = self.all_ys[0][:, :1]
else:
# Get first state from initial conditions as all_ys is empty
all_ys = self.all_models[0].y0full.reshape(-1, 1)

new_sol = Solution(
self.all_ts[0][:1],
self.all_ys[0][:, :1],
all_ys,
self.all_models[:1],
self.all_inputs[:1],
None,
Expand Down Expand Up @@ -500,9 +513,15 @@ def last_state(self):
else:
all_yps = self.all_yps[-1][:, -1:]

if not self.variables_returned:
all_ys = self.all_ys[-1][:, -1:]
else:
# Get last state from y_event as all_ys is empty
all_ys = self.y_event.reshape(len(self.y_event), 1)

new_sol = Solution(
self.all_ts[-1][-1:],
self.all_ys[-1][:, -1:],
all_ys,
self.all_models[-1:],
self.all_inputs[-1:],
self.t_event,
Expand Down Expand Up @@ -580,15 +599,11 @@ def _update_variable(self, variable):
# Iterate through all models, some may be in the list several times and
# therefore only get set up once
vars_casadi = []
for i, (model, ts, ys, inputs, var_pybamm) in enumerate(
zip(self.all_models, self.all_ts, self.all_ys, self.all_inputs, vars_pybamm)
for i, (model, ys, inputs, var_pybamm) in enumerate(
zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm)
):
if (
ys.size == 0
and var_pybamm.has_symbol_of_classes(
pybamm.expression_tree.state_vector.StateVector
)
and not ts.size == 0
if self.variables_returned and var_pybamm.has_symbol_of_classes(
pybamm.expression_tree.state_vector.StateVector
):
raise KeyError(
f"Cannot process variable '{variable}' as it was not part of the "
Expand Down Expand Up @@ -682,7 +697,7 @@ def __getitem__(self, key):
Returns
-------
:class:`pybamm.ProcessedVariable`
:class:`pybamm.ProcessedVariable` or :class:`pybamm.ProcessedVariableComputed`
A variable that can be evaluated at any time or spatial point. The
underlying data for this variable is available in its attribute ".data"
"""
Expand Down Expand Up @@ -950,6 +965,7 @@ def __add__(self, other):
other.termination,
all_sensitivities=all_sensitivities,
all_yps=all_yps,
variables_returned=other.variables_returned,
)

new_sol.closest_event_idx = other.closest_event_idx
Expand All @@ -966,6 +982,19 @@ def __add__(self, other):
# Set sub_solutions
new_sol._sub_solutions = self.sub_solutions + other.sub_solutions

# update variables which were derived at the solver stage
if other._variables and all(
isinstance(v, pybamm.ProcessedVariableComputed)
for v in other._variables.values()
):
if not self._variables:
new_sol._variables = other._variables.copy()
else:
new_sol._variables = {
v: self._variables[v]._update(other._variables[v], new_sol)
for v in self._variables.keys()
}

return new_sol

def __radd__(self, other):
Expand All @@ -983,6 +1012,7 @@ def copy(self):
self.termination,
self._all_sensitivities,
self.all_yps,
self.variables_returned,
)
new_sol._all_inputs_casadi = self.all_inputs_casadi
new_sol._sub_solutions = self.sub_solutions
Expand All @@ -992,6 +1022,13 @@ def copy(self):
new_sol.integration_time = self.integration_time
new_sol.set_up_time = self.set_up_time

# copy over variables which were derived at the solver stage
if self._variables and all(
isinstance(v, pybamm.ProcessedVariableComputed)
for v in self._variables.values()
):
new_sol._variables = self._variables.copy()

return new_sol

def plot_voltage_components(
Expand Down Expand Up @@ -1094,6 +1131,7 @@ def make_cycle_solution(
sum_sols.termination,
sum_sols._all_sensitivities,
sum_sols.all_yps,
sum_sols.variables_returned,
)
cycle_solution._all_inputs_casadi = sum_sols.all_inputs_casadi
cycle_solution._sub_solutions = sum_sols.sub_solutions
Expand Down
52 changes: 52 additions & 0 deletions tests/integration/test_solvers/test_idaklu.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,55 @@ def test_interpolation(self):
# test that y[1:3] = to true solution
true_solution = b_value * sol.t
np.testing.assert_array_almost_equal(sol.y[1:3], true_solution)

def test_with_experiments(self):
summary_vars = []
sols = []
for out_vars in [True, False]:
model = pybamm.lithium_ion.SPM()

if out_vars:
output_variables = [
"Discharge capacity [A.h]", # 0D variables
"Time [s]",
"Current [A]",
"Voltage [V]",
"Pressure [Pa]", # 1D variable
"Positive particle effective diffusivity [m2.s-1]", # 2D variable
]
else:
output_variables = None

solver = pybamm.IDAKLUSolver(output_variables=output_variables)

experiment = pybamm.Experiment(
[
(
"Charge at 1C until 4.2 V",
"Hold at 4.2 V until C/50",
"Rest for 1 hour",
)
]
)

sim = pybamm.Simulation(
model,
experiment=experiment,
solver=solver,
)

sol = sim.solve()
sols.append(sol)
summary_vars.append(sol.summary_variables)

# check computed variables are propegated sucessfully
np.testing.assert_array_equal(
sols[0]["Pressure [Pa]"].data, sols[1]["Pressure [Pa]"].data
)
np.testing.assert_array_almost_equal(
sols[0]["Voltage [V]"].data, sols[1]["Voltage [V]"].data
)

# check summary variables are the same if output variables are specified
for var in summary_vars[0].keys():
assert summary_vars[0][var] == summary_vars[1][var]
3 changes: 3 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,9 @@ def construct_model():
with pytest.raises(KeyError):
sol[varname].data

# Check Solution is marked
assert sol.variables_returned is True

# Mock a 1D current collector and initialise (none in the model)
sol["x_s [m]"].domain = ["current collector"]
sol["x_s [m]"].entries
Expand Down
75 changes: 73 additions & 2 deletions tests/unit/test_solvers/test_processed_variable_computed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Tests for the Processed Variable Computed class
#
# This class forms a container for variables (and sensitivities) calculted
# This class forms a container for variables (and sensitivities) calculated
# by the idaklu solver, and does not possesses any capability to calculate
# values itself since it does not have access to the full state vector
#
Expand Down Expand Up @@ -76,11 +76,12 @@ def test_processed_variable_0D(self):
t_sol = np.array([0])
y_sol = np.array([1])[:, np.newaxis]
var_casadi = to_casadi(var, y_sol)
sol = pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {})
processed_var = pybamm.ProcessedVariableComputed(
[var],
[var_casadi],
[y_sol],
pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
sol,
)
# Assert that the processed variable is the same as the solution
np.testing.assert_array_equal(processed_var.entries, y_sol[0])
Expand All @@ -94,6 +95,22 @@ def test_processed_variable_0D(self):
processed_var.cumtrapz_ic = 1
processed_var.entries

# check _update
t_sol2 = np.array([1])
y_sol2 = np.array([2])[:, np.newaxis]
var_casadi = to_casadi(var, y_sol2)
sol_2 = pybamm.Solution(t_sol2, y_sol2, pybamm.BaseModel(), {})
processed_var2 = pybamm.ProcessedVariableComputed(
[var],
[var_casadi],
[y_sol2],
sol_2,
)

comb_sol = sol + sol_2
comb_var = processed_var._update(processed_var2, comb_sol)
np.testing.assert_array_equal(comb_var.entries, np.append(y_sol, y_sol2))

# check empty sensitivity works
def test_processed_variable_0D_no_sensitivity(self):
# without space
Expand Down Expand Up @@ -217,6 +234,60 @@ def test_processed_variable_1D_unknown_domain(self):
c_casadi = to_casadi(c, y_sol)
pybamm.ProcessedVariableComputed([c], [c_casadi], [y_sol], solution)

def test_processed_variable_1D_update(self):
# variable 1
var = pybamm.Variable("var", domain=["negative electrode", "separator"])
x = pybamm.SpatialVariable("x", domain=["negative electrode", "separator"])

disc = tests.get_discretisation_for_testing()
disc.set_variable_slices([var])
x_sol1 = disc.process_symbol(x).entries[:, 0]
var_sol1 = disc.process_symbol(var)
t_sol1 = np.linspace(0, 1)
y_sol1 = np.ones_like(x_sol1)[:, np.newaxis] * np.linspace(0, 5)

var_casadi1 = to_casadi(var_sol1, y_sol1)
sol1 = pybamm.Solution(t_sol1, y_sol1, pybamm.BaseModel(), {})
processed_var1 = pybamm.ProcessedVariableComputed(
[var_sol1],
[var_casadi1],
[y_sol1],
sol1,
)

# variable 2 -------------------
var2 = pybamm.Variable("var2", domain=["negative electrode", "separator"])
z = pybamm.SpatialVariable("z", domain=["negative electrode", "separator"])

disc = tests.get_discretisation_for_testing()
disc.set_variable_slices([var2])
z_sol2 = disc.process_symbol(z).entries[:, 0]
var_sol2 = disc.process_symbol(var2)
t_sol2 = np.linspace(2, 3)
y_sol2 = np.ones_like(z_sol2)[:, np.newaxis] * np.linspace(5, 1)

var_casadi2 = to_casadi(var_sol2, y_sol2)
sol2 = pybamm.Solution(t_sol2, y_sol2, pybamm.BaseModel(), {})
var_2 = pybamm.ProcessedVariableComputed(
[var_sol2],
[var_casadi2],
[y_sol2],
sol2,
)

comb_sol = sol1 + sol2
comb_var = processed_var1._update(var_2, comb_sol)

# Ordering from idaklu with output_variables set is different to
# the full solver
y_sol1 = y_sol1.reshape((y_sol1.shape[1], y_sol1.shape[0])).transpose()
y_sol2 = y_sol2.reshape((y_sol2.shape[1], y_sol2.shape[0])).transpose()

np.testing.assert_array_equal(
comb_var.entries, np.concatenate((y_sol1, y_sol2), axis=1)
)
np.testing.assert_array_equal(comb_var.entries, comb_var.data)

def test_processed_variable_2D_x_r(self):
var = pybamm.Variable(
"var",
Expand Down
Loading

0 comments on commit 3864e5d

Please sign in to comment.