Skip to content

Commit

Permalink
Merge branch 'master' into dp/update-stage-two-weights
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Nov 22, 2024
2 parents 5ec0a48 + 2741269 commit c8b1558
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 37 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Bug Fixes

- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version
- Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before.
- Fixes bug in ``softmin/softmax`` implementation.


v0.12.3
-------
Expand Down
4 changes: 2 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
from jax.scipy.special import gammaln
from jax.tree_util import (
register_pytree_node,
tree_flatten,
Expand Down Expand Up @@ -445,7 +445,7 @@ def tangent_solve(g, y):
qr,
solve_triangular,
)
from scipy.special import gammaln, logsumexp # noqa: F401
from scipy.special import gammaln # noqa: F401
from scipy.special import softmax as softargmax # noqa: F401

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz
Expand Down
9 changes: 3 additions & 6 deletions desc/objectives/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,9 @@ class PlasmaVesselDistance(_Objective):
False by default, so that self.things = [eq, surface]
Both cannot be True.
softmin_alpha: float, optional
Parameter used for softmin. The larger softmin_alpha, the closer the softmin
approximates the hardmin. softmin -> hardmin as softmin_alpha -> infinity.
if softmin_alpha*array < 1, the underlying softmin will automatically multiply
the array by 2/min_val to ensure that softmin_alpha*array>1. Making
softmin_alpha larger than this minimum value will make the softmin a
more accurate approximation of the true min.
Parameter used for softmin. The larger ``softmin_alpha``, the closer the
softmin approximates the hardmin. softmin -> hardmin as
``softmin_alpha`` -> infinity.
"""

Expand Down
23 changes: 2 additions & 21 deletions desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from desc.backend import cond, jit, jnp, logsumexp, put
from desc.backend import jit, jnp, put, softargmax
from desc.io import IOAble
from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif

Expand Down Expand Up @@ -285,14 +285,6 @@ def __call__(self, x_reduced):
def softmax(arr, alpha):
"""JAX softmax implementation.
Inspired by https://www.johndcook.com/blog/2010/01/13/soft-maximum/
and https://www.johndcook.com/blog/2010/01/20/how-to-compute-the-soft-maximum/
Will automatically multiply array values by 2 / min_val if the min_val of
the array is <1. This is to avoid inaccuracies that arise when values <1
are present in the softmax, which can cause inaccurate maxes or even incorrect
signs of the softmax versus the actual max.
Parameters
----------
arr : ndarray
Expand All @@ -309,18 +301,7 @@ def softmax(arr, alpha):
"""
arr_times_alpha = alpha * arr
min_val = jnp.min(jnp.abs(arr_times_alpha)) + 1e-4 # buffer value in case min is 0
return cond(
jnp.any(min_val < 1),
lambda arr_times_alpha: logsumexp(
arr_times_alpha / min_val * 2
) # adjust to make vals>1
/ alpha
* min_val
/ 2,
lambda arr_times_alpha: logsumexp(arr_times_alpha) / alpha,
arr_times_alpha,
)
return softargmax(arr_times_alpha).dot(arr)


def softmin(arr, alpha):
Expand Down
15 changes: 7 additions & 8 deletions tests/test_objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def test_plasma_vessel_distance(self):
warnings.simplefilter("error")
obj.build()

# test softmin, should give value less than true minimum
# test softmin, should give approximate value
surf_grid = LinearGrid(M=5, N=6)
plas_grid = LinearGrid(M=5, N=6)
obj = PlasmaVesselDistance(
Expand All @@ -727,11 +727,12 @@ def test_plasma_vessel_distance(self):
surface_grid=surf_grid,
surface=surface,
use_softmin=True,
softmin_alpha=5,
)
obj.build()
d = obj.compute_unscaled(*obj.xs(eq, surface))
assert d.size == obj.dim_f
assert np.all(np.abs(d) < a_s - a_p)
np.testing.assert_allclose(np.abs(d).min(), a_s - a_p, rtol=1.5e-1)

# for large enough alpha, should be same as actual min
obj = PlasmaVesselDistance(
Expand Down Expand Up @@ -2308,20 +2309,18 @@ def test_objective_target_bounds():
def test_softmax_and_softmin():
"""Test softmax and softmin function."""
arr = np.arange(-17, 17, 5)
# expect this to not be equal to the max but rather be more
# since softmax is a conservative estimate of the max
# expect this to not be equal to the max but approximately so
sftmax = softmax(arr, alpha=1)
assert sftmax >= np.max(arr)
np.testing.assert_allclose(sftmax, np.max(arr), rtol=1e-2)

# expect this to be equal to the max
# as alpha -> infinity, softmax -> max
sftmax = softmax(arr, alpha=100)
np.testing.assert_almost_equal(sftmax, np.max(arr))

# expect this to not be equal to the min but rather be less
# since softmin is a conservative estimate of the min
# expect this to not be equal to the min but approximately so
sftmin = softmin(arr, alpha=1)
assert sftmin <= np.min(arr)
np.testing.assert_allclose(sftmin, np.min(arr), rtol=1e-2)

# expect this to be equal to the min
# as alpha -> infinity, softmin -> min
Expand Down

0 comments on commit c8b1558

Please sign in to comment.