diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 5ec4f47aa6..cdc0535c0c 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -1211,13 +1211,14 @@ def minimum(left, right): if out is not None: return out - k = pybamm.settings.min_smoothing + mode = pybamm.settings.min_max_mode + k = pybamm.settings.min_max_smoothing # Return exact approximation if that is the setting or the outcome is a constant # (i.e. no need for smoothing) - if k == "exact" or (left.is_constant() and right.is_constant()): + if mode == "exact" or (left.is_constant() and right.is_constant()): out = Minimum(left, right) - elif k == "smooth": - out = pybamm.smooth_minus(left, right, pybamm.settings.min_max_smoothing) + elif mode == "smooth": + out = pybamm.smooth_minus(left, right, k) else: out = pybamm.softminus(left, right, k) return pybamm.simplify_if_constant(out) @@ -1234,13 +1235,14 @@ def maximum(left, right): if out is not None: return out - k = pybamm.settings.max_smoothing + mode = pybamm.settings.min_max_mode + k = pybamm.settings.min_max_smoothing # Return exact approximation if that is the setting or the outcome is a constant # (i.e. no need for smoothing) - if k == "exact" or (left.is_constant() and right.is_constant()): + if mode == "exact" or (left.is_constant() and right.is_constant()): out = Maximum(left, right) - elif k == "smooth": - out = pybamm.smooth_plus(left, right, pybamm.settings.min_max_smoothing) + elif mode == "smooth": + out = pybamm.smooth_plus(left, right, k) else: out = pybamm.softplus(left, right, k) return pybamm.simplify_if_constant(out) diff --git a/pybamm/settings.py b/pybamm/settings.py index 299c17074e..4dc8db8151 100644 --- a/pybamm/settings.py +++ b/pybamm/settings.py @@ -6,8 +6,7 @@ class Settings(object): _debug_mode = False _simplify = True - _min_smoothing = "exact" - _max_smoothing = "exact" + _min_max_mode = "exact" _min_max_smoothing = 1000 _heaviside_smoothing = "exact" _abs_smoothing = "exact" @@ -45,35 +44,32 @@ def simplify(self, value): def set_smoothing_parameters(self, k): """Helper function to set all smoothing parameters""" - self.min_smoothing = k - self.max_smoothing = k + if k == "exact": + self.min_max_mode = "exact" + else: + self.min_max_smoothing = k + self.min_max_mode = "soft" self.heaviside_smoothing = k self.abs_smoothing = k @staticmethod def check_k(k): - if k != "exact" and k != "smooth" and k <= 0: + if k != "exact" and k <= 0: raise ValueError( - "Smoothing parameter must be 'exact', 'smooth', or a positive number" + "Smoothing parameter must be 'exact' or a strictly positive number" ) @property - def min_smoothing(self): - return self._min_smoothing + def min_max_mode(self): + return self._min_max_mode - @min_smoothing.setter - def min_smoothing(self, k): - self.check_k(k) - self._min_smoothing = k - - @property - def max_smoothing(self): - return self._max_smoothing - - @max_smoothing.setter - def max_smoothing(self, k): - self.check_k(k) - self._max_smoothing = k + @min_max_mode.setter + def min_max_mode(self, mode): + if mode not in ["exact", "soft", "smooth"]: + raise ValueError( + "Smoothing mode must be 'exact', 'soft', or 'smooth'" + ) + self._min_max_mode = mode @property def min_max_smoothing(self): @@ -81,7 +77,11 @@ def min_max_smoothing(self): @min_max_smoothing.setter def min_max_smoothing(self, k): - if k < 1: + if self._min_max_mode == "soft" and k <= 0: + raise ValueError( + "Smoothing parameter must be a strictly positive number" + ) + if self._min_max_mode == "smooth" and k < 1: raise ValueError( "Smoothing parameter must be greater than 1" ) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 487cffe877..9ced98d6fe 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -404,8 +404,8 @@ def test_softminus_softplus(self): ) # Test that smooth min/max are used when the setting is changed - pybamm.settings.min_smoothing = 10 - pybamm.settings.max_smoothing = 10 + pybamm.settings.min_max_mode = "soft" + pybamm.settings.min_max_smoothing = 10 self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.softminus(a, b, 10))) self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.softplus(a, b, 10))) @@ -417,8 +417,7 @@ def test_softminus_softplus(self): self.assertEqual(str(pybamm.maximum(a, b)), str(b)) # Change setting back for other tests - pybamm.settings.min_smoothing = "exact" - pybamm.settings.max_smoothing = "exact" + pybamm.settings.set_smoothing_parameters("exact") def test_smooth_minus_plus(self): a = pybamm.Scalar(1) @@ -444,8 +443,7 @@ def test_smooth_minus_plus(self): ) # Test that smooth min/max are used when the setting is changed - pybamm.settings.min_smoothing = "smooth" - pybamm.settings.max_smoothing = "smooth" + pybamm.settings.min_max_mode = "smooth" pybamm.settings.min_max_smoothing = 1 self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.smooth_minus(a, b, 1))) @@ -458,8 +456,7 @@ def test_smooth_minus_plus(self): self.assertEqual(str(pybamm.maximum(a, b)), str(b)) # Change setting back for other tests - pybamm.settings.min_smoothing = "exact" - pybamm.settings.max_smoothing = "exact" + pybamm.settings.set_smoothing_parameters("exact") def test_binary_simplifications(self): a = pybamm.Scalar(0) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index c2b19a2954..99310a42c1 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -16,29 +16,28 @@ def test_simplify(self): pybamm.settings.simplify = True def test_smoothing_parameters(self): - self.assertEqual(pybamm.settings.min_smoothing, "exact") - self.assertEqual(pybamm.settings.max_smoothing, "exact") + self.assertEqual(pybamm.settings.min_max_mode, "exact") self.assertEqual(pybamm.settings.heaviside_smoothing, "exact") self.assertEqual(pybamm.settings.abs_smoothing, "exact") pybamm.settings.set_smoothing_parameters(10) - self.assertEqual(pybamm.settings.min_smoothing, 10) - self.assertEqual(pybamm.settings.max_smoothing, 10) + self.assertEqual(pybamm.settings.min_max_smoothing, 10) self.assertEqual(pybamm.settings.heaviside_smoothing, 10) self.assertEqual(pybamm.settings.abs_smoothing, 10) pybamm.settings.set_smoothing_parameters("exact") # Test errors + with self.assertRaisesRegex(ValueError, "greater than 1"): + pybamm.settings.min_max_mode = "smooth" + pybamm.settings.min_max_smoothing = 0.9 with self.assertRaisesRegex(ValueError, "positive number"): - pybamm.settings.min_smoothing = -10 - with self.assertRaisesRegex(ValueError, "positive number"): - pybamm.settings.max_smoothing = -10 + pybamm.settings.min_max_mode = "soft" + pybamm.settings.min_max_smoothing = -10 with self.assertRaisesRegex(ValueError, "positive number"): pybamm.settings.heaviside_smoothing = -10 with self.assertRaisesRegex(ValueError, "positive number"): pybamm.settings.abs_smoothing = -10 - with self.assertRaisesRegex(ValueError, "greater than 1"): - pybamm.settings.min_max_smoothing = 0.9 + pybamm.settings.set_smoothing_parameters("exact") if __name__ == "__main__":