From 540d0628429d45492d81bb3df60cfd3c28b728cf Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 25 Aug 2024 23:58:38 -0400 Subject: [PATCH] Review algorithm. Fix documentation of integrals and use better names for functions --- desc/integrals/basis.py | 106 ++++++++++++++++------------- desc/integrals/bounce_integral.py | 41 +++++------ desc/integrals/bounce_utils.py | 58 ++++++++-------- desc/integrals/interp_utils.py | 19 +++--- desc/integrals/quad_utils.py | 38 +++++------ desc/integrals/surface_integral.py | 2 +- tests/test_integrals.py | 69 +++++++++---------- tests/test_quad_utils.py | 5 ++ 8 files changed, 178 insertions(+), 160 deletions(-) diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py index 72aecaac66..c93eb75e45 100644 --- a/desc/integrals/basis.py +++ b/desc/integrals/basis.py @@ -47,16 +47,16 @@ def _subtract(c, k): @partial(jnp.vectorize, signature="(m),(m)->(m)") -def epigraph_and(is_intersect, df_dy_sign): - """Set and epigraph of f with ``is_intersect``. +def _in_epigraph_and(is_intersect, df_dy_sign): + """Set and epigraph of function f with the given set of points. - Remove intersects for which there does not exist a connected path between + Return only intersects where there is a connected path between adjacent intersects in the epigraph of a continuous map ``f``. Parameters ---------- is_intersect : jnp.ndarray - Boolean array indicating whether element is an intersect. + Boolean array indicating whether index corresponds to an intersect. df_dy_sign : jnp.ndarray Shape ``is_intersect.shape``. Sign of ∂f/∂y (yᵢ) for f(yᵢ) = 0. @@ -88,11 +88,23 @@ def epigraph_and(is_intersect, df_dy_sign): # ). # At each step, the likelihood that an intersection has already been lost # due to floating point errors grows, so the real solution is to pick a less - # degenerate pitch value - one that does not ride the global extrema of |B|. + # degenerate pitch value - one that does not ride the global extrema of f. ) return put(is_intersect, idx[0], edge_case) +def _chebcast(cheb, arr): + # Input should not have rightmost dimension of cheb that iterates coefficients, + # but may have additional leftmost dimension for batch operation. + errorif( + jnp.ndim(arr) > cheb.ndim, + NotImplementedError, + msg=f"Only one additional axis for batch dimension is allowed. " + f"Got {jnp.ndim(arr) - cheb.ndim + 1} additional axes.", + ) + return cheb if jnp.ndim(arr) < cheb.ndim else cheb[jnp.newaxis] + + class FourierChebyshevBasis: """Fourier-Chebyshev series. @@ -138,15 +150,19 @@ def __init__(self, f, domain=(-1, 1), lobatto=False): self.N = f.shape[-1] errorif(domain[0] > domain[-1], msg="Got inverted domain.") self.domain = tuple(domain) - errorif(lobatto, NotImplementedError, "JAX has not implemented type 1 DCT.") + errorif(lobatto, NotImplementedError, "JAX hasn't implemented type 1 DCT.") self.lobatto = bool(lobatto) self._c = FourierChebyshevBasis._fast_transform(f, self.lobatto) @staticmethod def _fast_transform(f, lobatto): - M = f.shape[-2] N = f.shape[-1] - return rfft(dct(f, type=2 - lobatto, axis=-1), axis=-2) / (M * (N - lobatto)) + c = rfft( + dct(f, type=2 - lobatto, axis=-1) / (N - lobatto), + axis=-2, + norm="forward", + ) + return c @staticmethod def nodes(M, N, L=None, domain=(-1, 1), lobatto=False): @@ -201,12 +217,16 @@ def evaluate(self, M, N): ------- fq : jnp.ndarray Shape (..., M, N) - Fourier-Chebyshev series evaluated at ``FourierChebyshevBasis.nodes(M, N)``. + Fourier-Chebyshev series evaluated at + ``FourierChebyshevBasis.nodes(M,N,L,self.domain,self.lobatto)``. """ - fq = idct(irfft(self._c, n=M, axis=-2), type=2 - self.lobatto, n=N, axis=-1) * ( - M * (N - self.lobatto) - ) + fq = idct( + irfft(self._c, n=M, axis=-2, norm="forward"), + type=2 - self.lobatto, + n=N, + axis=-1, + ) * (N - self.lobatto) return fq def harmonics(self): @@ -259,7 +279,7 @@ class ChebyshevBasisSet: Shape (..., M, N). Chebyshev coefficients αₙ(x) for fₓ(y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y). M : int - Number of function in this basis set. + Number of functions in this basis set. N : int Chebyshev spectral resolution. domain : (float, float) @@ -287,7 +307,7 @@ def __init__(self, cheb, domain=(-1, 1)): @property def M(self): - """Number of function in this basis set.""" + """Number of functions in this basis set.""" return self.cheb.shape[-2] @property @@ -295,18 +315,6 @@ def N(self): """Chebyshev spectral resolution.""" return self.cheb.shape[-1] - @staticmethod - def _chebcast(cheb, arr): - # Input should not have rightmost dimension of cheb that iterates coefficients, - # but may have additional leftmost dimension for batch operation. - errorif( - jnp.ndim(arr) > cheb.ndim, - NotImplementedError, - msg=f"Only one additional axis for batch dimension is allowed. " - f"Got {jnp.ndim(arr) - cheb.ndim + 1} additional axes.", - ) - return cheb if jnp.ndim(arr) < cheb.ndim else cheb[jnp.newaxis] - def intersect2d(self, k=0.0, eps=_eps): """Coordinates yᵢ such that f(x, yᵢ) = k(x). @@ -331,7 +339,7 @@ def intersect2d(self, k=0.0, eps=_eps): Sign of ∂f/∂y (x, yᵢ). """ - c = _subtract(ChebyshevBasisSet._chebcast(self.cheb, k), k) + c = _subtract(_chebcast(self.cheb, k), k) # roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x) y = chebroots_vec(c) assert y.shape == (*c.shape[:-1], self.N - 1) @@ -340,7 +348,8 @@ def intersect2d(self, k=0.0, eps=_eps): # Pick sentinel such that only distinct roots are considered intersects. y = filter_distinct(y, sentinel=-2.0, eps=eps) is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1.0) - y = jnp.where(is_intersect, y.real, 1.0) # ensure y is in domain of arcos + # Ensure y is in domain of arcos; choose 1 because kernel probably cheaper. + y = jnp.where(is_intersect, y.real, 1.0) # TODO: Multipoint evaluation with FFT. # Chapter 10, https://doi.org/10.1017/CBO9781139856065. @@ -379,7 +388,8 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0): z1, z2 : (jnp.ndarray, jnp.ndarray) Shape broadcasts with (..., *self.cheb.shape[:-2], num_intersect). ``z1``, ``z2`` holds intersects satisfying ∂f/∂y <= 0, ∂f/∂y >= 0, - respectively. + respectively. The points are ordered such that the path between + ``z1`` and ``z2`` lies in the epigraph of f. """ errorif( @@ -400,12 +410,14 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0): # Note for bounce point applications: # We ignore the degenerate edge case where the boundary shared by adjacent - # polynomials is a left intersect point i.e. ``is_z1`` because the subset of - # pitch values that generate this edge case has zero measure. Note that - # the technique to account for this would be to disqualify intersects - # within ``_eps`` from ``domain[-1]``. + # polynomials is a left intersection i.e. ``is_z1`` because the subset of + # pitch values that generate this edge case has zero measure. By ignoring + # this, for those subset of pitch values the integrations will be done in + # the hypograph of |B| rather than the epigraph, which will be integrated + # to zero. If we decide later to not ignore this, the technique to solve + # this is to disqualify intersects within ``_eps`` from ``domain[-1]``. is_z1 = (df_dy_sign <= 0) & is_intersect - is_z2 = (df_dy_sign >= 0) & epigraph_and(is_intersect, df_dy_sign) + is_z2 = (df_dy_sign >= 0) & _in_epigraph_and(is_intersect, df_dy_sign) sentinel = self.domain[0] - 1.0 z1 = take_mask(y, is_z1, size=num_intersect, fill_value=sentinel) @@ -418,7 +430,7 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0): return z1, z2 def eval1d(self, z, cheb=None): - """Evaluate piecewise Chebyshev spline at coordinates z. + """Evaluate piecewise Chebyshev series at coordinates z. Parameters ---------- @@ -440,7 +452,7 @@ def eval1d(self, z, cheb=None): Chebyshev basis evaluated at z. """ - cheb = self._chebcast(setdefault(cheb, self.cheb), z) + cheb = _chebcast(setdefault(cheb, self.cheb), z) N = cheb.shape[-1] x_idx, y = self.isomorphism_to_C2(z) y = bijection_to_disc(y, self.domain[0], self.domain[1]) @@ -477,7 +489,8 @@ def isomorphism_to_C1(self, y): def isomorphism_to_C2(self, z): """Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ. - Returns index x and value y such that z = f(x) + y where f(x) = x * |domain|. + Returns index x and minimum value y such that + z = f(x) + y where f(x) = x * |domain|. Parameters ---------- @@ -513,11 +526,11 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): Parameters ---------- z1, z2 : jnp.ndarray - Shape must broadcast with (k, *self.cheb.shape[:-2], W). + Shape must broadcast with (*self.cheb.shape[:-2], W). ``z1``, ``z2`` holds intersects satisfying ∂f/∂y <= 0, ∂f/∂y >= 0, respectively. k : jnp.ndarray - Shape must broadcast with (k.shape[0], *self.cheb.shape[:-2]). + Shape must broadcast with *self.cheb.shape[:-2]. k such that fₓ(yᵢ) = k. plot : bool Whether to plot stuff. Default is true. @@ -533,15 +546,15 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): err_1 = jnp.any(z1 > z2, axis=-1) err_2 = jnp.any(z1[..., 1:] < z2[..., :-1], axis=-1) - f_m = self.eval1d((z1 + z2) / 2) - assert f_m.shape == z1.shape - err_3 = jnp.any(f_m > k + self._eps, axis=-1) + f_midpoint = self.eval1d((z1 + z2) / 2) + assert f_midpoint.shape == z1.shape + err_3 = jnp.any(f_midpoint > k + self._eps, axis=-1) if not (plot or jnp.any(err_1 | err_2 | err_3)): return # Ensure l axis exists for iteration in below loop. cheb = atleast_nd(3, self.cheb) - mask, z1, z2, f_m = atleast_3d_mid(mask, z1, z2, f_m) + mask, z1, z2, f_midpoint = atleast_3d_mid(mask, z1, z2, f_midpoint) err_1, err_2, err_3 = atleast_2d_end(err_1, err_2, err_3) for l in np.ndindex(cheb.shape[:-2]): @@ -564,8 +577,9 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): assert not err_1[idx], "Intersects have an inversion.\n" assert not err_2[idx], "Detected discontinuity.\n" assert not err_3[idx], ( - "Detected f > k in well. Increase Chebyshev resolution.\n" - f"{f_m[idx][mask[idx]]} > {k[idx] + self._eps}" + "Detected f > k in well, implying a path between z1 and z2 " + "is in hypograph(f). Increase Chebyshev resolution.\n" + f"{f_midpoint[idx][mask[idx]]} > {k[idx] + self._eps}" ) idx = (slice(None), *l) if plot: @@ -586,7 +600,7 @@ def plot1d( k=None, k_transparency=0.5, klabel=r"$k$", - title=r"Intersects $z$ in epigraph of $f(z) = k$", + title=r"Intersects $z$ in epigraph($f$) s.t. $f(z) = k$", hlabel=r"$z$", vlabel=r"$f(z)$", show=True, diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index 89dc4f530c..298e11a7af 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -78,14 +78,15 @@ def _transform_to_clebsch(grid, desc_from_clebsch, M, N, B): # TODO: -# After GitHub issue #1034 is resolved, we can also pass in the previous +# After GitHub issue #1034 is resolved, we should pass in the previous # θ(α) coordinates as an initial guess for the next coordinate mapping. # Perhaps tell the optimizer to perturb the coefficients of the # |B|(α, ζ) directly? Maybe auto diff to see change on |B|(θ, ζ) -# and hence stream functions. just guessing. not sure if feasible / useful. +# and hence stream functions. Not sure how feasible... # TODO: Allow multiple starting labels for near-rational surfaces. -# can just concatenate along second to last axis of cheb. +# can just concatenate along second to last axis of cheb, but will +# do in later pull request since it's not urgent. class Bounce2D: @@ -171,11 +172,11 @@ class of basis functions to low order (e.g. N = 2ᵏ where k is small) an alternate strategy that should work is to interpolate |B| to a double Fourier series in (ϑ, ϕ), then apply bisection methods to find roots of f with mesh size inversely proportional to the max frequency along the field - line: M ι + N. ``Bounce2D`` does not use this approach because the + line: M ι + N. ``Bounce2D`` does not use that approach because that root-finding scheme is inferior. After obtaining the bounce points, the supplied quadrature is performed. - By default, Gauss quadrature is performed after removing the singularity. + By default, this is a Gauss quadrature after removing the singularity. Fast fourier transforms interpolate functions in the integrand to the quadrature nodes. @@ -194,7 +195,7 @@ class of basis functions to low order (e.g. N = 2ᵏ where k is small) Uses one-dimensional local spline methods for the same task. An advantage of ``Bounce2D`` over ``Bounce1D`` is that the coordinates on which the root-finding must be done to map from DESC to Clebsch coords is - fixed to ``M*N``, independent of the number of toroidal transits. + fixed to ``L*M*N``, independent of the number of toroidal transits. Warnings -------- @@ -223,11 +224,11 @@ def __init__( M, N, alpha_0=0.0, - num_transit=50, + num_transit=32, quad=leggauss(32), automorphism=(automorphism_sin, grad_automorphism_sin), - B_ref=1.0, - L_ref=1.0, + Bref=1.0, + Lref=1.0, check=False, **kwargs, ): @@ -250,7 +251,7 @@ def __init__( desc_from_clebsch : jnp.ndarray Shape (L * M * N, 3). DESC coordinates (ρ, θ, ζ) sourced from the Clebsch coordinates - ``FourierChebyshevBasis.nodes(M,N,domain=FourierBounce.domain)``. + ``FourierChebyshevBasis.nodes(M,N,L,domain=FourierBounce.domain)``. M : int Grid resolution in poloidal direction for Clebsch coordinate grid. Preferably power of 2. A good choice is ``m``. If the poloidal stream @@ -271,9 +272,9 @@ def __init__( The second callable should be the derivative of the first. This map defines a change of variable for the bounce integral. The choice made for the automorphism will affect the performance of the quadrature method. - B_ref : float + Bref : float Optional. Reference magnetic field strength for normalization. - L_ref : float + Lref : float Optional. Reference length scale for normalization. check : bool Flag for debugging. Must be false for JAX transformations. @@ -292,13 +293,13 @@ def __init__( self._m = grid.num_theta self._n = grid.num_zeta self._b_sup_z = jnp.expand_dims( - transform_to_desc(grid, jnp.abs(data["B^zeta"]) / data["|B|"] * L_ref), + transform_to_desc(grid, jnp.abs(data["B^zeta"]) / data["|B|"] * Lref), axis=1, ) self._x, self._w = get_quadrature(quad, automorphism) # Compute global splines. - T, B = _transform_to_clebsch(grid, desc_from_clebsch, M, N, data["|B|"] / B_ref) + T, B = _transform_to_clebsch(grid, desc_from_clebsch, M, N, data["|B|"] / Bref) # peel off field lines alphas = get_alpha( alpha_0, @@ -337,6 +338,7 @@ def desc_from_clebsch(eq, L, M, N, clebsch=None, **kwargs): Preferably power of 2. clebsch : jnp.ndarray Optional, Clebsch coordinate tensor-product grid (ρ, α, ζ). + ``FourierChebyshevBasis.nodes(M,N,L,domain=FourierBounce.domain)``. If given, ``L``, ``M``, and ``N`` are ignored. kwargs : dict Additional parameters to supply to the coordinate mapping function. @@ -426,7 +428,8 @@ def check_bounce_points(self, bp1, bp2, pitch, plot=True, **kwargs): """Check that bounce points are computed correctly and plot them.""" kwargs.setdefault( "title", - r"Intersects $\zeta$ in epigraph of $\vert B \vert(\zeta) = 1/\lambda$", + r"Intersects $\zeta$ in epigraph($\vert B \vert$) s.t. " + r"$\vert B \vert(\zeta) = 1/\lambda$", ) kwargs.setdefault("klabel", r"$1/\lambda$") kwargs.setdefault("hlabel", r"$\zeta$") @@ -565,7 +568,7 @@ class Bounce1D: This is useful if one can efficiently obtain data along field lines. After obtaining the bounce points, the supplied quadrature is performed. - By default, Gauss quadrature is performed after removing the singularity. + By default, this is a Gauss quadrature after removing the singularity. Local splines interpolate functions in the integrand to the quadrature nodes. See Also @@ -575,9 +578,9 @@ class Bounce1D: Warnings -------- The supplied data must be from a Clebsch coordinate (ρ, α, ζ) tensor-product grid. - ζ coordinates must be strictly increasing and preferably uniformly spaced. - These are used as knots to construct splines; a reference knot density is 100 - knots per toroidal transit. + The ζ coordinates (the unique values prior to taking the tensor-product) must be + strictly increasing and preferably uniformly spaced. These are used as knots to + construct splines; a reference knot density is 100 knots per toroidal transit. Examples -------- diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index fe0798cd2e..dc3e371c57 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -4,10 +4,10 @@ from matplotlib import pyplot as plt from desc.backend import imap, jnp, softmax -from desc.integrals.basis import _add2legend, _plot_intersect, epigraph_and +from desc.integrals.basis import _add2legend, _in_epigraph_and, _plot_intersect from desc.integrals.interp_utils import ( interp1d_vec, - interp1d_vec_with_df, + interp1d_vec_Hermite, poly_root, polyval_vec, ) @@ -185,7 +185,7 @@ def bounce_points( respectively, for the bounce integrals. If there were less than ``num_wells`` wells detected along a field line, - then the last axis, which enumerates bounce points for a particular field + then the last axis, which enumerates bounce points for a particular field line and pitch, is padded with zero. """ @@ -213,7 +213,7 @@ def bounce_points( # we ignore the bounce points of particles only assigned to a class that are # trapped outside this snapshot of the field line. is_bp1 = (dB_dz_sign <= 0) & is_intersect - is_bp2 = (dB_dz_sign >= 0) & epigraph_and(is_intersect, dB_dz_sign) + is_bp2 = (dB_dz_sign >= 0) & _in_epigraph_and(is_intersect, dB_dz_sign) # Transform out of local power basis expansion. intersect = (intersect + knots[:-1, jnp.newaxis]).reshape(P, S, -1) @@ -238,7 +238,8 @@ def _check_bounce_points(bp1, bp2, pitch, knots, B, plot=True, **kwargs): eps = kwargs.pop("eps", jnp.finfo(jnp.array(1.0).dtype).eps * 10) kwargs.setdefault( "title", - r"Intersects $\zeta$ in epigraph of $\vert B \vert(\zeta) = 1/\lambda$", + r"Intersects $\zeta$ in epigraph($\vert B \vert$) s.t. " + r"$\vert B \vert(\zeta) = 1/\lambda$", ) kwargs.setdefault("klabel", r"$1/\lambda$") kwargs.setdefault("hlabel", r"$\zeta$") @@ -277,7 +278,8 @@ def _check_bounce_points(bp1, bp2, pitch, knots, B, plot=True, **kwargs): assert not err_2[p, s], "Detected discontinuity.\n" assert not err_3, ( f"Detected |B| = {Bs_midpoint[mask[p, s]]} > {1 / pitch[p, s] + eps} " - f"= 1/λ in well. Use more knots.\n" + "= 1/λ in well, implying a path between bounce points is in " + "hypograph(|B|). Use more knots.\n" ) if plot: plot_ppoly( @@ -435,7 +437,7 @@ def _interpolate_and_integrate( Quadrature weights. Q : jnp.ndarray Shape (P, S, Q.shape[2], w.size). - Quadrature points at ζ coordinates. + Quadrature points in ζ coordinates. data : dict[str, jnp.ndarray] Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. Must include names in ``Bounce1D.required_names()``. @@ -462,14 +464,14 @@ def _interpolate_and_integrate( pitch = jnp.expand_dims(pitch, axis=(2, 3) if (Q.ndim == 4) else 2) shape = Q.shape Q = Q.reshape(Q.shape[0], Q.shape[1], -1) - b_sup_z = interp1d_vec_with_df( + b_sup_z = interp1d_vec_Hermite( Q, knots, data["B^zeta"] / data["|B|"], data["B^zeta_z|r,a"] / data["|B|"] - data["B^zeta"] * data["|B|_z|r,a"] / data["|B|"] ** 2, ).reshape(shape) - B = interp1d_vec_with_df(Q, knots, data["|B|"], data["|B|_z|r,a"]).reshape(shape) + B = interp1d_vec_Hermite(Q, knots, data["|B|"], data["|B|_z|r,a"]).reshape(shape) # Spline the integrand so that we can evaluate it at quadrature points without # expensive coordinate mappings and root finding. Spline each function separately so # that the singularity near the bounce points can be captured more accurately than @@ -483,13 +485,13 @@ def _interpolate_and_integrate( return result -def _check_interp(Z, f, b_sup_z, B, B_z_ra, result, plot): +def _check_interp(Q, f, b_sup_z, B, B_z_ra, result, plot): """Check for floating point errors. Parameters ---------- - Z : jnp.ndarray - Quadrature points at ζ coordinates. + Q : jnp.ndarray + Quadrature points in ζ coordinates. f : list of jnp.ndarray Arguments to the integrand interpolated to Z. b_sup_z : jnp.ndarray @@ -504,9 +506,9 @@ def _check_interp(Z, f, b_sup_z, B, B_z_ra, result, plot): Whether to plot stuff. """ - assert jnp.isfinite(Z).all(), "NaN interpolation point." + assert jnp.isfinite(Q).all(), "NaN interpolation point." # Integrals that we should be computing. - marked = jnp.any(Z != 0, axis=-1) + marked = jnp.any(Q != 0.0, axis=-1) goal = jnp.sum(marked) msg = "Interpolation failed." @@ -527,15 +529,15 @@ def _check_interp(Z, f, b_sup_z, B, B_z_ra, result, plot): "can be caused by floating point error or a poor choice of quadrature nodes." ) if plot: - _plot_check_interp(Z, B, name=r"$\vert B \vert$") - _plot_check_interp(Z, b_sup_z, name=r"$ (B / \vert B \vert) \cdot e^{\zeta}$") + _plot_check_interp(Q, B, name=r"$\vert B \vert$") + _plot_check_interp(Q, b_sup_z, name=r"$ (B / \vert B \vert) \cdot e^{\zeta}$") -def _plot_check_interp(Z, V, name=""): - """Plot V[λ, (ρ, α), (ζ₁, ζ₂)](Z).""" - for p in range(Z.shape[0]): - for s in range(Z.shape[1]): - marked = jnp.nonzero(jnp.any(Z != 0, axis=-1))[0] +def _plot_check_interp(Q, V, name=""): + """Plot V[λ, (ρ, α), (ζ₁, ζ₂)](Q).""" + for p in range(Q.shape[0]): + for s in range(Q.shape[1]): + marked = jnp.nonzero(jnp.any(Q != 0.0, axis=-1))[0] if marked.size == 0: continue fig, ax = plt.subplots() @@ -545,7 +547,7 @@ def _plot_check_interp(Z, V, name=""): f"Interpolation of {name} to quadrature points. Index {p},{s}." ) for i in marked: - ax.plot(Z[p, s, i], V[p, s, i], marker="o") + ax.plot(Q[p, s, i], V[p, s, i], marker="o") fig.text( 0.01, 0.01, @@ -673,7 +675,7 @@ def plot_ppoly( k=None, k_transparency=0.5, klabel=r"$k$", - title=r"Intersects $z$ in epigraph of $f(z) = k$", + title=r"Intersects $z$ in epigraph($f$) s.t. $f(z) = k$", hlabel=r"$z$", vlabel=r"$f(z)$", show=True, @@ -692,13 +694,13 @@ def plot_ppoly( Number of points to evaluate for plot. z1 : jnp.ndarray Shape (k.shape[0], W). - Optional, intersects with ∂f/∂ζ <= 0. + Optional, intersects with ∂f/∂z <= 0. z2 : jnp.ndarray Shape (k.shape[0], W). - Optional, intersects with ∂f/∂ζ >= 0. + Optional, intersects with ∂f/∂z >= 0. k : jnp.ndarray Shape (k.shape[0], ). - Optional, k such that f(ζ) = k. + Optional, k such that f(z) = k. k_transparency : float Transparency of intersect lines. klabel : str @@ -712,9 +714,9 @@ def plot_ppoly( show : bool Whether to show the plot. Default is true. start : float - Minimum ζ on plot. + Minimum z on plot. stop : float - Maximum ζ on plot. + Maximum z on plot. include_knots : bool Whether to plot vertical lines at the knots. knot_transparency : float diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py index b1c36c3aa8..883da3ef21 100644 --- a/desc/integrals/interp_utils.py +++ b/desc/integrals/interp_utils.py @@ -1,4 +1,4 @@ -"""Interpolation utilities.""" +"""Fast interpolation utilities.""" from functools import partial @@ -19,14 +19,15 @@ # TODO: Transformation to make nodes more uniform Boyd eq. 16.46 pg. 336. # Have a hunch it won't change locations of complex poles much, so using -# more uniformly spaced nodes could speed up convergence. +# more uniformly spaced nodes could speed up convergence (wrt early +# series truncation, not the infinite limit). def cheb_pts(N, lobatto=False, domain=(-1, 1)): """Get ``N`` Chebyshev points mapped to given domain. - Notes - ----- + Warnings + -------- This is a common definition of the Chebyshev points (see Boyd, Chebyshev and Fourier Spectral Methods p. 498). These are the points demanded by discrete cosine transformations to interpolate Chebyshev series because the cosine @@ -307,7 +308,7 @@ def transform_to_desc(grid, f): ------- a : jnp.ndarray Shape (grid.num_rho, grid.num_theta // 2 + 1, grid.num_zeta) - Coefficients of 2D real FFT. + Complex coefficients of 2D real FFT. """ f = grid.meshgrid_reshape(f, order="rtz") @@ -325,8 +326,8 @@ def cheb_from_dct(a, axis=-1): a : jnp.ndarray Discrete cosine transform coefficients, e.g. ``a=dct(f,type=2,axis=axis,norm="forward")``. - The discrete cosine transformation used by scipy is defined here. - docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dct.html#scipy.fft.dct + The discrete cosine transformation used by scipy is defined here: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dct.html. axis : int Axis along which to transform. @@ -472,8 +473,8 @@ def polyval_vec(x, c): @partial(jnp.vectorize, signature="(m),(n),(n),(n)->(m)") -def interp1d_vec_with_df(xq, x, f, fx): - """Vectorized interp1d.""" +def interp1d_vec_Hermite(xq, x, f, fx): + """Vectorized cubic Hermite spline.""" return interp1d(xq, x, f, method="cubic", fx=fx) diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py index d9950ad07b..b14f691b02 100644 --- a/desc/integrals/quad_utils.py +++ b/desc/integrals/quad_utils.py @@ -8,19 +8,19 @@ def bijection_to_disc(x, a, b): """[a, b] ∋ x ↦ y ∈ [−1, 1].""" - y = 2 * (x - a) / (b - a) - 1 + y = 2.0 * (x - a) / (b - a) - 1.0 return y def bijection_from_disc(x, a, b): """[−1, 1] ∋ x ↦ y ∈ [a, b].""" - y = (x + 1) / 2 * (b - a) + a + y = 0.5 * (b - a) * (x + 1.0) + a return y def grad_bijection_from_disc(a, b): - """Gradient of affine bijection.""" - dy_dx = (b - a) / 2 + """Gradient of affine bijection from disc.""" + dy_dx = 0.5 * (b - a) return dy_dx @@ -42,13 +42,13 @@ def automorphism_arcsin(x): Transformed points. """ - y = 2 * jnp.arcsin(x) / jnp.pi + y = 2.0 * jnp.arcsin(x) / jnp.pi return y def grad_automorphism_arcsin(x): """Gradient of arcsin automorphism.""" - dy_dx = 2 / (jnp.sqrt(1 - x**2) * jnp.pi) + dy_dx = 2.0 / (jnp.sqrt(1.0 - x**2) * jnp.pi) return dy_dx @@ -85,7 +85,7 @@ def automorphism_sin(x, s=0, m=10): errorif(not (0 <= s <= 1)) # s = 0 -> derivative vanishes like cosine. # s = 1 -> derivative vanishes like cosine^k. - y0 = jnp.sin(jnp.pi * x / 2) + y0 = jnp.sin(0.5 * jnp.pi * x) y1 = x + jnp.sin(jnp.pi * x) / jnp.pi # k = 2 y = (1 - s) * y0 + s * y1 # y is an expansion, so y(x) > x near x ∈ {−1, 1} and there is a tendency @@ -96,8 +96,8 @@ def automorphism_sin(x, s=0, m=10): def grad_automorphism_sin(x, s=0): """Gradient of sin automorphism.""" - dy0_dx = jnp.pi * jnp.cos(jnp.pi * x / 2) / 2 - dy1_dx = 1 + jnp.cos(jnp.pi * x) + dy0_dx = 0.5 * jnp.pi * jnp.cos(0.5 * jnp.pi * x) + dy1_dx = 1.0 + jnp.cos(jnp.pi * x) dy_dx = (1 - s) * dy0_dx + s * dy1_dx return dy_dx @@ -138,7 +138,7 @@ def tanh_sinh(deg, m=10): return x, w -def leggausslob(deg): +def leggauss_lobatto(deg): """Lobatto-Gauss-Legendre quadrature. Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the @@ -210,7 +210,7 @@ def get_quadrature(quad, automorphism): def composite_linspace(x, num): - """Returns linearly spaced points between every pair of points ``x``. + """Returns linearly spaced values between every pair of values in ``x``. Parameters ---------- @@ -218,18 +218,18 @@ def composite_linspace(x, num): First axis has values to return linearly spaced values between. The remaining axes are batch axes. Assumes input is sorted along first axis. num : int - Number of points between every pair of points in ``x``. + Number of values between every pair of values in ``x``. Returns ------- - pts : jnp.ndarray + vals : jnp.ndarray Shape ((x.shape[0] - 1) * num + x.shape[0], *x.shape[1:]). - Linearly spaced points between ``x``. + Linearly spaced values between ``x``. """ x = jnp.atleast_1d(x) - pts = jnp.linspace(x[:-1], x[1:], num + 1, endpoint=False) - pts = jnp.swapaxes(pts, 0, 1).reshape(-1, *x.shape[1:]) - pts = jnp.append(pts, x[jnp.newaxis, -1], axis=0) - assert pts.shape == ((x.shape[0] - 1) * num + x.shape[0], *x.shape[1:]) - return pts + vals = jnp.linspace(x[:-1], x[1:], num + 1, endpoint=False) + vals = jnp.swapaxes(vals, 0, 1).reshape(-1, *x.shape[1:]) + vals = jnp.append(vals, x[jnp.newaxis, -1], axis=0) + assert vals.shape == ((x.shape[0] - 1) * num + x.shape[0], *x.shape[1:]) + return vals diff --git a/desc/integrals/surface_integral.py b/desc/integrals/surface_integral.py index acc1e6c1b9..944a711904 100644 --- a/desc/integrals/surface_integral.py +++ b/desc/integrals/surface_integral.py @@ -100,7 +100,7 @@ def line_integrals( The coordinate curve to compute the integration over. To clarify, a theta (poloidal) curve is the intersection of a rho surface (flux surface) and zeta (toroidal) surface. - fix_surface : str, float + fix_surface : (str, float) A tuple of the form: label, value. ``fix_surface`` label should differ from ``line_label``. By default, ``fix_surface`` is chosen to be the flux surface at rho=1. diff --git a/tests/test_integrals.py b/tests/test_integrals.py index 896b412b69..ad9726b310 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -47,7 +47,7 @@ bijection_from_disc, grad_automorphism_sin, grad_bijection_from_disc, - leggausslob, + leggauss_lobatto, tanh_sinh, ) from desc.integrals.singularities import _get_quadrature_nodes @@ -729,9 +729,8 @@ def filter(bp1, bp2): mask = (bp1 - bp2) != 0.0 return bp1[mask], bp2[mask] - @staticmethod @pytest.mark.unit - def test_bp1_first(): + def test_bp1_first(self): """Test that bounce points are computed correctly.""" start = np.pi / 3 end = 6 * np.pi @@ -745,9 +744,8 @@ def test_bp1_first(): np.testing.assert_allclose(bp1, intersect[0::2]) np.testing.assert_allclose(bp2, intersect[1::2]) - @staticmethod @pytest.mark.unit - def test_bp2_first(): + def test_bp2_first(self): """Test that bounce points are computed correctly.""" start = -3 * np.pi end = -start @@ -761,9 +759,8 @@ def test_bp2_first(): np.testing.assert_allclose(bp1, intersect[1:-1:2]) np.testing.assert_allclose(bp2, intersect[0::2][1:]) - @staticmethod @pytest.mark.unit - def test_bp1_before_extrema(): + def test_bp1_before_extrema(self): """Test that bounce points are computed correctly.""" start = -np.pi end = -2 * start @@ -771,9 +768,9 @@ def test_bp1_before_extrema(): B = CubicHermiteSpline( k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) ) - B_z_ra = B.derivative() - pitch = 1 / B(B_z_ra.roots(extrapolate=False))[3] + 1e-13 - bp1, bp2 = bounce_points(pitch, k, B.c, B_z_ra.c, check=True) + dB_dz = B.derivative() + pitch = 1 / B(dB_dz.roots(extrapolate=False))[3] + 1e-13 + bp1, bp2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) bp1, bp2 = TestBouncePoints.filter(bp1, bp2) assert bp1.size and bp2.size intersect = B.solve(1 / pitch, extrapolate=False) @@ -783,9 +780,8 @@ def test_bp1_before_extrema(): np.testing.assert_allclose(intersect[2], intersect[3], rtol=1e-6) np.testing.assert_allclose(bp2, intersect[[3, 4]], rtol=1e-6) - @staticmethod @pytest.mark.unit - def test_bp2_before_extrema(): + def test_bp2_before_extrema(self): """Test that bounce points are computed correctly.""" start = -1.2 * np.pi end = -2 * start @@ -795,18 +791,17 @@ def test_bp2_before_extrema(): np.cos(k) + 2 * np.sin(-2 * k) + k / 4, -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 4, ) - B_z_ra = B.derivative() - pitch = 1 / B(B_z_ra.roots(extrapolate=False))[2] - bp1, bp2 = bounce_points(pitch, k, B.c, B_z_ra.c, check=True) + dB_dz = B.derivative() + pitch = 1 / B(dB_dz.roots(extrapolate=False))[2] + bp1, bp2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) bp1, bp2 = TestBouncePoints.filter(bp1, bp2) assert bp1.size and bp2.size intersect = B.solve(1 / pitch, extrapolate=False) np.testing.assert_allclose(bp1, intersect[[0, -2]]) np.testing.assert_allclose(bp2, intersect[[1, -1]]) - @staticmethod @pytest.mark.unit - def test_extrema_first_and_before_bp1(): + def test_extrema_first_and_before_bp1(self): """Test that bounce points are computed correctly.""" start = -1.2 * np.pi end = -2 * start @@ -816,10 +811,10 @@ def test_extrema_first_and_before_bp1(): np.cos(k) + 2 * np.sin(-2 * k) + k / 20, -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 20, ) - B_z_ra = B.derivative() - pitch = 1 / B(B_z_ra.roots(extrapolate=False))[2] - 1e-13 + dB_dz = B.derivative() + pitch = 1 / B(dB_dz.roots(extrapolate=False))[2] - 1e-13 bp1, bp2 = bounce_points( - pitch, k[2:], B.c[:, 2:], B_z_ra.c[:, 2:], check=True, plot=False + pitch, k[2:], B.c[:, 2:], dB_dz.c[:, 2:], check=True, plot=False ) plot_ppoly(B, z1=bp1, z2=bp2, k=1 / pitch, start=k[2]) bp1, bp2 = TestBouncePoints.filter(bp1, bp2) @@ -830,9 +825,8 @@ def test_extrema_first_and_before_bp1(): np.testing.assert_allclose(bp1, intersect[[0, 2, 4]], rtol=1e-6) np.testing.assert_allclose(bp2, intersect[[0, 3, 5]], rtol=1e-6) - @staticmethod @pytest.mark.unit - def test_extrema_first_and_before_bp2(): + def test_extrema_first_and_before_bp2(self): """Test that bounce points are computed correctly.""" start = -1.2 * np.pi end = -2 * start + 1 @@ -842,9 +836,9 @@ def test_extrema_first_and_before_bp2(): np.cos(k) + 2 * np.sin(-2 * k) + k / 10, -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 10, ) - B_z_ra = B.derivative() - pitch = 1 / B(B_z_ra.roots(extrapolate=False))[1] + 1e-13 - bp1, bp2 = bounce_points(pitch, k, B.c, B_z_ra.c, check=True) + dB_dz = B.derivative() + pitch = 1 / B(dB_dz.roots(extrapolate=False))[1] + 1e-13 + bp1, bp2 = bounce_points(pitch, k, B.c, dB_dz.c, check=True) bp1, bp2 = TestBouncePoints.filter(bp1, bp2) assert bp1.size and bp2.size # Our routine correctly detects intersection, while scipy, jnp.root fails. @@ -864,17 +858,17 @@ def test_get_extrema(self): B = CubicHermiteSpline( k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) ) - B_z_ra = B.derivative() - extrema, B_extrema = _get_extrema(k, B.c, B_z_ra.c) - mask = ~np.isnan(extrema) - extrema, B_extrema = extrema[mask], B_extrema[mask] - idx = np.argsort(extrema) + dB_dz = B.derivative() + ext, B_ext = _get_extrema(k, B.c, dB_dz.c) + mask = ~np.isnan(ext) + ext, B_ext = ext[mask], B_ext[mask] + idx = np.argsort(ext) - extrema_scipy = np.sort(B_z_ra.roots(extrapolate=False)) - B_extrema_scipy = B(extrema_scipy) - assert extrema.size == extrema_scipy.size - np.testing.assert_allclose(extrema[idx], extrema_scipy) - np.testing.assert_allclose(B_extrema[idx], B_extrema_scipy) + ext_scipy = np.sort(dB_dz.roots(extrapolate=False)) + B_ext_scipy = B(ext_scipy) + assert ext.size == ext_scipy.size + np.testing.assert_allclose(ext[idx], ext_scipy) + np.testing.assert_allclose(B_ext[idx], B_ext_scipy) class TestBounceQuadrature: @@ -899,7 +893,7 @@ def _mod_chebu_gauss(deg): (True, tanh_sinh(40), None), (True, leggauss(25), "default"), (False, tanh_sinh(20), None), - (False, leggausslob(10), "default"), + (False, leggauss_lobatto(10), "default"), # sin automorphism still helps out chebyshev quadrature (True, _mod_cheb_gauss(30), "default"), (False, _mod_chebu_gauss(10), "default"), @@ -1136,7 +1130,6 @@ def dB_dz(z): knots=zeta, B=bounce.B, dB_dz=bounce._dB_dz, - method="cubic", ), rtol=1e-3, ) @@ -1355,4 +1348,4 @@ def integrand_grad(*args, **kwargs2): assert np.isclose(grad(fun1)(pitch), truth, rtol=1e-3) # Make sure bounce points get differentiated too. result = fun2(pitch) - assert np.isfinite(result) and not np.isclose(result, truth, rtol=1e-3) + assert np.isfinite(result) and not np.isclose(result, truth, rtol=1e-1) diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py index 130b2732b8..662e9fcef7 100644 --- a/tests/test_quad_utils.py +++ b/tests/test_quad_utils.py @@ -62,3 +62,8 @@ def test_automorphism(): assert np.isfinite(y).all() y = 1 / np.sqrt(1 - np.abs(automorphism_arcsin(x))) assert np.isfinite(y).all() + + +@pytest.mark.unit +def test_leggauss_lobatto(): + """Test that quadrature points and weights are correct."""