Skip to content

Commit

Permalink
Merge pull request #3306 from pybamm-team/issue-3298-heaviside-shape
Browse files Browse the repository at this point in the history
Issue 3298 heaviside shape
  • Loading branch information
rtimms authored Sep 4, 2023
2 parents 840fb13 + 460e540 commit ea4cc7b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

## Bug fixes

- Fixed a bug with `_Heaviside._evaluate_for_shape` which meant some expressions involving heaviside function and subtractions did not work ([#3306](https://github.com/pybamm-team/PyBaMM/pull/3306))
- The `OneDimensionalX` thermal model has been updated to account for edge/tab cooling and account for the current collector volumetric heat capacity. It now gives the correct behaviour compared with a lumped model with the correct total heat transfer coefficient and surface area for cooling. ([#3042](https://github.com/pybamm-team/PyBaMM/pull/3042))
- Fixed a bug where the "basic" lithium-ion models gave incorrect results when using nonlinear particle diffusivity ([#3207](https://github.com/pybamm-team/PyBaMM/pull/3207))
- Particle size distributions now work with SPMe and NewmanTobias models ([#3207](https://github.com/pybamm-team/PyBaMM/pull/3207))
Expand Down
11 changes: 11 additions & 0 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,17 @@ def _binary_jac(self, left_jac, right_jac):
# need to worry about shape
return pybamm.Scalar(0)

def _evaluate_for_shape(self):
"""
Returns an array of NaNs of the correct shape.
See :meth:`pybamm.Symbol.evaluate_for_shape()`.
"""
left = self.children[0].evaluate_for_shape()
right = self.children[1].evaluate_for_shape()
# _binary_evaluate will return an array of bools, so we multiply by NaN to get
# an array of NaNs
return self._binary_evaluate(left, right) * np.nan


class EqualHeaviside(_Heaviside):
"""A heaviside function with equality (return 1 when left = right)"""
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,12 @@ def test_heaviside(self):
self.assertEqual(1 < b + 2, -1 < b)
self.assertEqual(b + 1 > 2, b > 1)

# expression with a subtract
expr = 2 * (b < 1) - (b > 3)
self.assertEqual(expr.evaluate(y=np.array([0])), 2)
self.assertEqual(expr.evaluate(y=np.array([2])), 0)
self.assertEqual(expr.evaluate(y=np.array([4])), -1)

def test_equality(self):
a = pybamm.Scalar(1)
b = pybamm.StateVector(slice(0, 1))
Expand Down

0 comments on commit ea4cc7b

Please sign in to comment.