diff --git a/examples/scripts/minimal_example_of_lookup_tables.py b/examples/scripts/minimal_example_of_lookup_tables.py new file mode 100644 index 0000000000..1c93e311c0 --- /dev/null +++ b/examples/scripts/minimal_example_of_lookup_tables.py @@ -0,0 +1,51 @@ +import pybamm +import pandas as pd +import numpy as np + + +def process_2D(name, data): + data = data.to_numpy() + x1 = np.unique(data[:, 0]) + x2 = np.unique(data[:, 1]) + + value = data[:, 2] + + x = (x1, x2) + + value_data = value.reshape(len(x1), len(x2), order="C") + + formatted_data = (name, (x, value_data)) + + return formatted_data + + +parameter_values = pybamm.ParameterValues(pybamm.parameter_sets.Chen2020) + +# overwrite the diffusion coefficient with a 2D lookup table +D_s_n = parameter_values["Negative electrode diffusivity [m2.s-1]"] +df = pd.DataFrame( + { + "T": [0, 0, 25, 25, 45, 45], + "sto": [0, 1, 0, 1, 0, 1], + "D_s_n": [D_s_n, D_s_n, D_s_n, D_s_n, D_s_n, D_s_n], + } +) +df["T"] = df["T"] + 273.15 +D_s_n_data = process_2D("Negative electrode diffusivity [m2.s-1]", df) + + +def D_s_n(sto, T): + name, (x, y) = D_s_n_data + return pybamm.Interpolant(x, y, [T, sto], name) + + +parameter_values["Negative electrode diffusivity [m2.s-1]"] = D_s_n + +k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"] + +model = pybamm.lithium_ion.DFN() +sim = pybamm.Simulation(model, parameter_values=parameter_values) + +sim.solve([0, 30]) + +sim.plot(["Negative particle surface concentration [mol.m-3]"]) diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index 8296625da0..1cb5e70d05 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -302,10 +302,11 @@ def _function_evaluate(self, evaluated_children): new_evaluated_children, self.function.grid ): nan_children.append(np.ones_like(child) * interp_range.mean()) - return self.function(np.transpose(nan_children)) * np.nan + nan_eval = self.function(np.transpose(nan_children)) + return np.reshape(nan_eval, shape) else: res = self.function(np.transpose(new_evaluated_children)) - return res[:, np.newaxis] + return np.reshape(res, shape) else: # pragma: no cover raise ValueError("Invalid dimension: {0}".format(self.dimension)) diff --git a/tests/unit/test_expression_tree/test_interpolant.py b/tests/unit/test_expression_tree/test_interpolant.py index 5fa078cffc..1d1a55e85c 100644 --- a/tests/unit/test_expression_tree/test_interpolant.py +++ b/tests/unit/test_expression_tree/test_interpolant.py @@ -131,7 +131,7 @@ def f(x, y): value = interp.evaluate(y=np.array([[1, 1, x[1]], [5, 4, y[1]]])) np.testing.assert_array_equal( - value, np.array([[f(1, 5)], [f(1, 4)], [f(x[1], y[1])]]) + value, np.array([[f(1, 5), f(1, 4), f(x[1], y[1])]]) ) # check also works for cubic @@ -192,6 +192,17 @@ def f(x, y): evaluated_children = [1, 4] value = interp._function_evaluate(evaluated_children) + # Test that the interpolant shape is the same as the input data shape + interp = pybamm.Interpolant(x_in, data, (var1, var2), interpolator="linear") + + evaluated_children = [np.array([[1, 1]]), np.array([[7, 7]])] + value = interp._function_evaluate(evaluated_children) + self.assertEqual(value.shape, evaluated_children[0].shape) + + evaluated_children = [np.array([[1, 1], [1, 1]]), np.array([[7, 7], [7, 7]])] + value = interp._function_evaluate(evaluated_children) + self.assertEqual(value.shape, evaluated_children[0].shape) + def test_interpolation_3_x(self): def f(x, y, z): return 2 * x**3 + 3 * y**2 - z @@ -216,7 +227,7 @@ def f(x, y, z): value = interp.evaluate(y=np.array([[1, 1, 1], [5, 4, 4], [8, 7, 7]])) np.testing.assert_array_equal( - value, np.array([[f(1, 5, 8)], [f(1, 4, 7)], [f(1, 4, 7)]]) + value, np.array([[f(1, 5, 8), f(1, 4, 7), f(1, 4, 7)]]) ) # check also works for cubic diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 37ec89068f..3610b53424 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -559,7 +559,7 @@ def test_process_interpolant_2d(self): processed_func = parameter_values.process_symbol(func) self.assertIsInstance(processed_func, pybamm.Interpolant) self.assertAlmostEqual( - processed_func.evaluate(inputs={"a": 3.01, "b": 4.4})[0][0], 14.82 + processed_func.evaluate(inputs={"a": 3.01, "b": 4.4}), 14.82 ) # process differentiated function parameter