diff --git a/CHANGELOG.md b/CHANGELOG.md index a4616b0a0..58671a6bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ Bug Fixes - Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version - Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before. +- Fixes bug in ``softmin/softmax`` implementation. + v0.12.3 ------- diff --git a/desc/backend.py b/desc/backend.py index 9ecb95460..768a39323 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -76,7 +76,7 @@ from jax.numpy.fft import irfft, rfft, rfft2 from jax.scipy.fft import dct, idct from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular - from jax.scipy.special import gammaln, logsumexp + from jax.scipy.special import gammaln from jax.tree_util import ( register_pytree_node, tree_flatten, @@ -445,7 +445,7 @@ def tangent_solve(g, y): qr, solve_triangular, ) - from scipy.special import gammaln, logsumexp # noqa: F401 + from scipy.special import gammaln # noqa: F401 from scipy.special import softmax as softargmax # noqa: F401 trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz diff --git a/desc/objectives/_geometry.py b/desc/objectives/_geometry.py index 38950d885..b2d7abc2f 100644 --- a/desc/objectives/_geometry.py +++ b/desc/objectives/_geometry.py @@ -496,12 +496,9 @@ class PlasmaVesselDistance(_Objective): False by default, so that self.things = [eq, surface] Both cannot be True. softmin_alpha: float, optional - Parameter used for softmin. The larger softmin_alpha, the closer the softmin - approximates the hardmin. softmin -> hardmin as softmin_alpha -> infinity. - if softmin_alpha*array < 1, the underlying softmin will automatically multiply - the array by 2/min_val to ensure that softmin_alpha*array>1. Making - softmin_alpha larger than this minimum value will make the softmin a - more accurate approximation of the true min. + Parameter used for softmin. The larger ``softmin_alpha``, the closer the + softmin approximates the hardmin. softmin -> hardmin as + ``softmin_alpha`` -> infinity. """ diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 161c3f057..d02e70bc7 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -5,7 +5,7 @@ import numpy as np -from desc.backend import cond, jit, jnp, logsumexp, put +from desc.backend import jit, jnp, put, softargmax from desc.io import IOAble from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif @@ -285,14 +285,6 @@ def __call__(self, x_reduced): def softmax(arr, alpha): """JAX softmax implementation. - Inspired by https://www.johndcook.com/blog/2010/01/13/soft-maximum/ - and https://www.johndcook.com/blog/2010/01/20/how-to-compute-the-soft-maximum/ - - Will automatically multiply array values by 2 / min_val if the min_val of - the array is <1. This is to avoid inaccuracies that arise when values <1 - are present in the softmax, which can cause inaccurate maxes or even incorrect - signs of the softmax versus the actual max. - Parameters ---------- arr : ndarray @@ -309,18 +301,7 @@ def softmax(arr, alpha): """ arr_times_alpha = alpha * arr - min_val = jnp.min(jnp.abs(arr_times_alpha)) + 1e-4 # buffer value in case min is 0 - return cond( - jnp.any(min_val < 1), - lambda arr_times_alpha: logsumexp( - arr_times_alpha / min_val * 2 - ) # adjust to make vals>1 - / alpha - * min_val - / 2, - lambda arr_times_alpha: logsumexp(arr_times_alpha) / alpha, - arr_times_alpha, - ) + return softargmax(arr_times_alpha).dot(arr) def softmin(arr, alpha): diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 52d43960d..2c2abdf4a 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -718,7 +718,7 @@ def test_plasma_vessel_distance(self): warnings.simplefilter("error") obj.build() - # test softmin, should give value less than true minimum + # test softmin, should give approximate value surf_grid = LinearGrid(M=5, N=6) plas_grid = LinearGrid(M=5, N=6) obj = PlasmaVesselDistance( @@ -727,11 +727,12 @@ def test_plasma_vessel_distance(self): surface_grid=surf_grid, surface=surface, use_softmin=True, + softmin_alpha=5, ) obj.build() d = obj.compute_unscaled(*obj.xs(eq, surface)) assert d.size == obj.dim_f - assert np.all(np.abs(d) < a_s - a_p) + np.testing.assert_allclose(np.abs(d).min(), a_s - a_p, rtol=1.5e-1) # for large enough alpha, should be same as actual min obj = PlasmaVesselDistance( @@ -2308,20 +2309,18 @@ def test_objective_target_bounds(): def test_softmax_and_softmin(): """Test softmax and softmin function.""" arr = np.arange(-17, 17, 5) - # expect this to not be equal to the max but rather be more - # since softmax is a conservative estimate of the max + # expect this to not be equal to the max but approximately so sftmax = softmax(arr, alpha=1) - assert sftmax >= np.max(arr) + np.testing.assert_allclose(sftmax, np.max(arr), rtol=1e-2) # expect this to be equal to the max # as alpha -> infinity, softmax -> max sftmax = softmax(arr, alpha=100) np.testing.assert_almost_equal(sftmax, np.max(arr)) - # expect this to not be equal to the min but rather be less - # since softmin is a conservative estimate of the min + # expect this to not be equal to the min but approximately so sftmin = softmin(arr, alpha=1) - assert sftmin <= np.min(arr) + np.testing.assert_allclose(sftmin, np.min(arr), rtol=1e-2) # expect this to be equal to the min # as alpha -> infinity, softmin -> min