Skip to content

Commit

Permalink
Reviving fourier bounce pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Sep 17, 2024
1 parent 4d469e8 commit dbd80a5
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 74 deletions.
55 changes: 23 additions & 32 deletions desc/integrals/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from desc.backend import dct, flatnonzero, idct, irfft, jnp, put, rfft
from desc.integrals.interp_utils import (
_filter_distinct,
_subtract_first,
cheb_from_dct,
cheb_pts,
chebroots_vec,
Expand All @@ -29,23 +30,6 @@
)


def _subtract(c, k):
"""Subtract ``k`` from first index of last axis of ``c``.
Semantically same as ``return c.copy().at[...,0].add(-k)``,
but allows dimension to increase.
"""
c_0 = c[..., 0] - k
c = jnp.concatenate(
[
c_0[..., jnp.newaxis],
jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)),
],
axis=-1,
)
return c


@partial(jnp.vectorize, signature="(m),(m)->(m)")
def _in_epigraph_and(is_intersect, df_dy_sign, /):
"""Set and epigraph of function f with the given set of points.
Expand Down Expand Up @@ -162,12 +146,11 @@ def __init__(self, f, domain=(-1, 1), lobatto=False):
@staticmethod
def _fast_transform(f, lobatto):
N = f.shape[-1]
c = rfft(
return rfft(
dct(f, type=2 - lobatto, axis=-1) / (N - lobatto),
axis=-2,
norm="forward",
)
return c

@staticmethod
def nodes(M, N, L=None, domain=(-1, 1), lobatto=False):
Expand Down Expand Up @@ -204,8 +187,9 @@ def nodes(M, N, L=None, domain=(-1, 1), lobatto=False):
coords = (jnp.atleast_1d(L), x, y)
else:
coords = (x, y)
coords = list(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij")))
coords = jnp.column_stack(coords)
coords = jnp.column_stack(
list(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij")))
)
return coords

def evaluate(self, M, N):
Expand Down Expand Up @@ -424,17 +408,17 @@ def intersect2d(self, k=0.0, eps=_eps):
Sign of ∂f/∂y (x, yᵢ).
"""
c = _subtract(_chebcast(self.cheb, k), k)
c = _subtract_first(_chebcast(self.cheb, k), k)
# roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x)
y = chebroots_vec(c)
assert y.shape == (*c.shape[:-1], self.N - 1)

# Intersects must satisfy y ∈ [-1, 1].
# Pick sentinel such that only distinct roots are considered intersects.
y = _filter_distinct(y, sentinel=-2.0, eps=eps)
is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1.0)
# Ensure y is in domain of arcos; choose 1 because kernel probably cheaper.
y = jnp.where(is_intersect, y.real, 1.0)
is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) < 1.0)
# Ensure y is in differentiable domain of arcos: (-1, 1).
y = jnp.where(is_intersect, y.real, 0)

# TODO: Multipoint evaluation with FFT.
# Chapter 10, https://doi.org/10.1017/CBO9781139856065.
Expand Down Expand Up @@ -473,7 +457,7 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0):
z1, z2 : (jnp.ndarray, jnp.ndarray)
Shape broadcasts with (..., *self.cheb.shape[:-2], num_intersect).
``z1`` and ``z2`` are intersects satisfying ∂f/∂y <= 0 and ∂f/∂y >= 0,
respectively. The points are grouped and ordered such that the straight
respectively. The points are ordered and grouped such that the straight
line path between ``z1`` and ``z2`` resides in the epigraph of f.
"""
Expand All @@ -500,7 +484,9 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0):
# this, for those subset of pitch values the integrations will be done in
# the hypograph of |B|, which will yield zero. If in far future decide to
# not ignore this, note the solution is to disqualify intersects within
# ``_eps`` from ``domain[-1]``.
# ``_eps`` from ``domain[-1]``. Edit: For differentiability, we cannot
# consider intersects at boundary of Chebyshev polynomial. Again, cases
# where this would be incorrect have measure zero.
is_z1 = (df_dy_sign <= 0) & is_intersect
is_z2 = (df_dy_sign >= 0) & _in_epigraph_and(is_intersect, df_dy_sign)

Expand All @@ -519,7 +505,8 @@ def _check_shape(self, z1, z2, k):
# Ensure pitch batch dim exists and add back dim to broadcast with wells.
k = atleast_nd(self.cheb.ndim - 1, k)[..., jnp.newaxis]
# Same but back dim already exists.
z1, z2 = atleast_nd(self.cheb.ndim, z1, z2)
z1 = atleast_nd(self.cheb.ndim, z1)
z2 = atleast_nd(self.cheb.ndim, z2)
# Cheb has shape (..., M, N) and others
# have shape (K, ..., W)
errorif(not (z1.ndim == z2.ndim == k.ndim == self.cheb.ndim))
Expand All @@ -533,7 +520,7 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs):
z1, z2 : jnp.ndarray
Shape must broadcast with (*self.cheb.shape[:-2], W).
``z1`` and ``z2`` are intersects satisfying ∂f/∂y <= 0 and ∂f/∂y >= 0,
respectively. The points are grouped and ordered such that the straight
respectively. The points are ordered and grouped such that the straight
line path between ``z1`` and ``z2`` resides in the epigraph of f.
k : jnp.ndarray
Shape must broadcast with *self.cheb.shape[:-2].
Expand All @@ -560,8 +547,8 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs):

# Ensure l axis exists for iteration in below loop.
cheb = atleast_nd(3, self.cheb)
mask, z1, z2, f_midpoint = atleast_3d_mid(mask, z1, z2, f_midpoint)
err_1, err_2, err_3 = atleast_2d_end(err_1, err_2, err_3)
mask, z1, z2, f_midpoint = map(atleast_3d_mid, (mask, z1, z2, f_midpoint))
err_1, err_2, err_3 = map(atleast_2d_end, (err_1, err_2, err_3))

for l in np.ndindex(cheb.shape[:-2]):
for p in range(k.shape[0]):
Expand Down Expand Up @@ -610,6 +597,7 @@ def plot1d(
hlabel=r"$z$",
vlabel=r"$f$",
show=True,
include_legend=True,
):
"""Plot the piecewise Chebyshev series.
Expand Down Expand Up @@ -641,6 +629,8 @@ def plot1d(
Vertical axis label.
show : bool
Whether to show the plot. Default is true.
include_legend : bool
Whether to include the legend in the plot. Default is true.
Returns
-------
Expand All @@ -666,7 +656,8 @@ def plot1d(
)
ax.set_xlabel(hlabel)
ax.set_ylabel(vlabel)
ax.legend(legend.values(), legend.keys(), loc="lower right")
if include_legend:
ax.legend(legend.values(), legend.keys(), loc="lower right")
ax.set_title(title)
plt.tight_layout()
if show:
Expand Down
75 changes: 39 additions & 36 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def _transform_to_desc(grid, f):
# After GitHub issue #1034 is resolved, we should pass in the previous
# θ(α) coordinates as an initial guess for the next coordinate mapping.
# Perhaps tell the optimizer to perturb the coefficients of the
# |B|(α, ζ) directly? Maybe auto diff to see change on |B|(θ, ζ)
# and hence stream functions. Not sure how feasible...
# |B|(α, ζ) directly? think perturbing alpha is equivalent to perturbing
# lambda. Not sure if possible..

# TODO: Allow multiple starting labels for near-rational surfaces.
# can just concatenate along second to last axis of cheb, but will
Expand All @@ -115,12 +115,12 @@ def _transform_to_desc(grid, f):
class Bounce2D:
"""Computes bounce integrals using two-dimensional pseudo-spectral methods.
The bounce integral is defined as ∫ f(ℓ) dℓ, where
The bounce integral is defined as ∫ f(λ, ℓ) dℓ, where
dℓ parameterizes the distance along the field line in meters,
f(ℓ) is the quantity to integrate along the field line,
and the boundaries of the integral are bounce points ζ₁, ζ₂ s.t. λ|B|(ζᵢ) = 1,
where λ is a constant proportional to the magnetic moment over energy
and |B| is the norm of the magnetic field.
f(λ, ℓ) is the quantity to integrate along the field line,
and the boundaries of the integral are bounce points ₁, ₂ s.t. λ|B|(ℓᵢ) = 1,
where λ is a constant defining the integral proportional to the magnetic moment
over energy and |B| is the norm of the magnetic field.
For a particle with fixed λ, bounce points are defined to be the location on the
field line such that the particle's velocity parallel to the magnetic field is zero.
Expand Down Expand Up @@ -290,11 +290,17 @@ def __init__(
quad : (jnp.ndarray, jnp.ndarray)
Quadrature points xₖ and weights wₖ for the approximate evaluation of an
integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points.
For weak singular integrals, use ``chebgauss2`` from
``desc.integrals.quad_utils``.
For strong singular integrals, use ``leggauss``.
automorphism : (Callable, Callable) or None
The first callable should be an automorphism of the real interval [-1, 1].
The second callable should be the derivative of the first. This map defines
a change of variable for the bounce integral. The choice made for the
automorphism will affect the performance of the quadrature method.
For weak singular integrals, use ``None``.
For strong singular integrals, use ``automorphism_sin`` from
``desc.integrals.quad_utils``.
Bref : float
Optional. Reference magnetic field strength for normalization.
Lref : float
Expand Down Expand Up @@ -421,17 +427,16 @@ def _L(self):
"""int: Number of flux surfaces to compute on."""
return self._B.cheb.shape[0]

def bounce_points(self, pitch, num_well=None):
def bounce_points(self, pitch_inv, num_well=None):
"""Compute bounce points.
Parameters
----------
pitch : jnp.ndarray
Shape (P, L).
λ values to evaluate the bounce integral at each field line. λ(ρ) is
specified by ``pitch[...,ρ]`` where in the latter the labels ρ are
interpreted as the index into the last axis that corresponds to that field
line. If two-dimensional, the first axis is the batch axis.
pitch_inv : jnp.ndarray
Shape (M, L, P). # TODO: right now set up is (P, L).
1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by
``pitch_inv[α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
pitch along each field line. This is useful if ``num_well`` tightly
Expand All @@ -451,9 +456,9 @@ def bounce_points(self, pitch, num_well=None):
epigraph of |B|.
"""
return self._B.intersect1d(1 / jnp.atleast_2d(pitch), num_well)
return self._B.intersect1d(jnp.atleast_2d(pitch_inv), num_well)

def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs):
def check_bounce_points(self, z1, z2, pitch_inv, plot=True, **kwargs):
"""Check that bounce points are computed correctly.
Parameters
Expand All @@ -463,12 +468,11 @@ def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs):
ζ coordinates of bounce points. The points are grouped and ordered such
that the straight line path between ``z1`` and ``z2`` resides in the
epigraph of |B|.
pitch : jnp.ndarray
Shape (P, L).
λ values to evaluate the bounce integral at each field line. λ(ρ) is
specified by ``pitch[...,ρ]`` where in the latter the labels ρ are
interpreted as the index into the last axis that corresponds to that field
line. If two-dimensional, the first axis is the batch axis.
pitch_inv : jnp.ndarray
Shape (M, L, P). # TODO: right now set up is (P, L).
1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by
``pitch_inv[α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
plot : bool
Whether to plot stuff.
kwargs : dict
Expand All @@ -483,22 +487,21 @@ def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs):
kwargs.setdefault("klabel", r"$1/\lambda$")
kwargs.setdefault("hlabel", r"$\zeta$")
kwargs.setdefault("vlabel", r"$\vert B \vert$")
self._B.check_intersect1d(z1, z2, 1 / pitch, plot, **kwargs)
self._B.check_intersect1d(z1, z2, pitch_inv, plot, **kwargs)

def integrate(self, pitch, integrand, f, weight=None, num_well=None):
"""Bounce integrate ∫ f(ℓ) dℓ.
def integrate(self, pitch_inv, integrand, f, weight=None, num_well=None):
"""Bounce integrate ∫ f(λ, ℓ) dℓ.
Computes the bounce integral ∫ f(ℓ) dℓ for every specified field line
Computes the bounce integral ∫ f(λ, ℓ) dℓ for every specified field line
for every λ value in ``pitch``.
Parameters
----------
pitch : jnp.ndarray
Shape (P, L).
λ values to evaluate the bounce integral at each field line. λ(ρ) is
specified by ``pitch[...,ρ]`` where in the latter the labels ρ are
interpreted as the index into the last axis that corresponds to that field
line. If two-dimensional, the first axis is the batch axis.
pitch_inv : jnp.ndarray
Shape (M, L, P). # TODO: right now set up is (P, L).
1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by
``pitch_inv[α,ρ]`` where in the latter the labels are interpreted
as the indices that correspond to that field line.
integrand : callable
The composition operator on the set of functions in ``f`` that maps the
functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the
Expand All @@ -514,7 +517,7 @@ def integrate(self, pitch, integrand, f, weight=None, num_well=None):
weight : jnp.ndarray
Shape (L, 1, m, n).
If supplied, the bounce integral labeled by well j is weighted such that
the returned value is w(j) ∫ f(ℓ) dℓ, where w(j) is ``weight``
the returned value is w(j) ∫ f(λ, ℓ) dℓ, where w(j) is ``weight``
interpolated to the deepest point in the magnetic well. Use the method
``self.reshape_data`` to reshape the data into the expected shape.
num_well : int or None
Expand All @@ -535,9 +538,9 @@ def integrate(self, pitch, integrand, f, weight=None, num_well=None):
Last axis enumerates the bounce integrals.
"""
pitch = jnp.atleast_2d(pitch)
z1, z2 = self.bounce_points(pitch, num_well)
result = self._integrate(z1, z2, pitch, integrand, f)
pitch_inv = jnp.atleast_2d(pitch_inv)
z1, z2 = self.bounce_points(pitch_inv, num_well)
result = self._integrate(z1, z2, pitch_inv, integrand, f)
errorif(weight is not None, NotImplementedError)
return result

Expand Down
27 changes: 23 additions & 4 deletions desc/integrals/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# TODO: Boyd's method 𝒪(N²) instead of Chebyshev companion matrix 𝒪(N³).
# John P. Boyd, Computing real roots of a polynomial in Chebyshev series
# form through subdivision. https://doi.org/10.1016/j.apnum.2005.09.007.
# This is likely the bottleneck.
chebroots_vec = jnp.vectorize(chebroots, signature="(m)->(n)")


Expand Down Expand Up @@ -143,10 +144,11 @@ def harmonic_vander(x, M):
# TODO: For inverse transforms, do multipoint evaluation with FFT.
# FFT cost is 𝒪(M N log[M N]) while direct evaluation is 𝒪(M² N²).
# Chapter 10, https://doi.org/10.1017/CBO9781139856065.
# Right now we just do an MMT with the Vandermode matrix.
# Multipoint is likely better than using NFFT to evaluate f(xq) given fourier
# coefficients because evaluation points are quadratically packed near edges as
# required by quadrature to avoid runge. NFFT is only approximation anyway.
# Right now we do an MMT with the Vandermode matrix.
# Multipoint is likely better than using NFFT (for strong singular bounce
# integrals) to evaluate f(xq) given fourier coefficients because evaluation
# points are quadratically packed near edges for efficient quadrature. For
# weak singularities (e.g. effective ripple) NFFT should work well.
# https://github.com/flatironinstitute/jax-finufft.


Expand Down Expand Up @@ -451,6 +453,23 @@ def polyval_vec(*, x, c):
# TODO: Eventually do a PR to move this stuff into interpax.


def _subtract_first(c, k):
"""Subtract ``k`` from first index of last axis of ``c``.
Semantically same as ``return c.copy().at[...,0].add(-k)``,
but allows dimension to increase.
"""
c_0 = c[..., 0] - k
c = jnp.concatenate(
[
c_0[..., jnp.newaxis],
jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)),
],
axis=-1,
)
return c


def _subtract_last(c, k):
"""Subtract ``k`` from last index of last axis of ``c``.
Expand Down
12 changes: 12 additions & 0 deletions desc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,18 @@ def atleast_nd(ndmin, ary):
return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary


def atleast_3d_mid(ary):
"""Like np.atleast_3d but if adds dim at axis 1 for 2d arrays."""
ary = jnp.atleast_2d(ary)
return ary[:, jnp.newaxis] if ary.ndim == 2 else ary


def atleast_2d_end(ary):
"""Like np.atleast_2d but if adds dim at axis 1 for 1d arrays."""
ary = jnp.atleast_1d(ary)
return ary[:, jnp.newaxis] if ary.ndim == 1 else ary


PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text


Expand Down
4 changes: 2 additions & 2 deletions tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ def integrand_den(B, pitch):

normalization = -np.sign(data["psi"]) * data["Bref"] * data["a"] ** 2
drift_numerical_num = bounce.integrate(
pitch=pitch[:, np.newaxis],
pitch_inv=pitch[:, np.newaxis],
integrand=integrand_num,
f=Bounce2D.reshape_data(
grid,
Expand All @@ -1645,7 +1645,7 @@ def integrand_den(B, pitch):
num_well=1,
)
drift_numerical_den = bounce.integrate(
pitch=pitch[:, np.newaxis],
pitch_inv=pitch[:, np.newaxis],
integrand=integrand_den,
f=[],
num_well=1,
Expand Down

0 comments on commit dbd80a5

Please sign in to comment.