Skip to content

Commit

Permalink
Improve testing / code-coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
jsbrittain committed Aug 31, 2023
1 parent 58cee45 commit 2fb4e55
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
33 changes: 17 additions & 16 deletions pybamm/solvers/processed_variable_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,22 @@ def _unroll_nnz(self, realdata=None):
# unroll in nnz != numel, otherwise copy
if realdata is None:
realdata = self.base_variables_data
sp = self.base_variables_casadi[0](0, 0, 0).sparsity()
if sp.nnz() != sp.numel():
data = [None] * len(realdata)
for datak in range(len(realdata)):
data[datak] = np.zeros(self.base_eval_shape[0] * len(self.t_pts))
var_data = realdata[0].flatten()
k = 0
for t_i in range(len(self.t_pts)):
base = t_i * sp.numel()
for r in sp.row():
data[datak][base + r] = var_data[k]
k = k + 1
else:
data = realdata
return data
# sp = self.base_variables_casadi[0](0, 0, 0).sparsity()
# if sp.nnz() != sp.numel():
# data = [None] * len(realdata)
# for datak in range(len(realdata)):
# data[datak] = np.zeros(self.base_eval_shape[0] * len(self.t_pts))
# var_data = realdata[0].flatten()
# k = 0
# for t_i in range(len(self.t_pts)):
# base = t_i * sp.numel()
# for r in sp.row():
# data[datak][base + r] = var_data[k]
# k = k + 1
# else:
# data = realdata
# return data
return realdata

def unroll_0D(self, realdata=None):
if realdata is None:
Expand Down Expand Up @@ -172,7 +173,7 @@ def unroll(self, realdata=None):
return self.unroll_2D(realdata=realdata)
else:
# Raise error for 3D variable
raise NotImplementedError("Unsupported data dimension: {self.dimensions}")
raise NotImplementedError(f"Unsupported data dimension: {self.dimensions}")

def initialise_0D(self):
entries = self.unroll_0D()
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def test_model_events(self):
# create discretisation
disc = pybamm.Discretisation()
model_disc = disc.process_model(model, inplace=False)
# Invalid atol (dict) raises error, valid options are float or ndarray
self.assertRaises(
pybamm.SolverError,
pybamm.IDAKLUSolver(
rtol=1e-8, atol={'key': 'value'},
root_method=root_method)
)
# output_variables only valid with convert_to_format=='casadi'
if form == "python" or form == "jax":
self.assertRaises(
pybamm.SolverError,
pybamm.IDAKLUSolver(
rtol=1e-8, atol=1e-8,
output_variables=['var'],
root_method=root_method)
)
# Solve
solver = pybamm.IDAKLUSolver(rtol=1e-8, atol=1e-8, root_method=root_method)
t_eval = np.linspace(0, 1, 100)
Expand Down
48 changes: 47 additions & 1 deletion tests/unit/test_solvers/test_processed_variable_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,19 @@ def test_processed_variable_0D(self):
pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
warn=False,
)
# Assert that the processed variable is the same as the solution
np.testing.assert_array_equal(processed_var.entries, y_sol[0])
# Check that 'data' produces the same output as 'entries'
np.testing.assert_array_equal(processed_var.entries, processed_var.data)

# check empty sensitivity works
# Check unroll function
np.testing.assert_array_equal(processed_var.unroll(), y_sol[0])

# Check cumtrapz workflow produces no errors
processed_var.cumtrapz_ic = 1
processed_var.initialise_0D()

# check empty sensitivity works
def test_processed_variable_0D_no_sensitivity(self):
# without space
t = pybamm.t
Expand Down Expand Up @@ -153,8 +162,33 @@ def test_processed_variable_1D(self):
# the full solver
y_sol = y_sol.reshape((y_sol.shape[1], y_sol.shape[0])).transpose()
np.testing.assert_array_equal(processed_var.entries, y_sol)
np.testing.assert_array_equal(processed_var.entries, processed_var.data)
np.testing.assert_array_almost_equal(processed_var(t_sol, x_sol), y_sol)

# Check unroll function
np.testing.assert_array_equal(processed_var.unroll(), y_sol)

# Check no error when data dimension is transposed vs node/edge
processed_var.mesh.nodes, processed_var.mesh.edges = \
processed_var.mesh.edges, processed_var.mesh.nodes
processed_var.initialise_1D()
processed_var.mesh.nodes, processed_var.mesh.edges = \
processed_var.mesh.edges, processed_var.mesh.nodes

# Check no errors with domain-specific attributes
# (see ProcessedVariableVar.initialise_2D() for details)
domain_list = [
["particle", "electrode"],
["separator", "current collector"],
["particle", "particle size"],
["particle size", "electrode"],
["particle size", "current collector"]
]
for domain, secondary in domain_list:
processed_var.domain[0] = domain
processed_var.domains["secondary"] = [secondary]
processed_var.initialise_1D()

def test_processed_variable_1D_unknown_domain(self):
x = pybamm.SpatialVariable("x", domain="SEI layer", coord_sys="cartesian")
geometry = pybamm.Geometry(
Expand Down Expand Up @@ -218,6 +252,18 @@ def test_processed_variable_2D_space_only(self):
processed_var.entries,
np.reshape(y_sol, [len(r_sol), len(x_sol), len(t_sol)]),
)
np.testing.assert_array_equal(
processed_var.entries,
processed_var.data,
)

# Check unroll function (2D)
np.testing.assert_array_equal(processed_var.unroll(), y_sol.reshape(10, 40, 1))

# Check unroll function (3D)
with self.assertRaises(NotImplementedError):
processed_var.dimensions = 3
processed_var.unroll()

def test_processed_variable_2D_fixed_t_scikit(self):
var = pybamm.Variable("var", domain=["current collector"])
Expand Down

0 comments on commit 2fb4e55

Please sign in to comment.