Skip to content

Commit

Permalink
Swap vectorization order in bounce integrals (#1242)
Browse files Browse the repository at this point in the history
After the recent refactoring to the `Bounce1D` class that resulted from
#1214, the API is a little too strict for computations like effective
ripple etc. where we vectorize the computation over over some dimensions
and loop over others to save memory.

This PR changes the expected shape of the pitch angle input to
`Bounce1D` in #854 from `(P, M, L)` to `(M, L, P)`. With this change,
the two leading axes of all inputs to the methods in that class is `(M,
L)`.

These changes are tested and already included in downstream branches. I
am making new PR instead of directly committing to the `bounce` branch
for people who have already reviewed the `bounce` PR.

 This is better because
1. Easier usage for end users. (Previously, you'd have to manually add
trailing axes to pitch angle array).
2. Makes it much simpler to use with JAX's new batched map.
3. Previously we would loop over the pitch angles to save memory.
However, this means some computation is repeated because interpax would
interpolate multiple times. By looping over the field lines instead and
doing the interpolation for all the pitch angles at once, both
`_bounce_quadrature` and `interp_to_argmin` are faster. (I'm seeing 30%
faster speed just from computing effective ripple (no optimization), but
I don't plan to do any benchmarking to see whether that is from recent
changes like #1154 or #1043 , or others).
  • Loading branch information
unalmis authored Sep 3, 2024
2 parents c531a82 + 08e4257 commit 1436035
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 210 deletions.
113 changes: 59 additions & 54 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""Methods for computing bounce integrals (singular or otherwise)."""

import numpy as np
from interpax import CubicHermiteSpline, PPoly
from orthax.legendre import leggauss

from desc.backend import jnp
from desc.integrals.bounce_utils import (
_bounce_quadrature,
_check_bounce_points,
_set_default_plot_kwargs,
bounce_points,
bounce_quadrature,
get_pitch_inv,
interp_to_argmin,
plot_ppoly,
Expand All @@ -21,7 +20,7 @@
grad_automorphism_sin,
)
from desc.io import IOAble
from desc.utils import atleast_nd, errorif, setdefault, warnif
from desc.utils import errorif, setdefault, warnif


class Bounce1D(IOAble):
Expand Down Expand Up @@ -108,6 +107,8 @@ def __init__(
automorphism=(automorphism_sin, grad_automorphism_sin),
Bref=1.0,
Lref=1.0,
*,
is_reshaped=False,
check=False,
**kwargs,
):
Expand Down Expand Up @@ -137,6 +138,13 @@ def __init__(
Optional. Reference magnetic field strength for normalization.
Lref : float
Optional. Reference length scale for normalization.
is_reshaped : bool
Whether the arrays in ``data`` are already reshaped to the expected form of
shape (..., N) or (..., L, N) or (M, L, N). This option can be used to
iteratively compute bounce integrals one field line or one flux surface
at a time, respectively, potentially reducing memory usage. To do so,
set to true and provide only those axes of the reshaped data.
Default is false.
check : bool
Flag for debugging. Must be false for JAX transformations.
Expand All @@ -159,7 +167,11 @@ def __init__(
"|B|": data["|B|"] / Bref,
"|B|_z|r,a": data["|B|_z|r,a"] / Bref, # This is already the correct sign.
}
self._data = dict(zip(data.keys(), Bounce1D.reshape_data(grid, *data.values())))
self._data = (
data
if is_reshaped
else dict(zip(data.keys(), Bounce1D.reshape_data(grid, *data.values())))
)
self._x, self._w = get_quadrature(quad, automorphism)

# Compute local splines.
Expand All @@ -176,8 +188,10 @@ def __init__(
destination=(-1, -2),
)
self._dB_dz = polyder_vec(self.B)
assert self.B.shape == (grid.num_alpha, grid.num_rho, grid.num_zeta - 1, 4)
assert self._dB_dz.shape == (grid.num_alpha, grid.num_rho, grid.num_zeta - 1, 3)

# Add axis here instead of in ``_bounce_quadrature``.
for name in self._data:
self._data[name] = self._data[name][..., jnp.newaxis, :]

@staticmethod
def reshape_data(grid, *arys):
Expand All @@ -192,26 +206,23 @@ def reshape_data(grid, *arys):
Returns
-------
f : list[jnp.ndarray]
List of reshaped data which may be given to ``integrate``.
f : jnp.ndarray
Shape (M, L, N).
Reshaped data which may be given to ``integrate``.
"""
f = [grid.meshgrid_reshape(d, "arz") for d in arys]
return f
return f if len(f) > 1 else f[0]

def points(self, pitch_inv, num_well=None):
def points(self, pitch_inv, *, num_well=None):
"""Compute bounce points.
Notes
-----
Only the dimensions following L are required. The leading axes are batch axes.
Parameters
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
Shape (M, L, P).
1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
specified by ``pitch_inv[α,ρ]`` where in the latter the labels
are interpreted as the indices that correspond to that field line.
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
Expand All @@ -227,7 +238,7 @@ def points(self, pitch_inv, num_well=None):
Returns
-------
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape (P, M, L, num_well).
Shape (M, L, P, num_well).
ζ coordinates of bounce points. The points are ordered and grouped such
that the straight line path between ``z1`` and ``z2`` resides in the
epigraph of |B|.
Expand All @@ -239,20 +250,20 @@ def points(self, pitch_inv, num_well=None):
"""
return bounce_points(pitch_inv, self._zeta, self.B, self._dB_dz, num_well)

def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs):
def check_points(self, z1, z2, pitch_inv, *, plot=True, **kwargs):
"""Check that bounce points are computed correctly.
Parameters
----------
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape (P, M, L, num_well).
Shape (M, L, P, num_well).
ζ coordinates of bounce points. The points are ordered and grouped such
that the straight line path between ``z1`` and ``z2`` resides in the
epigraph of |B|.
pitch_inv : jnp.ndarray
Shape (P, M, L).
Shape (M, L, P).
1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
specified by ``pitch_inv[α,ρ]`` where in the latter the labels
are interpreted as the indices that correspond to that field line.
plot : bool
Whether to plot the field lines and bounce points of the given pitch angles.
Expand All @@ -268,7 +279,7 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs):
return _check_bounce_points(
z1=z1,
z2=z2,
pitch_inv=atleast_nd(3, pitch_inv),
pitch_inv=pitch_inv,
knots=self._zeta,
B=self.B,
plot=plot,
Expand All @@ -277,10 +288,11 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs):

def integrate(
self,
pitch_inv,
integrand,
pitch_inv,
f=None,
weight=None,
*,
num_well=None,
method="cubic",
batch=True,
Expand All @@ -291,24 +303,20 @@ def integrate(
Computes the bounce integral ∫ f(ℓ) dℓ for every field line and pitch.
Notes
-----
Only the dimensions following L are required. The leading axes are batch axes.
Parameters
----------
pitch_inv : jnp.ndarray
Shape (P, M, L).
1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by
``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
integrand : callable
The composition operator on the set of functions in ``f`` that maps the
functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the
arrays in ``f`` as arguments as well as the additional keyword arguments:
``B`` and ``pitch``. A quadrature will be performed to approximate the
bounce integral of ``integrand(*f,B=B,pitch=pitch)``.
f : list[jnp.ndarray]
pitch_inv : jnp.ndarray
Shape (M, L, P).
1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by
``pitch_inv[α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
f : list[jnp.ndarray] or jnp.ndarray
Shape (M, L, N).
Real scalar-valued functions evaluated on the ``grid`` supplied to
construct this object. These functions should be arguments to the callable
Expand Down Expand Up @@ -345,20 +353,19 @@ def integrate(
Returns
-------
result : jnp.ndarray
Shape (P, M, L, num_well).
Last axis enumerates the bounce integrals for a given pitch, field line,
and flux surface.
Shape (M, L, P, num_well).
Last axis enumerates the bounce integrals for a given field line,
flux surface, and pitch value.
"""
pitch_inv = atleast_nd(3, pitch_inv)
z1, z2 = self.points(pitch_inv, num_well)
result = bounce_quadrature(
z1, z2 = self.points(pitch_inv, num_well=num_well)
result = _bounce_quadrature(
x=self._x,
w=self._w,
z1=z1,
z2=z2,
pitch_inv=pitch_inv,
integrand=integrand,
pitch_inv=pitch_inv,
f=setdefault(f, []),
data=self._data,
knots=self._zeta,
Expand All @@ -377,11 +384,10 @@ def integrate(
self._dB_dz,
method,
)
assert result.shape[0] == pitch_inv.shape[0]
assert result.shape[-1] == setdefault(num_well, np.prod(self._dB_dz.shape[-2:]))
assert result.shape == z1.shape
return result

def plot(self, m, l, pitch_inv=None, **kwargs):
def plot(self, m, l, pitch_inv=None, /, **kwargs):
"""Plot the field line and bounce points of the given pitch angles.
Parameters
Expand All @@ -402,22 +408,21 @@ def plot(self, m, l, pitch_inv=None, **kwargs):
Matplotlib (fig, ax) tuple.
"""
B, dB_dz = self.B, self._dB_dz
if B.ndim == 4:
B = B[m, l]
dB_dz = dB_dz[m, l]
elif B.ndim == 3:
B = B[l]
dB_dz = dB_dz[l]

Check warning on line 417 in desc/integrals/bounce_integral.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_integral.py#L415-L417

Added lines #L415 - L417 were not covered by tests
if pitch_inv is not None:
pitch_inv = jnp.atleast_1d(jnp.squeeze(pitch_inv))
errorif(
pitch_inv.ndim != 1,
pitch_inv.ndim > 1,
msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.",
)
z1, z2 = bounce_points(
pitch_inv[:, jnp.newaxis, jnp.newaxis],
self._zeta,
self.B[m, l],
self._dB_dz[m, l],
)
z1, z2 = bounce_points(pitch_inv, self._zeta, B, dB_dz)
kwargs["z1"] = z1
kwargs["z2"] = z2
kwargs["k"] = pitch_inv
fig, ax = plot_ppoly(
PPoly(self.B[m, l].T, self._zeta), **_set_default_plot_kwargs(kwargs)
)
fig, ax = plot_ppoly(PPoly(B.T, self._zeta), **_set_default_plot_kwargs(kwargs))
return fig, ax
Loading

0 comments on commit 1436035

Please sign in to comment.