From 744540a5028cac37eb4fb2d454756a5cb714c48a Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 16 Aug 2024 01:00:23 -0400 Subject: [PATCH] Adding tests part 2 --- desc/compute/_interp_utils.py | 42 ++++---- desc/compute/bounce_integral.py | 2 +- desc/compute/fourier_bounce_integral.py | 99 ++++++++++-------- desc/equilibrium/coords.py | 2 + tests/test_bounce_integral.py | 127 ++++++++++++++---------- tests/test_fourier_bounce.py | 111 ++++++++++++++++++++- tests/test_interp_utils.py | 4 +- 7 files changed, 264 insertions(+), 123 deletions(-) diff --git a/desc/compute/_interp_utils.py b/desc/compute/_interp_utils.py index 4fa1ec327a..8284c1a02d 100644 --- a/desc/compute/_interp_utils.py +++ b/desc/compute/_interp_utils.py @@ -85,13 +85,13 @@ def harmonic(a, M, axis=-1): return h -def harmonic_basis(x, M): +def harmonic_vander(x, M): """Nyquist trigonometric interpolant basis evaluated at ``x``. Parameters ---------- x : jnp.ndarray - Points to evaluate. + Points at which to evaluate pseudo-Vandermonde matrix. M : int Spectral resolution. @@ -99,7 +99,7 @@ def harmonic_basis(x, M): ------- basis : jnp.ndarray Shape (*x.shape, M). - Basis evaluated at points ``x``. + Pseudo-Vandermonde matrix of degree ``M-1`` and sample points ``x``. Last axis ordered as [1, cos(x), ..., cos(mx), sin(x), sin(2x), ..., sin(mx)]. """ @@ -128,7 +128,7 @@ def interp_rfft(xq, f, axis=-1): ---------- xq : jnp.ndarray Real query points where interpolation is desired. - Shape of ``xq`` must broadcast with ``f`` except along ``axis``. + Shape of ``xq`` must broadcast with arrays of shape ``np.delete(f.shape,axis)``. f : jnp.ndarray Real function values on uniform 2π periodic grid to interpolate. axis : int @@ -153,7 +153,7 @@ def irfft_non_uniform(xq, a, n, axis=-1): ---------- xq : jnp.ndarray Real query points where interpolation is desired. - Shape of ``xq`` must broadcast with ``a`` except along ``axis``. + Shape of ``xq`` must broadcast with arrays of shape ``np.delete(a.shape,axis)``. a : jnp.ndarray Fourier coefficients ``a=rfft(f,axis=axis,norm="forward")``. n : int @@ -175,7 +175,7 @@ def irfft_non_uniform(xq, a, n, axis=-1): .at[Index.get(-1, axis, a.ndim)] .divide(1.0 + ((n % 2) == 0)) ) - a = jnp.swapaxes(a[..., jnp.newaxis], axis % a.ndim, -1) + a = jnp.moveaxis(a, axis, -1) m = jnp.fft.rfftfreq(n, d=1 / n) basis = jnp.exp(-1j * m * xq[..., jnp.newaxis]) fq = jnp.linalg.vecdot(basis, a).real @@ -193,7 +193,7 @@ def interp_rfft2(xq, f, axes=(-2, -1)): Shape (..., 2). Real query points where interpolation is desired. Last axis must hold coordinates for a given point. - Shape of ``xq`` must broadcast ``f`` except along ``axes``. + Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(f.shape,axes)``. f : jnp.ndarray Shape (..., f.shape[-2], f.shape[-1]). Real function values on uniform (2π × 2π) periodic tensor-product grid to @@ -223,7 +223,7 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)): Shape (..., 2). Real query points where interpolation is desired. Last axis must hold coordinates for a given point. - Shape of ``xq`` must broadcast ``a`` except along ``axes``. + Shape ``xq.shape[:-1]`` must broadcast with shape ``np.delete(a.shape,axes)``. a : jnp.ndarray Shape (..., a.shape[-2], a.shape[-1]). Fourier coefficients ``a=rfft2(f,axes=axes,norm="forward")``. @@ -240,7 +240,6 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)): Real function value at query points. """ - errorif(axes != (-2, -1), NotImplementedError) # need to swap axes before reshape assert xq.shape[-1] == 2 assert a.ndim >= 2 a = ( @@ -249,7 +248,9 @@ def irfft2_non_uniform(xq, a, M, N, axes=(-2, -1)): .divide(2.0) .at[Index.get(-1, axes[-1], a.ndim)] .divide(1.0 + ((N % 2) == 0)) - ).reshape(*a.shape[:-2], 1, -1) + ) + a = jnp.moveaxis(a, source=axes, destination=(-2, -1)) + a = a.reshape(*a.shape[:-2], -1) m = jnp.fft.fftfreq(M, d=1 / M) n = jnp.fft.rfftfreq(N, d=1 / N) @@ -295,7 +296,7 @@ def interp_dct(xq, f, lobatto=False, axis=-1): ---------- xq : jnp.ndarray Real query points where interpolation is desired. - Shape of ``xq`` must broadcast with ``f`` except along ``axis``. + Shape of ``xq`` must broadcast with shape ``np.delete(f.shape,axis)``. f : jnp.ndarray Real function values on Chebyshev points to interpolate. lobatto : bool @@ -310,27 +311,26 @@ def interp_dct(xq, f, lobatto=False, axis=-1): Real function value at query points. """ + lobatto = bool(lobatto) errorif(lobatto, NotImplementedError) assert f.ndim >= 1 - lobatto = bool(lobatto) - a = dct(f, type=2 - lobatto, axis=axis) / (f.shape[axis] - lobatto) + a = cheb_from_dct( + dct(f, type=2 - lobatto, axis=axis) / (f.shape[axis] - lobatto), axis + ) fq = idct_non_uniform(xq, a, f.shape[axis], axis) return fq def idct_non_uniform(xq, a, n, axis=-1): - """Evaluate Discrete Cosine Transform coefficients ``a`` at ``xq`` ∈ [-1, 1]. + """Evaluate Discrete Chebyshev Transform coefficients ``a`` at ``xq`` ∈ [-1, 1]. Parameters ---------- xq : jnp.ndarray Real query points where interpolation is desired. - Shape of ``xq`` must broadcast with ``a`` except along ``axis``. + Shape of ``xq`` must broadcast with shape ``np.delete(a.shape,axis)``. a : jnp.ndarray - Discrete Cosine Transform coefficients, e.g. - ``a=dct(f,type=2,axis=axis,norm="forward")``. - The discrete cosine transformation used by scipy is defined here. - docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dct.html#scipy.fft.dct + Discrete Chebyshev Transform coefficients. n : int Spectral resolution of ``a``. axis : int @@ -343,9 +343,9 @@ def idct_non_uniform(xq, a, n, axis=-1): """ assert a.ndim >= 1 - a = cheb_from_dct(a, axis) - a = jnp.swapaxes(a[..., jnp.newaxis], axis % a.ndim, -1) + a = jnp.moveaxis(a, axis, -1) basis = chebvander(xq, n - 1) + # Could instead use Clenshaw recursion with ``fq=chebval(xq,a,tensor=False)``. fq = jnp.linalg.vecdot(basis, a) return fq diff --git a/desc/compute/bounce_integral.py b/desc/compute/bounce_integral.py index 94036dae70..bff1b9cdf6 100644 --- a/desc/compute/bounce_integral.py +++ b/desc/compute/bounce_integral.py @@ -6,7 +6,7 @@ from interpax import CubicHermiteSpline, PPoly, interp1d from jax.nn import softmax from matplotlib import pyplot as plt -from orthax.legendre import leggauss +from numpy.polynomial.legendre import leggauss from desc.backend import flatnonzero, imap, jnp, put from desc.compute._interp_utils import poly_root, polyder_vec, polyval_vec diff --git a/desc/compute/fourier_bounce_integral.py b/desc/compute/fourier_bounce_integral.py index b46a7f5bac..0459996601 100644 --- a/desc/compute/fourier_bounce_integral.py +++ b/desc/compute/fourier_bounce_integral.py @@ -2,7 +2,7 @@ import numpy as np from matplotlib import pyplot as plt -from orthax.chebyshev import chebroots, chebvander +from orthax.chebyshev import chebroots from orthax.legendre import leggauss from desc.backend import dct, idct, irfft, jnp, rfft, rfft2 @@ -12,6 +12,7 @@ cheb_pts, fourier_pts, harmonic, + idct_non_uniform, interp_rfft2, irfft2_non_uniform, irfft_non_uniform, @@ -35,7 +36,7 @@ def _flatten_matrix(y): return y.reshape(*y.shape[:-2], -1) -def _alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi): +def alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi): """Get sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) of field line. Parameters @@ -62,6 +63,19 @@ def _alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi): return alphas +def _subtract(c, k): + # subtract k from last axis of c, obeying numpy broadcasting + c_0 = c[..., 0] - k + c = jnp.concatenate( + [ + jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)), + c_0[..., jnp.newaxis], + ], + axis=-1, + ) + return c + + class FourierChebyshevBasis: """Fourier-Chebyshev series. @@ -206,10 +220,7 @@ def compute_cheb(self, x): # Always add new axis to broadcast against Chebyshev coefficients. x = jnp.atleast_1d(x)[..., jnp.newaxis] cheb = cheb_from_dct(irfft_non_uniform(x, self._c, self.M, axis=-2), axis=-1) - assert cheb.shape[-2:] == ( - x.shape[-2], - self.N, - ), f"{cheb.shape}; {x.shape}; {self.N}" + assert cheb.shape[-2:] == (x.shape[-2], self.N) return _PiecewiseChebyshevBasis(cheb, self.domain) @@ -247,6 +258,16 @@ def __init__(self, cheb, domain): self.N = cheb.shape[-1] self.domain = domain + def _chebcast(self, arr): + # Input should not have rightmost dimension of cheb that iterates coefficients, + # but may have additional leftmost dimensions for batch operations. + errorif( + arr.ndim > self.cheb.ndim, + NotImplementedError, + msg=f"Got ndim {arr.ndim} > cheb.ndim {self.cheb.ndim}.", + ) + return self.cheb if arr.ndim < self.cheb.ndim else self.cheb[jnp.newaxis] + def intersect(self, k=0, eps=_eps): """Coordinates yᵢ such that f(x, yᵢ) = k(x). @@ -274,13 +295,7 @@ def intersect(self, k=0, eps=_eps): Boolean array into ``y`` indicating whether element is an intersect. """ - errorif( - k.ndim > self.cheb.ndim, - NotImplementedError, - msg=f"Got k.ndim {k.ndim} > cheb.ndim {self.cheb.ndim}.", - ) - c = self.cheb if k.ndim < self.cheb.ndim else self.cheb[jnp.newaxis] - c = c.copy().at[..., 0].add(-k) + c = _subtract(self._chebcast(k), k) # roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x) y = _chebroots_vec(c) assert y.shape == (*c.shape[:-1], self.N - 1) @@ -289,13 +304,15 @@ def intersect(self, k=0, eps=_eps): # Pick sentinel above such that only distinct roots are considered intersects. is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1) y = jnp.where(is_intersect, y.real, 0) # ensure y is in domain of arcos + + # TODO: Multipoint evaluation with FFT. + # Chapter 10, https://doi.org/10.1017/CBO9781139856065. + n = jnp.arange(self.N) # ∂f/∂y = ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n Uₙ₋₁(y) - # sign ∂f/∂y = sign ∑ₙ₌₁ᴺ⁻¹ aₙ(x) sin(n arcos y) + # sign ∂f/∂y = sign ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n sin(n arcos y) s = jnp.linalg.vecdot( - # TODO: Multipoint evaluation with FFT. - # Chapter 10, https://doi.org/10.1017/CBO9781139856065. + n * jnp.sin(n * jnp.arccos(y)[..., jnp.newaxis]), self.cheb[..., jnp.newaxis, :], - jnp.sin(jnp.arange(self.N) * jnp.arccos(y)[..., jnp.newaxis]), ) is_decreasing = s <= 0 is_increasing = s >= 0 @@ -547,11 +564,9 @@ def eval1d(self, z): y = bijection_to_disc(y, self.domain[0], self.domain[1]) # Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z]) # are held in self.cheb with shape (..., num cheb series, N). - cheb = jnp.moveaxis(self.cheb, source=-1, destination=0) - cheb = jnp.take_along_axis(cheb, x_idx, axis=-1) - # TODO: Multipoint evaluation with FFT. - # Chapter 10, https://doi.org/10.1017/CBO9781139856065. - f = jnp.linalg.vecdot(chebvander(y, self.N - 1), cheb) + cheb = jnp.take_along_axis(self._chebcast(z), x_idx[..., jnp.newaxis], axis=-2) + f = idct_non_uniform(y, cheb, self.N) + assert f.shape == z.shape return f def _isomorphism_1d(self, y): @@ -595,10 +610,8 @@ def _isomorphism_2d(self, z): Isomorphic coordinates. """ - period = self.domain[-1] - self.domain[0] - x_index = z // period - y_value = z % period - return x_index, y_value + x_index, y_value = jnp.divmod(z, self.domain[-1] - self.domain[0]) + return x_index.astype(int), y_value def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch): @@ -655,7 +668,7 @@ def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch) Returns ------- result : jnp.ndarray - Shape (P, S, num_well). + Shape (P, L, num_well). First axis enumerates pitch values. Second axis enumerates the field lines. Last axis enumerates the bounce integrals. @@ -684,6 +697,7 @@ def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch) w, ) assert result.shape == (P, L, num_well) + return result def required_names(): @@ -691,7 +705,7 @@ def required_names(): return ["B^zeta", "|B|"] -# TODO: Assumes zeta = phi +# TODO: Assumes zeta = phi (alpha sequence) def bounce_integral( grid, data, @@ -765,14 +779,16 @@ def bounce_integral( bounce_integrate : callable This callable method computes the bounce integral ∫ f(ℓ) dℓ for every specified field line for every λ value in ``pitch``. - alphas : jnp.ndarray - Sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line. - B : _PiecewiseChebyshevBasis - Set of 1D Chebyshev spectral coefficients of |B| along field line. - {|B|_α : ζ |B|(α, ζ) | α ∈ A } . - T : _PiecewiseChebyshevBasis - Set of 1D Chebyshev spectral coefficients of θ along field line. - {θ_α : ζ θ(α, ζ) | α ∈ A }. + spline : tuple(ndarray, _PiecewiseChebyshevBasis, _PiecewiseChebyshevBasis) + alphas : jnp.ndarray + Poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line. + B : _PiecewiseChebyshevBasis + Set of 1D Chebyshev spectral coefficients of |B| along field line. + {|B|_α : ζ |B|(α, ζ) | α ∈ A } . + T : _PiecewiseChebyshevBasis + Set of 1D Chebyshev spectral coefficients of θ along field line. + {θ_α : ζ θ(α, ζ) | α ∈ A }. + """ # Resolution of periodic DESC coordinate tensor-product grid. L, m, n = grid.num_rho, grid.num_theta, grid.num_zeta @@ -798,7 +814,7 @@ def bounce_integral( ).reshape(L, M, N), ) # Peel off field lines. - alphas = _alpha_sequence(alpha_0, grid.compress(data["iota"]), num_transit) + alphas = alpha_sequence(alpha_0, grid.compress(data["iota"]), num_transit) T = T.compute_cheb(alphas) B = B.compute_cheb(alphas) assert T.cheb.shape == B.cheb.shape == (L, num_transit, N) @@ -863,12 +879,15 @@ def bounce_integrate(integrand, f, pitch, weight=None, num_well=None): errorif(weight is not None, NotImplementedError) # Compute bounce points. pitch = jnp.atleast_3d(pitch) - P = pitch.shape[0] - assert pitch.shape[1:] == B.cheb.shape[:-1], f"{pitch.shape}; {B.cheb.shape}" + assert ( + pitch.shape[1] == B.cheb.shape[0] + or pitch.shape[1] == 1 + or B.cheb.shape[0] == 1 + ) bp1, bp2 = B.bounce_points(*B.intersect(1 / pitch), num_well) + P = pitch.shape[0] num_well = bp1.shape[-1] assert bp1.shape == bp2.shape == (P, L, num_well) - result = _bounce_quadrature( bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch ) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 8291cdd423..c052be4040 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -699,6 +699,8 @@ def get_rtz_grid( "z": "zeta", "p": "phi", } + if "iota" in kwargs: + kwargs["iota"] = grid.expand(kwargs["iota"], surface_label="rho") rtz_nodes = map_coordinates( eq, grid.nodes, diff --git a/tests/test_bounce_integral.py b/tests/test_bounce_integral.py index 3f9409e889..e6e2719010 100644 --- a/tests/test_bounce_integral.py +++ b/tests/test_bounce_integral.py @@ -541,63 +541,15 @@ def _elliptic_incomplete(k2): return I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 -@pytest.mark.unit -@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) -def test_drift(): - """Test bounce-averaged drift with analytical expressions.""" - eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5") - psi_boundary = eq.Psi / (2 * np.pi) - psi = 0.25 * psi_boundary - rho = np.sqrt(psi / psi_boundary) - np.testing.assert_allclose(rho, 0.5) - - # Make a set of nodes along a single fieldline. - grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) - data = eq.compute(["iota"], grid=grid_fsa) - iota = grid_fsa.compress(data["iota"]).item() - alpha = 0 - zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) - grid = get_rtz_grid( - eq, rho, alpha, zeta, coordinates="raz", period=(np.inf, 2 * np.pi, np.inf) - ) - - data = eq.compute( - required_names() - + [ - "cvdrift", - "gbdrift", - "grad(psi)", - "grad(alpha)", - "shear", - "iota", - "psi", - "a", - ], - grid=grid, - ) - np.testing.assert_allclose(data["psi"], psi) - np.testing.assert_allclose(data["iota"], iota) - assert np.all(data["B^zeta"] > 0) - data["iota"] = grid.compress(data["iota"]).item() - data["shear"] = grid.compress(data["shear"]).item() - - B_ref = 2 * np.abs(psi_boundary) / data["a"] ** 2 - bounce_integrate, _ = bounce_integral( - data, - knots=zeta, - B_ref=B_ref, - L_ref=data["a"], - quad=leggauss(28), # converges to absolute and relative tolerance of 1e-7 - check=True, - ) - - B = data["|B|"] / B_ref +def _drift_analytic(data): + """Compute analytic approximation for bounce-averaged binormal drift.""" + B = data["|B|"] / data["B ref"] B0 = np.mean(B) # epsilon should be changed to dimensionless, and computed in a way that # is independent of normalization length scales, like "effective r/R0". - epsilon = data["a"] * rho # Aspect ratio of the flux surface. + epsilon = data["a"] * data["rho"] # Aspect ratio of the flux surface. np.testing.assert_allclose(epsilon, 0.05) - theta_PEST = alpha + data["iota"] * zeta + theta_PEST = data["alpha"] + data["iota"] * data["zeta"] # same as 1 / (1 + epsilon cos(theta)) assuming epsilon << 1 B_analytic = B0 * (1 - epsilon * np.cos(theta_PEST)) np.testing.assert_allclose(B, B_analytic, atol=3e-3) @@ -611,7 +563,7 @@ def test_drift(): np.testing.assert_allclose(gradpar, gradpar_analytic, atol=5e-3) # Comparing coefficient calculation here with coefficients from compute/_metric - normalization = -np.sign(psi) * B_ref * data["a"] ** 2 + normalization = -np.sign(data["psi"]) * data["B ref"] * data["a"] ** 2 cvdrift = data["cvdrift"] * normalization gbdrift = data["gbdrift"] * normalization dPdrho = np.mean(-0.5 * (cvdrift - gbdrift) * data["|B|"] ** 2) @@ -620,7 +572,7 @@ def test_drift(): -np.sign(data["iota"]) * data["shear"] * dot(data["grad(psi)"], data["grad(alpha)"]) - / B_ref + / data["B ref"] ) gds21_analytic = -data["shear"] * ( data["shear"] * theta_PEST - alpha_MHD / B**4 * np.sin(theta_PEST) @@ -671,6 +623,71 @@ def test_drift(): ) / G0 drift_analytic_den = I_0 / G0 drift_analytic = drift_analytic_num / drift_analytic_den + return drift_analytic, cvdrift, gbdrift, pitch + + +@pytest.mark.unit +@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) +def test_drift(): + """Test bounce-averaged drift with analytical expressions.""" + eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5") + psi_boundary = eq.Psi / (2 * np.pi) + psi = 0.25 * psi_boundary + rho = np.sqrt(psi / psi_boundary) + np.testing.assert_allclose(rho, 0.5) + + # Make a set of nodes along a single fieldline. + grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) + data = eq.compute(["iota"], grid=grid_fsa) + iota = grid_fsa.compress(data["iota"]).item() + alpha = 0 + zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) + grid = get_rtz_grid( + eq, + rho, + alpha, + zeta, + coordinates="raz", + period=(np.inf, 2 * np.pi, np.inf), + iota=np.array([iota]), + ) + data = eq.compute( + required_names() + + [ + "cvdrift", + "gbdrift", + "grad(psi)", + "grad(alpha)", + "shear", + "iota", + "psi", + "a", + ], + grid=grid, + ) + np.testing.assert_allclose(data["psi"], psi) + np.testing.assert_allclose(data["iota"], iota) + assert np.all(data["B^zeta"] > 0) + B_ref = 2 * np.abs(psi_boundary) / data["a"] ** 2 + data["B ref"] = B_ref + data["rho"] = rho + data["alpha"] = alpha + data["zeta"] = zeta + data["psi"] = grid.compress(data["psi"]) + data["iota"] = grid.compress(data["iota"]) + data["shear"] = grid.compress(data["shear"]) + + # Compute analytic approximation. + drift_analytic, cvdrift, gbdrift, pitch = _drift_analytic(data) + # Compute numerical result. + bounce_integrate, _ = bounce_integral( + data, + knots=zeta, + B_ref=B_ref, + L_ref=data["a"], + quad=leggauss(28), # converges to absolute and relative tolerance of 1e-7 + check=True, + ) def integrand_num(cvdrift, gbdrift, B, pitch): g = jnp.sqrt(1 - pitch * B) diff --git a/tests/test_fourier_bounce.py b/tests/test_fourier_bounce.py index 1a01ea970b..8718695766 100644 --- a/tests/test_fourier_bounce.py +++ b/tests/test_fourier_bounce.py @@ -2,15 +2,21 @@ import numpy as np import pytest +from matplotlib import pyplot as plt +from numpy.polynomial.legendre import leggauss +from tests.test_bounce_integral import _drift_analytic +from tests.test_plotting import tol_1d +from desc.backend import jnp from desc.compute.bounce_integral import get_pitch from desc.compute.fourier_bounce_integral import ( FourierChebyshevBasis, - _alpha_sequence, + alpha_sequence, bounce_integral, required_names, ) -from desc.equilibrium.coords import map_coordinates +from desc.equilibrium import Equilibrium +from desc.equilibrium.coords import get_rtz_grid, map_coordinates from desc.examples import get from desc.grid import LinearGrid @@ -23,7 +29,7 @@ def test_alpha_sequence(alpha_0, iota, num_period, period): """Test field line poloidal label tracking utility.""" iota = np.atleast_1d(iota) - alphas = _alpha_sequence(alpha_0, iota, num_period, period) + alphas = alpha_sequence(alpha_0, iota, num_period, period) assert alphas.shape == (iota.size, num_period) for i in range(iota.size): assert np.unique(alphas[i]).size == num_period, "Is iota irrational?" @@ -31,7 +37,7 @@ def test_alpha_sequence(alpha_0, iota, num_period, period): @pytest.mark.unit -def test_fourier_chebyshev(rho=1, M=8, N=32, f=lambda x: x): +def test_fourier_chebyshev(rho=1, M=8, N=32, f=lambda B, pitch: B * pitch): """Test bounce points...""" eq = get("W7-X") clebsch = FourierChebyshevBasis.nodes(M, N, rho=rho) @@ -52,3 +58,100 @@ def test_fourier_chebyshev(rho=1, M=8, N=32, f=lambda x: x): grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 ) result = bounce_integrate(f, [], pitch) # noqa: F841 + + +@pytest.mark.unit +@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) +def test_drift(): + """Test bounce-averaged drift with analytical expressions.""" + eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5") + psi_boundary = eq.Psi / (2 * np.pi) + psi = 0.25 * psi_boundary + rho = np.sqrt(psi / psi_boundary) + np.testing.assert_allclose(rho, 0.5) + + # Make a set of nodes along a single fieldline. + grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) + data = eq.compute(["iota"], grid=grid_fsa) + iota = grid_fsa.compress(data["iota"]).item() + alpha = 0 + zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) + grid = get_rtz_grid( + eq, + rho, + alpha, + zeta, + coordinates="raz", + period=(np.inf, 2 * np.pi, np.inf), + iota=np.array([iota]), + ) + data = eq.compute( + required_names() + + [ + "cvdrift", + "gbdrift", + "grad(psi)", + "grad(alpha)", + "shear", + "iota", + "psi", + "a", + ], + grid=grid, + ) + np.testing.assert_allclose(data["psi"], psi) + np.testing.assert_allclose(data["iota"], iota) + assert np.all(data["B^zeta"] > 0) + B_ref = 2 * np.abs(psi_boundary) / data["a"] ** 2 + data["B ref"] = B_ref + data["rho"] = rho + data["alpha"] = alpha + data["zeta"] = zeta + data["psi"] = grid.compress(data["psi"]) + data["iota"] = grid.compress(data["iota"]) + data["shear"] = grid.compress(data["shear"]) + + # Compute analytic approximation. + drift_analytic, cvdrift, gbdrift, pitch = _drift_analytic(data) + # Compute numerical result. + bounce_integrate, _ = bounce_integral( + data, + knots=zeta, + B_ref=B_ref, + L_ref=data["a"], + quad=leggauss(28), # converges to absolute and relative tolerance of 1e-7 + check=True, + ) + + def integrand_num(cvdrift, gbdrift, B, pitch): + g = jnp.sqrt(1 - pitch * B) + return (cvdrift * g) - (0.5 * g * gbdrift) + (0.5 * gbdrift / g) + + def integrand_den(B, pitch): + return 1 / jnp.sqrt(1 - pitch * B) + + drift_numerical_num = bounce_integrate( + integrand=integrand_num, + f=[cvdrift, gbdrift], + pitch=pitch[:, np.newaxis], + num_well=1, + ) + drift_numerical_den = bounce_integrate( + integrand=integrand_den, + f=[], + pitch=pitch[:, np.newaxis], + num_well=1, + weight=np.ones(zeta.size), + ) + + drift_numerical_num = np.squeeze(drift_numerical_num) + drift_numerical_den = np.squeeze(drift_numerical_den) + drift_numerical = drift_numerical_num / drift_numerical_den + msg = "There should be one bounce integral per pitch in this example." + assert drift_numerical.size == drift_analytic.size, msg + np.testing.assert_allclose(drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2) + + fig, ax = plt.subplots() + ax.plot(1 / pitch, drift_analytic) + ax.plot(1 / pitch, drift_numerical) + return fig diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py index 2dffd69b5c..1f47e74418 100644 --- a/tests/test_interp_utils.py +++ b/tests/test_interp_utils.py @@ -18,7 +18,7 @@ cheb_from_dct, cheb_pts, harmonic, - harmonic_basis, + harmonic_vander, interp_dct, interp_rfft, interp_rfft2, @@ -143,7 +143,7 @@ def test_rfftfreq(self, M): def _interp_rfft_harmonic(xq, f): M = f.shape[-1] fq = jnp.linalg.vecdot( - harmonic_basis(xq, M), harmonic(rfft(f, norm="forward"), M) + harmonic_vander(xq, M), harmonic(rfft(f, norm="forward"), M) ) return fq