Skip to content

Commit

Permalink
Switching to smoothing mode
Browse files Browse the repository at this point in the history
  • Loading branch information
kratman committed Oct 2, 2023
1 parent e3db154 commit bf9e0b6
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 47 deletions.
18 changes: 10 additions & 8 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,13 +1211,14 @@ def minimum(left, right):
if out is not None:
return out

k = pybamm.settings.min_smoothing
mode = pybamm.settings.min_max_mode
k = pybamm.settings.min_max_smoothing
# Return exact approximation if that is the setting or the outcome is a constant
# (i.e. no need for smoothing)
if k == "exact" or (left.is_constant() and right.is_constant()):
if mode == "exact" or (left.is_constant() and right.is_constant()):
out = Minimum(left, right)
elif k == "smooth":
out = pybamm.smooth_minus(left, right, pybamm.settings.min_max_smoothing)
elif mode == "smooth":
out = pybamm.smooth_minus(left, right, k)
else:
out = pybamm.softminus(left, right, k)
return pybamm.simplify_if_constant(out)
Expand All @@ -1234,13 +1235,14 @@ def maximum(left, right):
if out is not None:
return out

k = pybamm.settings.max_smoothing
mode = pybamm.settings.min_max_mode
k = pybamm.settings.min_max_smoothing
# Return exact approximation if that is the setting or the outcome is a constant
# (i.e. no need for smoothing)
if k == "exact" or (left.is_constant() and right.is_constant()):
if mode == "exact" or (left.is_constant() and right.is_constant()):
out = Maximum(left, right)
elif k == "smooth":
out = pybamm.smooth_plus(left, right, pybamm.settings.min_max_smoothing)
elif mode == "smooth":
out = pybamm.smooth_plus(left, right, k)
else:
out = pybamm.softplus(left, right, k)
return pybamm.simplify_if_constant(out)
Expand Down
44 changes: 22 additions & 22 deletions pybamm/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
class Settings(object):
_debug_mode = False
_simplify = True
_min_smoothing = "exact"
_max_smoothing = "exact"
_min_max_mode = "exact"
_min_max_smoothing = 1000
_heaviside_smoothing = "exact"
_abs_smoothing = "exact"
Expand Down Expand Up @@ -45,43 +44,44 @@ def simplify(self, value):

def set_smoothing_parameters(self, k):
"""Helper function to set all smoothing parameters"""
self.min_smoothing = k
self.max_smoothing = k
if k == "exact":
self.min_max_mode = "exact"
else:
self.min_max_smoothing = k
self.min_max_mode = "soft"
self.heaviside_smoothing = k
self.abs_smoothing = k

@staticmethod
def check_k(k):
if k != "exact" and k != "smooth" and k <= 0:
if k != "exact" and k <= 0:
raise ValueError(
"Smoothing parameter must be 'exact', 'smooth', or a positive number"
"Smoothing parameter must be 'exact' or a strictly positive number"
)

@property
def min_smoothing(self):
return self._min_smoothing
def min_max_mode(self):
return self._min_max_mode

@min_smoothing.setter
def min_smoothing(self, k):
self.check_k(k)
self._min_smoothing = k

@property
def max_smoothing(self):
return self._max_smoothing

@max_smoothing.setter
def max_smoothing(self, k):
self.check_k(k)
self._max_smoothing = k
@min_max_mode.setter
def min_max_mode(self, mode):
if mode not in ["exact", "soft", "smooth"]:
raise ValueError(
"Smoothing mode must be 'exact', 'soft', or 'smooth'"
)
self._min_max_mode = mode

@property
def min_max_smoothing(self):
return self._min_max_smoothing

@min_max_smoothing.setter
def min_max_smoothing(self, k):
if k < 1:
if self._min_max_mode == "soft" and k <= 0:
raise ValueError(
"Smoothing parameter must be a strictly positive number"
)
if self._min_max_mode == "smooth" and k < 1:
raise ValueError(
"Smoothing parameter must be greater than 1"
)
Expand Down
13 changes: 5 additions & 8 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ def test_softminus_softplus(self):
)

# Test that smooth min/max are used when the setting is changed
pybamm.settings.min_smoothing = 10
pybamm.settings.max_smoothing = 10
pybamm.settings.min_max_mode = "soft"
pybamm.settings.min_max_smoothing = 10

self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.softminus(a, b, 10)))
self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.softplus(a, b, 10)))
Expand All @@ -417,8 +417,7 @@ def test_softminus_softplus(self):
self.assertEqual(str(pybamm.maximum(a, b)), str(b))

# Change setting back for other tests
pybamm.settings.min_smoothing = "exact"
pybamm.settings.max_smoothing = "exact"
pybamm.settings.set_smoothing_parameters("exact")

def test_smooth_minus_plus(self):
a = pybamm.Scalar(1)
Expand All @@ -444,8 +443,7 @@ def test_smooth_minus_plus(self):
)

# Test that smooth min/max are used when the setting is changed
pybamm.settings.min_smoothing = "smooth"
pybamm.settings.max_smoothing = "smooth"
pybamm.settings.min_max_mode = "smooth"

pybamm.settings.min_max_smoothing = 1
self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.smooth_minus(a, b, 1)))
Expand All @@ -458,8 +456,7 @@ def test_smooth_minus_plus(self):
self.assertEqual(str(pybamm.maximum(a, b)), str(b))

# Change setting back for other tests
pybamm.settings.min_smoothing = "exact"
pybamm.settings.max_smoothing = "exact"
pybamm.settings.set_smoothing_parameters("exact")

def test_binary_simplifications(self):
a = pybamm.Scalar(0)
Expand Down
17 changes: 8 additions & 9 deletions tests/unit/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,28 @@ def test_simplify(self):
pybamm.settings.simplify = True

def test_smoothing_parameters(self):
self.assertEqual(pybamm.settings.min_smoothing, "exact")
self.assertEqual(pybamm.settings.max_smoothing, "exact")
self.assertEqual(pybamm.settings.min_max_mode, "exact")
self.assertEqual(pybamm.settings.heaviside_smoothing, "exact")
self.assertEqual(pybamm.settings.abs_smoothing, "exact")

pybamm.settings.set_smoothing_parameters(10)
self.assertEqual(pybamm.settings.min_smoothing, 10)
self.assertEqual(pybamm.settings.max_smoothing, 10)
self.assertEqual(pybamm.settings.min_max_smoothing, 10)
self.assertEqual(pybamm.settings.heaviside_smoothing, 10)
self.assertEqual(pybamm.settings.abs_smoothing, 10)
pybamm.settings.set_smoothing_parameters("exact")

# Test errors
with self.assertRaisesRegex(ValueError, "greater than 1"):
pybamm.settings.min_max_mode = "smooth"
pybamm.settings.min_max_smoothing = 0.9
with self.assertRaisesRegex(ValueError, "positive number"):
pybamm.settings.min_smoothing = -10
with self.assertRaisesRegex(ValueError, "positive number"):
pybamm.settings.max_smoothing = -10
pybamm.settings.min_max_mode = "soft"
pybamm.settings.min_max_smoothing = -10
with self.assertRaisesRegex(ValueError, "positive number"):
pybamm.settings.heaviside_smoothing = -10
with self.assertRaisesRegex(ValueError, "positive number"):
pybamm.settings.abs_smoothing = -10
with self.assertRaisesRegex(ValueError, "greater than 1"):
pybamm.settings.min_max_smoothing = 0.9
pybamm.settings.set_smoothing_parameters("exact")


if __name__ == "__main__":
Expand Down

0 comments on commit bf9e0b6

Please sign in to comment.