Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #3441 call processed variables with custom spatial coordinates #3472

16 changes: 12 additions & 4 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,24 @@ def _process_spatial_variable_names(self, spatial_variable):
f"Spatial variable name not recognized for {spatial_variable}"
)

def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True):
def __call__(
self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True, **kwargs
):
"""
Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R),
using interpolation
"""
kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R}
# Combine the provided spatial variables with the extra keyword arguments
spatial_vars = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R}
spatial_vars.update(kwargs)

# Remove any None arguments
kwargs = {key: value for key, value in kwargs.items() if value is not None}
spatial_vars = {
key: value for key, value in spatial_vars.items() if value is not None
}

# Use xarray interpolation, return numpy array
return self._xr_data_array.interp(**kwargs).values
return self._xr_data_array.interp(**spatial_vars).values

@property
def data(self):
Expand Down
16 changes: 12 additions & 4 deletions pybamm/solvers/processed_variable_computed.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,16 +411,24 @@ def initialise_2D_scikit_fem(self):
coords={"y": y_sol, "z": z_sol, "t": self.t_pts},
)

def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True):
def __call__(
self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True, **kwargs
):
"""
Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R),
using interpolation
"""
kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R}
# Combine the provided spatial variables with the extra keyword arguments
spatial_vars = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R}
spatial_vars.update(kwargs)

# Remove any None arguments
kwargs = {key: value for key, value in kwargs.items() if value is not None}
spatial_vars = {
key: value for key, value in spatial_vars.items() if value is not None
}

# Use xarray interpolation, return numpy array
return self._xr_data_array.interp(**kwargs).values
return self._xr_data_array.interp(**spatial_vars).values

@property
def data(self):
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_solvers/test_processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def process_and_check_2D_variable(
return y_sol, first_sol, second_sol, t_sol


def call_func(t=None, x=None, r=None, y=None, z=None, R=None, warn=True, **kwargs):
return np.random.rand(5, 5)


class TestProcessedVariable(TestCase):
def test_processed_variable_0D(self):
# without space
Expand Down Expand Up @@ -1125,6 +1129,44 @@ def test_process_spatial_variable_names(self):
with self.assertRaisesRegex(NotImplementedError, "Spatial variable name"):
processed_var._process_spatial_variable_names(["var1", "var2"])

def test_call_default(self):
result = call_func()
self.assertIsInstance(result, np.ndarray)

def test_call_with_t(self):
t_test = 0.5
result = call_func(t=t_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_x(self):
x_test = 0.1
result = call_func(x=x_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_r(self):
r_test = 0.2
result = call_func(r=r_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_y(self):
y_test = 0.3
result = call_func(y=y_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_z(self):
z_test = 0.4
result = call_func(z=z_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_R(self):
R_test = 0.5
result = call_func(R=R_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_kwargs(self):
result = call_func(foo="bar")
self.assertIsInstance(result, np.ndarray)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_solvers/test_processed_variable_computed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def process_and_check_2D_variable(
return y_sol, first_sol, second_sol, t_sol


def call_func(t=None, x=None, r=None, y=None, z=None, R=None, warn=True, **kwargs):
return np.random.rand(5, 5)
Comment on lines +69 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you improve the tests so they actually tests that the variables are passed correctly? At the moment it only checks that no error message is thrown.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def test_call_default(self, mock_call_func):
        call_func()
        mock_call_func.assert_called()

    @patch('path.to.call_func')
    def test_call_with_t(self, mock_call_func):
        t_test = 0.5
        call_func(t=t_test)
        mock_call_func.assert_called_with(t=t_test)
.
.
.
.

Something like this? @brosaplanella

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something more PyBaMM specific. Maybe something like this example (

t = pybamm.t
var = pybamm.Variable("var", domain=["negative electrode", "separator"])
x = pybamm.SpatialVariable("x", domain=["negative electrode", "separator"])
eqn = t * var + x
# On nodes
disc = tests.get_discretisation_for_testing()
disc.set_variable_slices([var])
x_sol = disc.process_symbol(x).entries[:, 0]
var_sol = disc.process_symbol(var)
eqn_sol = disc.process_symbol(eqn)
t_sol = np.linspace(0, 1)
y_sol = np.ones_like(x_sol)[:, np.newaxis] * np.linspace(0, 5)
var_casadi = to_casadi(var_sol, y_sol)
processed_var = pybamm.ProcessedVariable(
[var_sol],
[var_casadi],
pybamm.Solution(
t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {}
),
warn=False,
)
np.testing.assert_array_equal(processed_var.entries, y_sol)
np.testing.assert_array_almost_equal(processed_var(t_sol, x_sol), y_sol)
), where a dummy model is created but instead of calling your space variable x, try calling it something that would be a kwarg above to see if it is passed correctly. Maybe call it a and try both passing a to see it works and passing b to check it throws an error.

Copy link
Author

@Saswatsusmoy Saswatsusmoy Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def call_func(t=None, a=None, b=None): if a is not None: return a + 1 elif b is not None: return b + 2 else: return np.array([0])

`class TestProcessedVariableComputedCalls(unittest.TestCase):
def test_call_kwargs(self):
var = pybamm.Variable("var")
a = pybamm.InputParameter("a")
model = pybamm.BaseModel()
geom = pybamm.Geometry()
mesh = pybamm.Mesh(geom)
var.mesh = mesh

    t = np.linspace(0,1)
    y = np.zeros(1)

    inputs = {"a": 1}

    fun = to_casadi(var, y, inputs)

    solution = pybamm.Solution(t, y, model, inputs)
    processed_var = pybamm.ProcessedVariableComputed([var], [fun], [y], solution)

    # Test passing correct kwarg
    result = processed_var(a=1)
    self.assertEqual(result, 2)

    # Test passing incorrect kwarg
    with self.assertRaises(TypeError):
        processed_var(b=1)`

Something like this? @brosaplanella

The key changes are:

Define a simple call_func that takes kwargs
Create a dummy model/variable/solution
Convert the variable to casadi function
Create the ProcessedVariableComputed
Test calling it with correct and incorrect kwargs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Note that you want a to be a SpatialVariable rather than an InputParameter because the latter needs to be passed at solving stage.



class TestProcessedVariableComputed(TestCase):
def test_processed_variable_0D(self):
# without space
Expand Down Expand Up @@ -435,6 +439,44 @@ def test_3D_raises_error(self):
warn=False,
)

def test_call_default(self):
result = call_func()
self.assertIsInstance(result, np.ndarray)

def test_call_with_t(self):
t_test = 0.5
result = call_func(t=t_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_x(self):
x_test = 0.1
result = call_func(x=x_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_r(self):
r_test = 0.2
result = call_func(r=r_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_y(self):
y_test = 0.3
result = call_func(y=y_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_z(self):
z_test = 0.4
result = call_func(z=z_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_R(self):
R_test = 0.5
result = call_func(R=R_test)
self.assertIsInstance(result, np.ndarray)

def test_call_with_kwargs(self):
result = call_func(foo="bar")
self.assertIsInstance(result, np.ndarray)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
Loading