Skip to content

Commit

Permalink
Adding tests part 2
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Aug 16, 2024
1 parent ff991cc commit 744540a
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 123 deletions.
42 changes: 21 additions & 21 deletions desc/compute/_interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ def harmonic(a, M, axis=-1):
return h


def harmonic_basis(x, M):
def harmonic_vander(x, M):
"""Nyquist trigonometric interpolant basis evaluated at ``x``.
Parameters
----------
x : jnp.ndarray
Points to evaluate.
Points at which to evaluate pseudo-Vandermonde matrix.
M : int
Spectral resolution.
Returns
-------
basis : jnp.ndarray
Shape (*x.shape, M).
Basis evaluated at points ``x``.
Pseudo-Vandermonde matrix of degree ``M-1`` and sample points ``x``.
Last axis ordered as [1, cos(x), ..., cos(mx), sin(x), sin(2x), ..., sin(mx)].
"""
Expand Down Expand Up @@ -128,7 +128,7 @@ def interp_rfft(xq, f, axis=-1):
----------
xq : jnp.ndarray
Real query points where interpolation is desired.
Shape of ``xq`` must broadcast with ``f`` except along ``axis``.
Shape of ``xq`` must broadcast with arrays of shape ``np.delete(f.shape,axis)``.
f : jnp.ndarray
Real function values on uniform 2π periodic grid to interpolate.
axis : int
Expand All @@ -153,7 +153,7 @@ def irfft_non_uniform(xq, a, n, axis=-1):
----------
xq : jnp.ndarray
Real query points where interpolation is desired.
Shape of ``xq`` must broadcast with ``a`` except along ``axis``.
Shape of ``xq`` must broadcast with arrays of shape ``np.delete(a.shape,axis)``.
a : jnp.ndarray
Fourier coefficients ``a=rfft(f,axis=axis,norm="forward")``.
n : int
Expand All @@ -175,7 +175,7 @@ def irfft_non_uniform(xq, a, n, axis=-1):
.at[Index.get(-1, axis, a.ndim)]
.divide(1.0 + ((n % 2) == 0))
)
a = jnp.swapaxes(a[..., jnp.newaxis], axis % a.ndim, -1)
a = jnp.moveaxis(a, axis, -1)
m = jnp.fft.rfftfreq(n, d=1 / n)
basis = jnp.exp(-1j * m * xq[..., jnp.newaxis])
fq = jnp.linalg.vecdot(basis, a).real
Expand All @@ -193,7 +193,7 @@ def interp_rfft2(xq, f, axes=(-2, -1)):
Shape (..., 2).
Real query points where interpolation is desired.
Last axis must hold coordinates for a given point.
Shape of ``xq`` must broadcast ``f`` except along ``axes``.
Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(f.shape,axes)``.
f : jnp.ndarray
Shape (..., f.shape[-2], f.shape[-1]).
Real function values on uniform (2π × 2π) periodic tensor-product grid to
Expand Down Expand Up @@ -223,7 +223,7 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)):
Shape (..., 2).
Real query points where interpolation is desired.
Last axis must hold coordinates for a given point.
Shape of ``xq`` must broadcast ``a`` except along ``axes``.
Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(a.shape,axes)``.
a : jnp.ndarray
Shape (..., a.shape[-2], a.shape[-1]).
Fourier coefficients ``a=rfft2(f,axes=axes,norm="forward")``.
Expand All @@ -240,7 +240,6 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)):
Real function value at query points.
"""
errorif(axes != (-2, -1), NotImplementedError) # need to swap axes before reshape
assert xq.shape[-1] == 2
assert a.ndim >= 2
a = (
Expand All @@ -249,7 +248,9 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)):
.divide(2.0)
.at[Index.get(-1, axes[-1], a.ndim)]
.divide(1.0 + ((N % 2) == 0))
).reshape(*a.shape[:-2], 1, -1)
)
a = jnp.moveaxis(a, source=axes, destination=(-2, -1))
a = a.reshape(*a.shape[:-2], -1)

m = jnp.fft.fftfreq(M, d=1 / M)
n = jnp.fft.rfftfreq(N, d=1 / N)
Expand Down Expand Up @@ -295,7 +296,7 @@ def interp_dct(xq, f, lobatto=False, axis=-1):
----------
xq : jnp.ndarray
Real query points where interpolation is desired.
Shape of ``xq`` must broadcast with ``f`` except along ``axis``.
Shape of ``xq`` must broadcast with shape ``np.delete(f.shape,axis)``.
f : jnp.ndarray
Real function values on Chebyshev points to interpolate.
lobatto : bool
Expand All @@ -310,27 +311,26 @@ def interp_dct(xq, f, lobatto=False, axis=-1):
Real function value at query points.
"""
lobatto = bool(lobatto)
errorif(lobatto, NotImplementedError)
assert f.ndim >= 1
lobatto = bool(lobatto)
a = dct(f, type=2 - lobatto, axis=axis) / (f.shape[axis] - lobatto)
a = cheb_from_dct(
dct(f, type=2 - lobatto, axis=axis) / (f.shape[axis] - lobatto), axis
)
fq = idct_non_uniform(xq, a, f.shape[axis], axis)
return fq


def idct_non_uniform(xq, a, n, axis=-1):
"""Evaluate Discrete Cosine Transform coefficients ``a`` at ``xq`` ∈ [-1, 1].
"""Evaluate Discrete Chebyshev Transform coefficients ``a`` at ``xq`` ∈ [-1, 1].
Parameters
----------
xq : jnp.ndarray
Real query points where interpolation is desired.
Shape of ``xq`` must broadcast with ``a`` except along ``axis``.
Shape of ``xq`` must broadcast with shape ``np.delete(a.shape,axis)``.
a : jnp.ndarray
Discrete Cosine Transform coefficients, e.g.
``a=dct(f,type=2,axis=axis,norm="forward")``.
The discrete cosine transformation used by scipy is defined here.
docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dct.html#scipy.fft.dct
Discrete Chebyshev Transform coefficients.
n : int
Spectral resolution of ``a``.
axis : int
Expand All @@ -343,9 +343,9 @@ def idct_non_uniform(xq, a, n, axis=-1):
"""
assert a.ndim >= 1
a = cheb_from_dct(a, axis)
a = jnp.swapaxes(a[..., jnp.newaxis], axis % a.ndim, -1)
a = jnp.moveaxis(a, axis, -1)
basis = chebvander(xq, n - 1)
# Could instead use Clenshaw recursion with ``fq=chebval(xq,a,tensor=False)``.
fq = jnp.linalg.vecdot(basis, a)
return fq

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from interpax import CubicHermiteSpline, PPoly, interp1d
from jax.nn import softmax
from matplotlib import pyplot as plt
from orthax.legendre import leggauss
from numpy.polynomial.legendre import leggauss

from desc.backend import flatnonzero, imap, jnp, put
from desc.compute._interp_utils import poly_root, polyder_vec, polyval_vec
Expand Down
99 changes: 59 additions & 40 deletions desc/compute/fourier_bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
from matplotlib import pyplot as plt
from orthax.chebyshev import chebroots, chebvander
from orthax.chebyshev import chebroots
from orthax.legendre import leggauss

from desc.backend import dct, idct, irfft, jnp, rfft, rfft2
Expand All @@ -12,6 +12,7 @@
cheb_pts,
fourier_pts,
harmonic,
idct_non_uniform,
interp_rfft2,
irfft2_non_uniform,
irfft_non_uniform,
Expand All @@ -35,7 +36,7 @@ def _flatten_matrix(y):
return y.reshape(*y.shape[:-2], -1)


def _alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi):
def alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi):
"""Get sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) of field line.
Parameters
Expand All @@ -62,6 +63,19 @@ def _alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi):
return alphas


def _subtract(c, k):
# subtract k from last axis of c, obeying numpy broadcasting
c_0 = c[..., 0] - k
c = jnp.concatenate(
[
jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)),
c_0[..., jnp.newaxis],
],
axis=-1,
)
return c


class FourierChebyshevBasis:
"""Fourier-Chebyshev series.
Expand Down Expand Up @@ -206,10 +220,7 @@ def compute_cheb(self, x):
# Always add new axis to broadcast against Chebyshev coefficients.
x = jnp.atleast_1d(x)[..., jnp.newaxis]
cheb = cheb_from_dct(irfft_non_uniform(x, self._c, self.M, axis=-2), axis=-1)
assert cheb.shape[-2:] == (
x.shape[-2],
self.N,
), f"{cheb.shape}; {x.shape}; {self.N}"
assert cheb.shape[-2:] == (x.shape[-2], self.N)
return _PiecewiseChebyshevBasis(cheb, self.domain)


Expand Down Expand Up @@ -247,6 +258,16 @@ def __init__(self, cheb, domain):
self.N = cheb.shape[-1]
self.domain = domain

def _chebcast(self, arr):
# Input should not have rightmost dimension of cheb that iterates coefficients,
# but may have additional leftmost dimensions for batch operations.
errorif(
arr.ndim > self.cheb.ndim,
NotImplementedError,
msg=f"Got ndim {arr.ndim} > cheb.ndim {self.cheb.ndim}.",
)
return self.cheb if arr.ndim < self.cheb.ndim else self.cheb[jnp.newaxis]

def intersect(self, k=0, eps=_eps):
"""Coordinates yᵢ such that f(x, yᵢ) = k(x).
Expand Down Expand Up @@ -274,13 +295,7 @@ def intersect(self, k=0, eps=_eps):
Boolean array into ``y`` indicating whether element is an intersect.
"""
errorif(
k.ndim > self.cheb.ndim,
NotImplementedError,
msg=f"Got k.ndim {k.ndim} > cheb.ndim {self.cheb.ndim}.",
)
c = self.cheb if k.ndim < self.cheb.ndim else self.cheb[jnp.newaxis]
c = c.copy().at[..., 0].add(-k)
c = _subtract(self._chebcast(k), k)
# roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x)
y = _chebroots_vec(c)
assert y.shape == (*c.shape[:-1], self.N - 1)
Expand All @@ -289,13 +304,15 @@ def intersect(self, k=0, eps=_eps):
# Pick sentinel above such that only distinct roots are considered intersects.
is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1)
y = jnp.where(is_intersect, y.real, 0) # ensure y is in domain of arcos

# TODO: Multipoint evaluation with FFT.
# Chapter 10, https://doi.org/10.1017/CBO9781139856065.
n = jnp.arange(self.N)
# ∂f/∂y = ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n Uₙ₋₁(y)
# sign ∂f/∂y = sign ∑ₙ₌ᴺ⁻¹ aₙ(x) sin(n arcos y)
# sign ∂f/∂y = sign ∑ₙ₌ᴺ⁻¹ aₙ(x) n sin(n arcos y)
s = jnp.linalg.vecdot(
# TODO: Multipoint evaluation with FFT.
# Chapter 10, https://doi.org/10.1017/CBO9781139856065.
n * jnp.sin(n * jnp.arccos(y)[..., jnp.newaxis]),
self.cheb[..., jnp.newaxis, :],
jnp.sin(jnp.arange(self.N) * jnp.arccos(y)[..., jnp.newaxis]),
)
is_decreasing = s <= 0
is_increasing = s >= 0
Expand Down Expand Up @@ -547,11 +564,9 @@ def eval1d(self, z):
y = bijection_to_disc(y, self.domain[0], self.domain[1])
# Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z])
# are held in self.cheb with shape (..., num cheb series, N).
cheb = jnp.moveaxis(self.cheb, source=-1, destination=0)
cheb = jnp.take_along_axis(cheb, x_idx, axis=-1)
# TODO: Multipoint evaluation with FFT.
# Chapter 10, https://doi.org/10.1017/CBO9781139856065.
f = jnp.linalg.vecdot(chebvander(y, self.N - 1), cheb)
cheb = jnp.take_along_axis(self._chebcast(z), x_idx[..., jnp.newaxis], axis=-2)
f = idct_non_uniform(y, cheb, self.N)
assert f.shape == z.shape
return f

def _isomorphism_1d(self, y):
Expand Down Expand Up @@ -595,10 +610,8 @@ def _isomorphism_2d(self, z):
Isomorphic coordinates.
"""
period = self.domain[-1] - self.domain[0]
x_index = z // period
y_value = z % period
return x_index, y_value
x_index, y_value = jnp.divmod(z, self.domain[-1] - self.domain[0])
return x_index.astype(int), y_value


def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch):
Expand Down Expand Up @@ -655,7 +668,7 @@ def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch)
Returns
-------
result : jnp.ndarray
Shape (P, S, num_well).
Shape (P, L, num_well).
First axis enumerates pitch values. Second axis enumerates the field lines.
Last axis enumerates the bounce integrals.
Expand Down Expand Up @@ -684,14 +697,15 @@ def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch)
w,
)
assert result.shape == (P, L, num_well)
return result


def required_names():
"""Return names in ``data_index`` required to compute bounce integrals."""
return ["B^zeta", "|B|"]


# TODO: Assumes zeta = phi
# TODO: Assumes zeta = phi (alpha sequence)
def bounce_integral(
grid,
data,
Expand Down Expand Up @@ -765,14 +779,16 @@ def bounce_integral(
bounce_integrate : callable
This callable method computes the bounce integral ∫ f(ℓ) dℓ for every
specified field line for every λ value in ``pitch``.
alphas : jnp.ndarray
Sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line.
B : _PiecewiseChebyshevBasis
Set of 1D Chebyshev spectral coefficients of |B| along field line.
{|B|_α : ζ |B|(α, ζ) | α ∈ A } .
T : _PiecewiseChebyshevBasis
Set of 1D Chebyshev spectral coefficients of θ along field line.
{θ_α : ζ θ(α, ζ) | α ∈ A }.
spline : tuple(ndarray, _PiecewiseChebyshevBasis, _PiecewiseChebyshevBasis)
alphas : jnp.ndarray
Poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line.
B : _PiecewiseChebyshevBasis
Set of 1D Chebyshev spectral coefficients of |B| along field line.
{|B|_α : ζ |B|(α, ζ) | α ∈ A } .
T : _PiecewiseChebyshevBasis
Set of 1D Chebyshev spectral coefficients of θ along field line.
{θ_α : ζ θ(α, ζ) | α ∈ A }.
"""
# Resolution of periodic DESC coordinate tensor-product grid.
L, m, n = grid.num_rho, grid.num_theta, grid.num_zeta
Expand All @@ -798,7 +814,7 @@ def bounce_integral(
).reshape(L, M, N),
)
# Peel off field lines.
alphas = _alpha_sequence(alpha_0, grid.compress(data["iota"]), num_transit)
alphas = alpha_sequence(alpha_0, grid.compress(data["iota"]), num_transit)
T = T.compute_cheb(alphas)
B = B.compute_cheb(alphas)
assert T.cheb.shape == B.cheb.shape == (L, num_transit, N)
Expand Down Expand Up @@ -863,12 +879,15 @@ def bounce_integrate(integrand, f, pitch, weight=None, num_well=None):
errorif(weight is not None, NotImplementedError)
# Compute bounce points.
pitch = jnp.atleast_3d(pitch)
P = pitch.shape[0]
assert pitch.shape[1:] == B.cheb.shape[:-1], f"{pitch.shape}; {B.cheb.shape}"
assert (
pitch.shape[1] == B.cheb.shape[0]
or pitch.shape[1] == 1
or B.cheb.shape[0] == 1
)
bp1, bp2 = B.bounce_points(*B.intersect(1 / pitch), num_well)
P = pitch.shape[0]
num_well = bp1.shape[-1]
assert bp1.shape == bp2.shape == (P, L, num_well)

result = _bounce_quadrature(
bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch
)
Expand Down
2 changes: 2 additions & 0 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ def get_rtz_grid(
"z": "zeta",
"p": "phi",
}
if "iota" in kwargs:
kwargs["iota"] = grid.expand(kwargs["iota"], surface_label="rho")
rtz_nodes = map_coordinates(
eq,
grid.nodes,
Expand Down
Loading

0 comments on commit 744540a

Please sign in to comment.