diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b42fbb69c..2ec5e39b8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ## Bug fixes +- Fixed a bug with `_Heaviside._evaluate_for_shape` which meant some expressions involving heaviside function and subtractions did not work ([#3306](https://github.com/pybamm-team/PyBaMM/pull/3306)) - The `OneDimensionalX` thermal model has been updated to account for edge/tab cooling and account for the current collector volumetric heat capacity. It now gives the correct behaviour compared with a lumped model with the correct total heat transfer coefficient and surface area for cooling. ([#3042](https://github.com/pybamm-team/PyBaMM/pull/3042)) - Fixed a bug where the "basic" lithium-ion models gave incorrect results when using nonlinear particle diffusivity ([#3207](https://github.com/pybamm-team/PyBaMM/pull/3207)) - Particle size distributions now work with SPMe and NewmanTobias models ([#3207](https://github.com/pybamm-team/PyBaMM/pull/3207)) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 6794d201af..749384e9bc 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -508,6 +508,17 @@ def _binary_jac(self, left_jac, right_jac): # need to worry about shape return pybamm.Scalar(0) + def _evaluate_for_shape(self): + """ + Returns an array of NaNs of the correct shape. + See :meth:`pybamm.Symbol.evaluate_for_shape()`. + """ + left = self.children[0].evaluate_for_shape() + right = self.children[1].evaluate_for_shape() + # _binary_evaluate will return an array of bools, so we multiply by NaN to get + # an array of NaNs + return self._binary_evaluate(left, right) * np.nan + class EqualHeaviside(_Heaviside): """A heaviside function with equality (return 1 when left = right)""" diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 4e4bbb80cc..6acd7c41b0 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -324,6 +324,12 @@ def test_heaviside(self): self.assertEqual(1 < b + 2, -1 < b) self.assertEqual(b + 1 > 2, b > 1) + # expression with a subtract + expr = 2 * (b < 1) - (b > 3) + self.assertEqual(expr.evaluate(y=np.array([0])), 2) + self.assertEqual(expr.evaluate(y=np.array([2])), 0) + self.assertEqual(expr.evaluate(y=np.array([4])), -1) + def test_equality(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1))