Skip to content

Commit

Permalink
Improve quadrature over velocity coordiante for effective ripple
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Sep 16, 2024
1 parent 3a93117 commit 52adba9
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 79 deletions.
56 changes: 26 additions & 30 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,7 @@
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import (
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import (
register_pytree_node,
Expand All @@ -98,6 +91,31 @@
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Parameters
----------
func : callable
Function to decorate
Returns
-------
wrapper : callable
Decorated function that will always run on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

# JAX implementation is not differentiable on gpu.
eigh_tridiagonal = execute_on_cpu(jax.scipy.linalg.eigh_tridiagonal)

def put(arr, inds, vals):
"""Functional interface for array "fancy indexing".
Expand All @@ -123,28 +141,6 @@ def put(arr, inds, vals):
return arr
return jnp.asarray(arr).at[inds].set(vals)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Parameters
----------
func : callable
Function to decorate
Returns
-------
wrapper : callable
Decorated function that will run always on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

def sign(x):
"""Sign function, but returns 1 for x==0.
Expand Down
46 changes: 23 additions & 23 deletions desc/compute/_neoclassical.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..integrals.quad_utils import get_quadrature, leggauss_lob
from ..utils import map2, safediv
from .data_index import register_compute_fun
from .utils import _get_pitch_inv, _poloidal_mean
from .utils import _get_pitch_inv_chebgauss, _poloidal_mean


@register_compute_fun(
Expand Down Expand Up @@ -79,10 +79,10 @@ def _G_ra_fsa(data, transforms, profiles, **kwargs):
@register_compute_fun(
name="effective ripple", # this is ε¹ᐧ⁵
label=(
# ε¹ᐧ⁵ = π/(8√2) (R₀/〈|∇ψ|〉)² ∫dλ λ⁻²B₀⁻¹ 〈 ∑ⱼ Hⱼ²/Iⱼ 〉
# ε¹ᐧ⁵ = π/(8√2) (R₀/〈|∇ψ|〉)² B₀⁻¹ ∫dλ λ⁻² 〈 ∑ⱼ Hⱼ²/Iⱼ 〉
"\\epsilon^{3/2} = \\frac{\\pi}{8 \\sqrt{2}} "
"(R_0 / \\langle \\vert\\nabla \\psi\\vert \\rangle)^2 "
"\\int d\\lambda \\lambda^{-2} B_0^{-1} "
"B_0^{-1} \\int d\\lambda \\lambda^{-2} "
"\\langle \\sum_j H_j^2 / I_j \\rangle"
),
units="~",
Expand All @@ -106,12 +106,7 @@ def _G_ra_fsa(data, transforms, profiles, **kwargs):
resolution_requirement="z",
source_grid_requirement={"coordinates": "raz", "is_meshgrid": True},
quad="jnp.ndarray : Optional, quadrature points and weights for bounce integrals.",
num_pitch=(
"int : Resolution for quadrature over velocity coordinate, preferably odd. "
"Default is 75. Profile will look smoother at high values."
# If computed on many flux surfaces and small oscillations are seen
# between neighboring surfaces, increasing this will smooth the profile.
),
num_pitch="int : Resolution for quadrature over velocity coordinate. Default 50.",
num_well=(
"int : Maximum number of wells to detect for each pitch and field line. "
"Default is to detect all wells, but due to limitations in JAX this option "
Expand Down Expand Up @@ -152,7 +147,7 @@ def _effective_ripple(params, transforms, profiles, data, **kwargs):
if "quad" in kwargs
else get_quadrature(leggauss_lob(32), Bounce1D._default_automorphism)
)
num_pitch = kwargs.get("num_pitch", 75)
num_pitch = kwargs.get("num_pitch", 50)
num_well = kwargs.get("num_well", None)
batch = kwargs.get("batch", True)
grid = transforms["grid"].source_grid
Expand All @@ -171,24 +166,29 @@ def dI(B, pitch):
return jnp.sqrt(jnp.abs(1 - pitch * B)) / B

def compute(data):
"""(∂ψ/∂ρ)⁻² B₀⁻² ∫ dλ ∑ⱼ Hⱼ²/Iⱼ."""
"""Return (∂ψ/∂ρ)⁻² B₀⁻² ∫ dλ λ⁻² ∑ⱼ Hⱼ²/Iⱼ.
Notes
-----
B₀ has units of λ⁻¹.
Nemov's ∑ⱼ Hⱼ²/Iⱼ = (∂ψ/∂ρ)² (λB₀)³ ``(H**2 / I).sum(axis=-1)``.
(λB₀)³ d(λB₀)⁻¹ = B₀² λ³ d(λ⁻¹) = -B₀² λ dλ.
"""
bounce = Bounce1D(grid, data, quad, automorphism=None, is_reshaped=True)
# Interpolate |∇ρ| κ_g since it is smoother than κ_g alone.
H = bounce.integrate(
dH,
data["pitch_inv"],
# Interpolate |∇ρ| κ_g since it is smoother than κ_g alone.
data["|grad(rho)|*kappa_g"],
num_well=num_well,
batch=batch,
)
I = bounce.integrate(dI, data["pitch_inv"], num_well=num_well, batch=batch)
# Note B₀ has units of λ⁻¹.
# Nemov's ∑ⱼ Hⱼ²/Iⱼ = (∂ψ/∂ρ)² (λB₀)³ ``(H**2 / I).sum(axis=-1)``.
# (λB₀)³ db = λ³B₀² d(λ⁻¹) = λB₀² (-dλ).
y = data["pitch_inv"] ** (-3) * safediv(H**2, I).sum(axis=-1)
return simpson(y=y, x=data["pitch_inv"])
# TODO: Try Gauss-Chebyshev quadrature after automorphism arcsin to
# make nodes more evenly spaced.
return (
safediv(H**2, I).sum(axis=-1)
* data["pitch_inv"] ** (-3)
* data["pitch_inv weight"]
).sum(axis=-1)

_data = { # noqa: unused dependency
name: Bounce1D.reshape_data(grid, data[name])
Expand All @@ -197,13 +197,13 @@ def compute(data):
_data["|grad(rho)|*kappa_g"] = Bounce1D.reshape_data(
grid, data["|grad(rho)|"] * data["kappa_g"]
)
_data["pitch_inv"] = _get_pitch_inv(grid, data, num_pitch)
out = _poloidal_mean(grid, map2(compute, _data))
_data = _get_pitch_inv_chebgauss(grid, data, num_pitch, _data)
B0 = data["max_tz |B|"]
data["effective ripple"] = (
jnp.pi
/ (8 * 2**0.5)
* (data["max_tz |B|"] * data["R0"] / data["<|grad(rho)|>"]) ** 2
* grid.expand(out)
* (B0 * data["R0"] / data["<|grad(rho)|>"]) ** 2
* grid.expand(_poloidal_mean(grid, map2(compute, _data)))
/ data["<L|r,a>"]
)
return data
24 changes: 20 additions & 4 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from desc.backend import execute_on_cpu, jnp
from desc.grid import Grid

from ..integrals.bounce_utils import get_pitch_inv
from ..integrals.bounce_utils import get_pitch_inv, get_pitch_inv_chebgauss
from ..utils import errorif
from .data_index import allowed_kwargs, data_index

Expand Down Expand Up @@ -728,12 +728,28 @@ def _poloidal_mean(grid, f):
return f.T.dot(dp) / jnp.sum(dp)

Check warning on line 728 in desc/compute/utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/utils.py#L724-L728

Added lines #L724 - L728 were not covered by tests


def _get_pitch_inv(grid, data, num_pitch):
return jnp.broadcast_to(
def _get_pitch_inv(grid, data, num_pitch, _data):
_data["pitch_inv"] = jnp.broadcast_to(

Check warning on line 732 in desc/compute/utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/utils.py#L732

Added line #L732 was not covered by tests
get_pitch_inv(
grid.compress(data["min_tz |B|"]),
grid.compress(data["max_tz |B|"]),
num_pitch,
)[jnp.newaxis],
(grid.num_alpha, grid.num_rho, num_pitch + 2),
(grid.num_alpha, grid.num_rho, num_pitch),
)
return _data

Check warning on line 740 in desc/compute/utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/utils.py#L740

Added line #L740 was not covered by tests


def _get_pitch_inv_chebgauss(grid, data, num_pitch, _data):
p, w = get_pitch_inv_chebgauss(
grid.compress(data["min_tz |B|"]),
grid.compress(data["max_tz |B|"]),
num_pitch,
)
_data["pitch_inv"] = jnp.broadcast_to(
p[jnp.newaxis], (grid.num_alpha, grid.num_rho, num_pitch)
)
_data["pitch_inv weight"] = jnp.broadcast_to(
w[jnp.newaxis], (grid.num_alpha, grid.num_rho, num_pitch)
)
return _data
46 changes: 40 additions & 6 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from desc.integrals.quad_utils import (
bijection_from_disc,
chebgauss_uniform,
composite_linspace,
grad_bijection_from_disc,
)
Expand All @@ -37,28 +38,61 @@ def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6):
max_B : jnp.ndarray
Maximum |B| value.
num : int
Number of values, not including endpoints.
Number of values.
relative_shift : float
Relative amount to shift maxima down and minima up to avoid floating point
errors in downstream routines.
Returns
-------
pitch_inv : jnp.ndarray
Shape (*min_B.shape, num + 2).
Shape (*min_B.shape, num).
1/λ values.
"""
# Floating point error impedes consistent detection of bounce points riding
# extrema. Shift values slightly to resolve this issue.
min_B = (1 + relative_shift) * min_B
max_B = (1 - relative_shift) * max_B
min_B = (1.0 + relative_shift) * min_B
max_B = (1.0 - relative_shift) * max_B

Check warning on line 56 in desc/integrals/bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_utils.py#L55-L56

Added lines #L55 - L56 were not covered by tests
# Samples should be uniformly spaced in |B| and not λ (GitHub issue #1228).
pitch_inv = jnp.moveaxis(composite_linspace(jnp.stack([min_B, max_B]), num), 0, -1)
assert pitch_inv.shape == (*min_B.shape, num + 2)
pitch_inv = jnp.moveaxis(

Check warning on line 58 in desc/integrals/bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_utils.py#L58

Added line #L58 was not covered by tests
composite_linspace(jnp.stack([min_B, max_B]), num - 2), 0, -1
)
assert pitch_inv.shape == (*min_B.shape, num)

Check warning on line 61 in desc/integrals/bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_utils.py#L61

Added line #L61 was not covered by tests
return pitch_inv


def get_pitch_inv_chebgauss(min_B, max_B, num, relative_shift=1e-6):
"""Return Chebyshev quadrature with 1/λ uniform in ``min_B`` and ``max_B``.
Parameters
----------
min_B : jnp.ndarray
Minimum |B| value.
max_B : jnp.ndarray
Maximum |B| value.
num : int
Number of values.
relative_shift : float
Relative amount to shift maxima down and minima up to avoid floating point
errors in downstream routines.
Returns
-------
pitch_inv, weight : (jnp.ndarray, jnp.ndarray)
Shape (*min_B.shape, num).
1/λ values and weights.
"""
min_B = (1.0 + relative_shift) * min_B
max_B = (1.0 - relative_shift) * max_B
# Samples should be uniformly spaced in |B| (GitHub issue #1228).
x, w = chebgauss_uniform(num)
pitch_inv = bijection_from_disc(x, min_B[..., jnp.newaxis], max_B[..., jnp.newaxis])
w = w * grad_bijection_from_disc(min_B, max_B)[..., jnp.newaxis]
return pitch_inv, w


def _check_spline_shape(knots, g, dg_dz, pitch_inv=None):
"""Ensure inputs have compatible shape.
Expand Down
29 changes: 27 additions & 2 deletions desc/integrals/quad_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Utilities for quadratures."""

from orthax.chebyshev import chebgauss
from orthax.legendre import legder, legval

from desc.backend import eigh_tridiagonal, execute_on_cpu, jnp, put
from desc.backend import eigh_tridiagonal, jnp, put
from desc.utils import errorif


Expand Down Expand Up @@ -139,7 +140,6 @@ def tanh_sinh(deg, m=10):
return x, w


@execute_on_cpu # JAX implementation of eigh_tridiagonal is not differentiable on gpu.
def leggauss_lob(deg, interior_only=False):
"""Lobatto-Gauss-Legendre quadrature.
Expand Down Expand Up @@ -191,6 +191,31 @@ def leggauss_lob(deg, interior_only=False):
return x, w


def chebgauss_uniform(deg):
"""Gauss-Chebyshev quadrature with uniformly spaced nodes.
Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the
integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ).
Parameters
----------
deg : int
Number of quadrature points.
Returns
-------
x, w : (jnp.ndarray, jnp.ndarray)
Shape (deg, ).
Quadrature points and weights.
"""
# Define x = 2/π arcsin y and g : y ↦ f(x(y)).
# ∫₋₁¹ f(x) dx = 2/π ∫₋₁¹ (1−y²)⁻⁰ᐧ⁵ g(y) dy
# ∑ₖ wₖ f(x(yₖ)) = 2/π ∑ₖ ωₖ g(yₖ)
# Given roots yₖ of Chebyshev polynomial, x(yₖ) is uniform in (-1, 1).
y, w = chebgauss(deg)
return automorphism_arcsin(y), 2 * w / jnp.pi


def get_quadrature(quad, automorphism):
"""Apply automorphism to given quadrature.
Expand Down
Loading

0 comments on commit 52adba9

Please sign in to comment.