diff --git a/desc/backend.py b/desc/backend.py index 5aab199a2a..ecbc915cbb 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -75,14 +75,7 @@ from jax.numpy import bincount, flatnonzero, repeat, take from jax.numpy.fft import irfft, rfft, rfft2 from jax.scipy.fft import dct, idct - from jax.scipy.linalg import ( - block_diag, - cho_factor, - cho_solve, - eigh_tridiagonal, - qr, - solve_triangular, - ) + from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular from jax.scipy.special import gammaln, logsumexp from jax.tree_util import ( register_pytree_node, @@ -98,6 +91,31 @@ jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid ) + def execute_on_cpu(func): + """Decorator to set default device to CPU for a function. + + Parameters + ---------- + func : callable + Function to decorate + + Returns + ------- + wrapper : callable + Decorated function that will always run on CPU even if + there are available GPUs. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with jax.default_device(jax.devices("cpu")[0]): + return func(*args, **kwargs) + + return wrapper + + # JAX implementation is not differentiable on gpu. + eigh_tridiagonal = execute_on_cpu(jax.scipy.linalg.eigh_tridiagonal) + def put(arr, inds, vals): """Functional interface for array "fancy indexing". @@ -123,28 +141,6 @@ def put(arr, inds, vals): return arr return jnp.asarray(arr).at[inds].set(vals) - def execute_on_cpu(func): - """Decorator to set default device to CPU for a function. - - Parameters - ---------- - func : callable - Function to decorate - - Returns - ------- - wrapper : callable - Decorated function that will run always on CPU even if - there are available GPUs. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - with jax.default_device(jax.devices("cpu")[0]): - return func(*args, **kwargs) - - return wrapper - def sign(x): """Sign function, but returns 1 for x==0. diff --git a/desc/compute/_neoclassical.py b/desc/compute/_neoclassical.py index 3eb1c143ee..96a0c054f0 100644 --- a/desc/compute/_neoclassical.py +++ b/desc/compute/_neoclassical.py @@ -19,7 +19,7 @@ from ..integrals.quad_utils import get_quadrature, leggauss_lob from ..utils import map2, safediv from .data_index import register_compute_fun -from .utils import _get_pitch_inv, _poloidal_mean +from .utils import _get_pitch_inv_chebgauss, _poloidal_mean @register_compute_fun( @@ -79,10 +79,10 @@ def _G_ra_fsa(data, transforms, profiles, **kwargs): @register_compute_fun( name="effective ripple", # this is ε¹ᐧ⁵ label=( - # ε¹ᐧ⁵ = π/(8√2) (R₀/〈|∇ψ|〉)² ∫dλ λ⁻²B₀⁻¹ 〈 ∑ⱼ Hⱼ²/Iⱼ 〉 + # ε¹ᐧ⁵ = π/(8√2) (R₀/〈|∇ψ|〉)² B₀⁻¹ ∫dλ λ⁻² 〈 ∑ⱼ Hⱼ²/Iⱼ 〉 "\\epsilon^{3/2} = \\frac{\\pi}{8 \\sqrt{2}} " "(R_0 / \\langle \\vert\\nabla \\psi\\vert \\rangle)^2 " - "\\int d\\lambda \\lambda^{-2} B_0^{-1} " + "B_0^{-1} \\int d\\lambda \\lambda^{-2} " "\\langle \\sum_j H_j^2 / I_j \\rangle" ), units="~", @@ -106,12 +106,7 @@ def _G_ra_fsa(data, transforms, profiles, **kwargs): resolution_requirement="z", source_grid_requirement={"coordinates": "raz", "is_meshgrid": True}, quad="jnp.ndarray : Optional, quadrature points and weights for bounce integrals.", - num_pitch=( - "int : Resolution for quadrature over velocity coordinate, preferably odd. " - "Default is 75. Profile will look smoother at high values." - # If computed on many flux surfaces and small oscillations are seen - # between neighboring surfaces, increasing this will smooth the profile. - ), + num_pitch="int : Resolution for quadrature over velocity coordinate. Default 50.", num_well=( "int : Maximum number of wells to detect for each pitch and field line. " "Default is to detect all wells, but due to limitations in JAX this option " @@ -152,7 +147,7 @@ def _effective_ripple(params, transforms, profiles, data, **kwargs): if "quad" in kwargs else get_quadrature(leggauss_lob(32), Bounce1D._default_automorphism) ) - num_pitch = kwargs.get("num_pitch", 75) + num_pitch = kwargs.get("num_pitch", 50) num_well = kwargs.get("num_well", None) batch = kwargs.get("batch", True) grid = transforms["grid"].source_grid @@ -171,24 +166,29 @@ def dI(B, pitch): return jnp.sqrt(jnp.abs(1 - pitch * B)) / B def compute(data): - """(∂ψ/∂ρ)⁻² B₀⁻² ∫ dλ ∑ⱼ Hⱼ²/Iⱼ.""" + """Return (∂ψ/∂ρ)⁻² B₀⁻² ∫ dλ λ⁻² ∑ⱼ Hⱼ²/Iⱼ. + + Notes + ----- + B₀ has units of λ⁻¹. + Nemov's ∑ⱼ Hⱼ²/Iⱼ = (∂ψ/∂ρ)² (λB₀)³ ``(H**2 / I).sum(axis=-1)``. + (λB₀)³ d(λB₀)⁻¹ = B₀² λ³ d(λ⁻¹) = -B₀² λ dλ. + """ bounce = Bounce1D(grid, data, quad, automorphism=None, is_reshaped=True) - # Interpolate |∇ρ| κ_g since it is smoother than κ_g alone. H = bounce.integrate( dH, data["pitch_inv"], + # Interpolate |∇ρ| κ_g since it is smoother than κ_g alone. data["|grad(rho)|*kappa_g"], num_well=num_well, batch=batch, ) I = bounce.integrate(dI, data["pitch_inv"], num_well=num_well, batch=batch) - # Note B₀ has units of λ⁻¹. - # Nemov's ∑ⱼ Hⱼ²/Iⱼ = (∂ψ/∂ρ)² (λB₀)³ ``(H**2 / I).sum(axis=-1)``. - # (λB₀)³ db = λ³B₀² d(λ⁻¹) = λB₀² (-dλ). - y = data["pitch_inv"] ** (-3) * safediv(H**2, I).sum(axis=-1) - return simpson(y=y, x=data["pitch_inv"]) - # TODO: Try Gauss-Chebyshev quadrature after automorphism arcsin to - # make nodes more evenly spaced. + return ( + safediv(H**2, I).sum(axis=-1) + * data["pitch_inv"] ** (-3) + * data["pitch_inv weight"] + ).sum(axis=-1) _data = { # noqa: unused dependency name: Bounce1D.reshape_data(grid, data[name]) @@ -197,13 +197,13 @@ def compute(data): _data["|grad(rho)|*kappa_g"] = Bounce1D.reshape_data( grid, data["|grad(rho)|"] * data["kappa_g"] ) - _data["pitch_inv"] = _get_pitch_inv(grid, data, num_pitch) - out = _poloidal_mean(grid, map2(compute, _data)) + _data = _get_pitch_inv_chebgauss(grid, data, num_pitch, _data) + B0 = data["max_tz |B|"] data["effective ripple"] = ( jnp.pi / (8 * 2**0.5) - * (data["max_tz |B|"] * data["R0"] / data["<|grad(rho)|>"]) ** 2 - * grid.expand(out) + * (B0 * data["R0"] / data["<|grad(rho)|>"]) ** 2 + * grid.expand(_poloidal_mean(grid, map2(compute, _data))) / data[""] ) return data diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 26ed64c081..e7d496a045 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -8,7 +8,7 @@ from desc.backend import execute_on_cpu, jnp from desc.grid import Grid -from ..integrals.bounce_utils import get_pitch_inv +from ..integrals.bounce_utils import get_pitch_inv, get_pitch_inv_chebgauss from ..utils import errorif from .data_index import allowed_kwargs, data_index @@ -728,12 +728,28 @@ def _poloidal_mean(grid, f): return f.T.dot(dp) / jnp.sum(dp) -def _get_pitch_inv(grid, data, num_pitch): - return jnp.broadcast_to( +def _get_pitch_inv(grid, data, num_pitch, _data): + _data["pitch_inv"] = jnp.broadcast_to( get_pitch_inv( grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), num_pitch, )[jnp.newaxis], - (grid.num_alpha, grid.num_rho, num_pitch + 2), + (grid.num_alpha, grid.num_rho, num_pitch), ) + return _data + + +def _get_pitch_inv_chebgauss(grid, data, num_pitch, _data): + p, w = get_pitch_inv_chebgauss( + grid.compress(data["min_tz |B|"]), + grid.compress(data["max_tz |B|"]), + num_pitch, + ) + _data["pitch_inv"] = jnp.broadcast_to( + p[jnp.newaxis], (grid.num_alpha, grid.num_rho, num_pitch) + ) + _data["pitch_inv weight"] = jnp.broadcast_to( + w[jnp.newaxis], (grid.num_alpha, grid.num_rho, num_pitch) + ) + return _data diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index 4d4263e849..9b784d334c 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -14,6 +14,7 @@ ) from desc.integrals.quad_utils import ( bijection_from_disc, + chebgauss_uniform, composite_linspace, grad_bijection_from_disc, ) @@ -37,7 +38,7 @@ def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): max_B : jnp.ndarray Maximum |B| value. num : int - Number of values, not including endpoints. + Number of values. relative_shift : float Relative amount to shift maxima down and minima up to avoid floating point errors in downstream routines. @@ -45,20 +46,53 @@ def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): Returns ------- pitch_inv : jnp.ndarray - Shape (*min_B.shape, num + 2). + Shape (*min_B.shape, num). 1/λ values. """ # Floating point error impedes consistent detection of bounce points riding # extrema. Shift values slightly to resolve this issue. - min_B = (1 + relative_shift) * min_B - max_B = (1 - relative_shift) * max_B + min_B = (1.0 + relative_shift) * min_B + max_B = (1.0 - relative_shift) * max_B # Samples should be uniformly spaced in |B| and not λ (GitHub issue #1228). - pitch_inv = jnp.moveaxis(composite_linspace(jnp.stack([min_B, max_B]), num), 0, -1) - assert pitch_inv.shape == (*min_B.shape, num + 2) + pitch_inv = jnp.moveaxis( + composite_linspace(jnp.stack([min_B, max_B]), num - 2), 0, -1 + ) + assert pitch_inv.shape == (*min_B.shape, num) return pitch_inv +def get_pitch_inv_chebgauss(min_B, max_B, num, relative_shift=1e-6): + """Return Chebyshev quadrature with 1/λ uniform in ``min_B`` and ``max_B``. + + Parameters + ---------- + min_B : jnp.ndarray + Minimum |B| value. + max_B : jnp.ndarray + Maximum |B| value. + num : int + Number of values. + relative_shift : float + Relative amount to shift maxima down and minima up to avoid floating point + errors in downstream routines. + + Returns + ------- + pitch_inv, weight : (jnp.ndarray, jnp.ndarray) + Shape (*min_B.shape, num). + 1/λ values and weights. + + """ + min_B = (1.0 + relative_shift) * min_B + max_B = (1.0 - relative_shift) * max_B + # Samples should be uniformly spaced in |B| (GitHub issue #1228). + x, w = chebgauss_uniform(num) + pitch_inv = bijection_from_disc(x, min_B[..., jnp.newaxis], max_B[..., jnp.newaxis]) + w = w * grad_bijection_from_disc(min_B, max_B)[..., jnp.newaxis] + return pitch_inv, w + + def _check_spline_shape(knots, g, dg_dz, pitch_inv=None): """Ensure inputs have compatible shape. diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py index 5f7c5994bf..2f85f64058 100644 --- a/desc/integrals/quad_utils.py +++ b/desc/integrals/quad_utils.py @@ -1,8 +1,9 @@ """Utilities for quadratures.""" +from orthax.chebyshev import chebgauss from orthax.legendre import legder, legval -from desc.backend import eigh_tridiagonal, execute_on_cpu, jnp, put +from desc.backend import eigh_tridiagonal, jnp, put from desc.utils import errorif @@ -139,7 +140,6 @@ def tanh_sinh(deg, m=10): return x, w -@execute_on_cpu # JAX implementation of eigh_tridiagonal is not differentiable on gpu. def leggauss_lob(deg, interior_only=False): """Lobatto-Gauss-Legendre quadrature. @@ -191,6 +191,31 @@ def leggauss_lob(deg, interior_only=False): return x, w +def chebgauss_uniform(deg): + """Gauss-Chebyshev quadrature with uniformly spaced nodes. + + Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the + integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). + + Parameters + ---------- + deg : int + Number of quadrature points. + + Returns + ------- + x, w : (jnp.ndarray, jnp.ndarray) + Shape (deg, ). + Quadrature points and weights. + """ + # Define x = 2/π arcsin y and g : y ↦ f(x(y)). + # ∫₋₁¹ f(x) dx = 2/π ∫₋₁¹ (1−y²)⁻⁰ᐧ⁵ g(y) dy + # ∑ₖ wₖ f(x(yₖ)) = 2/π ∑ₖ ωₖ g(yₖ) + # Given roots yₖ of Chebyshev polynomial, x(yₖ) is uniform in (-1, 1). + y, w = chebgauss(deg) + return automorphism_arcsin(y), 2 * w / jnp.pi + + def get_quadrature(quad, automorphism): """Apply automorphism to given quadrature. diff --git a/desc/objectives/_neoclassical.py b/desc/objectives/_neoclassical.py index 2936fe6ebb..302db58cfa 100644 --- a/desc/objectives/_neoclassical.py +++ b/desc/objectives/_neoclassical.py @@ -42,7 +42,7 @@ class EffectiveRipple(_Objective): locations. Defaults to 0. bounds : tuple of {float, ndarray, callable}, optional Lower and upper bounds on the objective. Overrides target. - Both bounds must be broadcastable to Objective.dim_f + Both bounds must be broadcastable to Objective.dim_f. If a callable, each should take a single argument ``rho`` and return the desired bound (lower or upper) of the profile at those locations. weight : {float, ndarray}, optional @@ -68,18 +68,17 @@ class EffectiveRipple(_Objective): Should have poloidal and toroidal resolution. alpha : ndarray Unique coordinate values for field line poloidal angle label alpha. + knots_per_transit : int + Number of points per toroidal transit at which to sample data along field + line. Default is 100. num_transit : int Number of toroidal transits to follow field line. For axisymmetric devices, one poloidal transit is sufficient. Otherwise, more transits will give more accurate result, with diminishing returns. - knots_per_transit : int - Number of points per toroidal transit at which to sample data along field - line. Default is 100. num_quad : int Resolution for quadrature of bounce integrals. Default is 32. num_pitch : int - Resolution for quadrature over velocity coordinate, preferably odd. - Default is 75. Profile will look smoother at high values. + Resolution for quadrature over velocity coordinate. Default 50. batch : bool Whether to vectorize part of the computation. Default is true. num_well : int @@ -99,7 +98,7 @@ class EffectiveRipple(_Objective): def __init__( self, eq, - target=0.0, + target=None, bounds=None, weight=1, normalize=True, @@ -108,16 +107,17 @@ def __init__( deriv_mode="auto", grid=None, alpha=np.array([0]), - num_transit=10, + *, knots_per_transit=100, + num_transit=10, num_quad=32, - num_pitch=75, + num_pitch=50, batch=True, num_well=None, name="Effective ripple", ): - if bounds is not None: - target = None + if target is None and bounds is None: + target = 0.0 self._keys_1dr = [ "iota", diff --git a/tests/baseline/test_effective_ripple.png b/tests/baseline/test_effective_ripple.png index 38bbd3846b..0a19da2c30 100644 Binary files a/tests/baseline/test_effective_ripple.png and b/tests/baseline/test_effective_ripple.png differ diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 6c669e2912..bcf246cd35 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -2521,8 +2521,17 @@ def test_compute_scalar_resolution_others(self, objective): M_grid=int(self.eq.M * res), N_grid=int(self.eq.N * res), ) - - obj = ObjectiveFunction(objective(eq=self.eq), use_jit=False) + if objective in {EffectiveRipple}: + obj = ObjectiveFunction( + [ + objective( + self.eq, knots_per_transit=50, num_transit=2, num_pitch=25 + ) + ], + use_jit=False, + ) + else: + obj = ObjectiveFunction(objective(eq=self.eq), use_jit=False) obj.build(verbose=0) f[i] = obj.compute_scalar(obj.x()) np.testing.assert_allclose(f, f[-1], rtol=5e-2) @@ -2803,7 +2812,9 @@ def test_objective_no_nangrad_effective_ripple(self): eq = get("ESTELL") with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(2, 2, 2, 4, 4, 4) - obj = ObjectiveFunction([EffectiveRipple(eq)]) + obj = ObjectiveFunction( + [EffectiveRipple(eq, knots_per_transit=50, num_transit=2, num_pitch=25)] + ) obj.build(verbose=0) g = obj.grad(obj.x()) assert not np.any(np.isnan(g)) diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py index 5a7c3d00e7..4a18c9335c 100644 --- a/tests/test_quad_utils.py +++ b/tests/test_quad_utils.py @@ -2,7 +2,9 @@ import numpy as np import pytest +import scipy from jax import grad +from orthax.chebyshev import chebgauss from desc.backend import jnp from desc.integrals.quad_utils import ( @@ -10,6 +12,7 @@ automorphism_sin, bijection_from_disc, bijection_to_disc, + chebgauss_uniform, composite_linspace, grad_automorphism_arcsin, grad_automorphism_sin, @@ -101,3 +104,20 @@ def fun(a): # make sure differentiable # https://github.com/PlasmaControl/DESC/pull/854#discussion_r1733323161 assert np.isfinite(grad(fun)(jnp.arange(10) * np.pi)).all() + + +@pytest.mark.unit +def test_chebgauss_uniform(): + """Test uniform Chebyshev quadrature.""" + + def f(y): + return 5.2 * y**7 - 3.6 * y**3 + y**4 + + truth = scipy.integrate.quad(lambda y: f(y) / np.sqrt(1 - y**2), -1, 1)[0] + deg = 4 + yk, _ = chebgauss(deg) + x, w = chebgauss_uniform(deg) + np.testing.assert_array_equal(np.sign(x), np.sign(yk)) + np.testing.assert_allclose(np.diff(x), x[1] - x[0]) + np.testing.assert_allclose(yk, automorphism_sin(x)) + np.testing.assert_allclose(f(yk).dot(w), 2 * truth / jnp.pi)