diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py index b9ec3c78c4..0ccad22c5b 100644 --- a/desc/integrals/basis.py +++ b/desc/integrals/basis.py @@ -8,6 +8,7 @@ from desc.backend import dct, flatnonzero, idct, irfft, jnp, put, rfft from desc.integrals.interp_utils import ( _filter_distinct, + _subtract_first, cheb_from_dct, cheb_pts, chebroots_vec, @@ -29,23 +30,6 @@ ) -def _subtract(c, k): - """Subtract ``k`` from first index of last axis of ``c``. - - Semantically same as ``return c.copy().at[...,0].add(-k)``, - but allows dimension to increase. - """ - c_0 = c[..., 0] - k - c = jnp.concatenate( - [ - c_0[..., jnp.newaxis], - jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)), - ], - axis=-1, - ) - return c - - @partial(jnp.vectorize, signature="(m),(m)->(m)") def _in_epigraph_and(is_intersect, df_dy_sign, /): """Set and epigraph of function f with the given set of points. @@ -162,12 +146,11 @@ def __init__(self, f, domain=(-1, 1), lobatto=False): @staticmethod def _fast_transform(f, lobatto): N = f.shape[-1] - c = rfft( + return rfft( dct(f, type=2 - lobatto, axis=-1) / (N - lobatto), axis=-2, norm="forward", ) - return c @staticmethod def nodes(M, N, L=None, domain=(-1, 1), lobatto=False): @@ -204,8 +187,9 @@ def nodes(M, N, L=None, domain=(-1, 1), lobatto=False): coords = (jnp.atleast_1d(L), x, y) else: coords = (x, y) - coords = list(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij"))) - coords = jnp.column_stack(coords) + coords = jnp.column_stack( + list(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij"))) + ) return coords def evaluate(self, M, N): @@ -424,7 +408,7 @@ def intersect2d(self, k=0.0, eps=_eps): Sign of ∂f/∂y (x, yᵢ). """ - c = _subtract(_chebcast(self.cheb, k), k) + c = _subtract_first(_chebcast(self.cheb, k), k) # roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x) y = chebroots_vec(c) assert y.shape == (*c.shape[:-1], self.N - 1) @@ -432,9 +416,9 @@ def intersect2d(self, k=0.0, eps=_eps): # Intersects must satisfy y ∈ [-1, 1]. # Pick sentinel such that only distinct roots are considered intersects. y = _filter_distinct(y, sentinel=-2.0, eps=eps) - is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1.0) - # Ensure y is in domain of arcos; choose 1 because kernel probably cheaper. - y = jnp.where(is_intersect, y.real, 1.0) + is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) < 1.0) + # Ensure y is in differentiable domain of arcos: (-1, 1). + y = jnp.where(is_intersect, y.real, 0) # TODO: Multipoint evaluation with FFT. # Chapter 10, https://doi.org/10.1017/CBO9781139856065. @@ -473,7 +457,7 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0): z1, z2 : (jnp.ndarray, jnp.ndarray) Shape broadcasts with (..., *self.cheb.shape[:-2], num_intersect). ``z1`` and ``z2`` are intersects satisfying ∂f/∂y <= 0 and ∂f/∂y >= 0, - respectively. The points are grouped and ordered such that the straight + respectively. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of f. """ @@ -500,7 +484,9 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0): # this, for those subset of pitch values the integrations will be done in # the hypograph of |B|, which will yield zero. If in far future decide to # not ignore this, note the solution is to disqualify intersects within - # ``_eps`` from ``domain[-1]``. + # ``_eps`` from ``domain[-1]``. Edit: For differentiability, we cannot + # consider intersects at boundary of Chebyshev polynomial. Again, cases + # where this would be incorrect have measure zero. is_z1 = (df_dy_sign <= 0) & is_intersect is_z2 = (df_dy_sign >= 0) & _in_epigraph_and(is_intersect, df_dy_sign) @@ -519,7 +505,8 @@ def _check_shape(self, z1, z2, k): # Ensure pitch batch dim exists and add back dim to broadcast with wells. k = atleast_nd(self.cheb.ndim - 1, k)[..., jnp.newaxis] # Same but back dim already exists. - z1, z2 = atleast_nd(self.cheb.ndim, z1, z2) + z1 = atleast_nd(self.cheb.ndim, z1) + z2 = atleast_nd(self.cheb.ndim, z2) # Cheb has shape (..., M, N) and others # have shape (K, ..., W) errorif(not (z1.ndim == z2.ndim == k.ndim == self.cheb.ndim)) @@ -533,7 +520,7 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): z1, z2 : jnp.ndarray Shape must broadcast with (*self.cheb.shape[:-2], W). ``z1`` and ``z2`` are intersects satisfying ∂f/∂y <= 0 and ∂f/∂y >= 0, - respectively. The points are grouped and ordered such that the straight + respectively. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of f. k : jnp.ndarray Shape must broadcast with *self.cheb.shape[:-2]. @@ -560,8 +547,8 @@ def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): # Ensure l axis exists for iteration in below loop. cheb = atleast_nd(3, self.cheb) - mask, z1, z2, f_midpoint = atleast_3d_mid(mask, z1, z2, f_midpoint) - err_1, err_2, err_3 = atleast_2d_end(err_1, err_2, err_3) + mask, z1, z2, f_midpoint = map(atleast_3d_mid, (mask, z1, z2, f_midpoint)) + err_1, err_2, err_3 = map(atleast_2d_end, (err_1, err_2, err_3)) for l in np.ndindex(cheb.shape[:-2]): for p in range(k.shape[0]): @@ -610,6 +597,7 @@ def plot1d( hlabel=r"$z$", vlabel=r"$f$", show=True, + include_legend=True, ): """Plot the piecewise Chebyshev series. @@ -641,6 +629,8 @@ def plot1d( Vertical axis label. show : bool Whether to show the plot. Default is true. + include_legend : bool + Whether to include the legend in the plot. Default is true. Returns ------- @@ -666,7 +656,8 @@ def plot1d( ) ax.set_xlabel(hlabel) ax.set_ylabel(vlabel) - ax.legend(legend.values(), legend.keys(), loc="lower right") + if include_legend: + ax.legend(legend.values(), legend.keys(), loc="lower right") ax.set_title(title) plt.tight_layout() if show: diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index 1826791a9d..b65d2a7ebe 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -104,8 +104,8 @@ def _transform_to_desc(grid, f): # After GitHub issue #1034 is resolved, we should pass in the previous # θ(α) coordinates as an initial guess for the next coordinate mapping. # Perhaps tell the optimizer to perturb the coefficients of the -# |B|(α, ζ) directly? Maybe auto diff to see change on |B|(θ, ζ) -# and hence stream functions. Not sure how feasible... +# |B|(α, ζ) directly? think perturbing alpha is equivalent to perturbing +# lambda. Not sure if possible.. # TODO: Allow multiple starting labels for near-rational surfaces. # can just concatenate along second to last axis of cheb, but will @@ -115,12 +115,12 @@ def _transform_to_desc(grid, f): class Bounce2D: """Computes bounce integrals using two-dimensional pseudo-spectral methods. - The bounce integral is defined as ∫ f(ℓ) dℓ, where + The bounce integral is defined as ∫ f(λ, ℓ) dℓ, where dℓ parameterizes the distance along the field line in meters, - f(ℓ) is the quantity to integrate along the field line, - and the boundaries of the integral are bounce points ζ₁, ζ₂ s.t. λ|B|(ζᵢ) = 1, - where λ is a constant proportional to the magnetic moment over energy - and |B| is the norm of the magnetic field. + f(λ, ℓ) is the quantity to integrate along the field line, + and the boundaries of the integral are bounce points ℓ₁, ℓ₂ s.t. λ|B|(ℓᵢ) = 1, + where λ is a constant defining the integral proportional to the magnetic moment + over energy and |B| is the norm of the magnetic field. For a particle with fixed λ, bounce points are defined to be the location on the field line such that the particle's velocity parallel to the magnetic field is zero. @@ -290,11 +290,17 @@ def __init__( quad : (jnp.ndarray, jnp.ndarray) Quadrature points xₖ and weights wₖ for the approximate evaluation of an integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points. + For weak singular integrals, use ``chebgauss2`` from + ``desc.integrals.quad_utils``. + For strong singular integrals, use ``leggauss``. automorphism : (Callable, Callable) or None The first callable should be an automorphism of the real interval [-1, 1]. The second callable should be the derivative of the first. This map defines a change of variable for the bounce integral. The choice made for the automorphism will affect the performance of the quadrature method. + For weak singular integrals, use ``None``. + For strong singular integrals, use ``automorphism_sin`` from + ``desc.integrals.quad_utils``. Bref : float Optional. Reference magnetic field strength for normalization. Lref : float @@ -421,17 +427,16 @@ def _L(self): """int: Number of flux surfaces to compute on.""" return self._B.cheb.shape[0] - def bounce_points(self, pitch, num_well=None): + def bounce_points(self, pitch_inv, num_well=None): """Compute bounce points. Parameters ---------- - pitch : jnp.ndarray - Shape (P, L). - λ values to evaluate the bounce integral at each field line. λ(ρ) is - specified by ``pitch[...,ρ]`` where in the latter the labels ρ are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. + pitch_inv : jnp.ndarray + Shape (M, L, P). # TODO: right now set up is (P, L). + 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by + ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. num_well : int or None Specify to return the first ``num_well`` pairs of bounce points for each pitch along each field line. This is useful if ``num_well`` tightly @@ -451,9 +456,9 @@ def bounce_points(self, pitch, num_well=None): epigraph of |B|. """ - return self._B.intersect1d(1 / jnp.atleast_2d(pitch), num_well) + return self._B.intersect1d(jnp.atleast_2d(pitch_inv), num_well) - def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs): + def check_bounce_points(self, z1, z2, pitch_inv, plot=True, **kwargs): """Check that bounce points are computed correctly. Parameters @@ -463,12 +468,11 @@ def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs): ζ coordinates of bounce points. The points are grouped and ordered such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. - pitch : jnp.ndarray - Shape (P, L). - λ values to evaluate the bounce integral at each field line. λ(ρ) is - specified by ``pitch[...,ρ]`` where in the latter the labels ρ are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. + pitch_inv : jnp.ndarray + Shape (M, L, P). # TODO: right now set up is (P, L). + 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by + ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. plot : bool Whether to plot stuff. kwargs : dict @@ -483,22 +487,21 @@ def check_bounce_points(self, z1, z2, pitch, plot=True, **kwargs): kwargs.setdefault("klabel", r"$1/\lambda$") kwargs.setdefault("hlabel", r"$\zeta$") kwargs.setdefault("vlabel", r"$\vert B \vert$") - self._B.check_intersect1d(z1, z2, 1 / pitch, plot, **kwargs) + self._B.check_intersect1d(z1, z2, pitch_inv, plot, **kwargs) - def integrate(self, pitch, integrand, f, weight=None, num_well=None): - """Bounce integrate ∫ f(ℓ) dℓ. + def integrate(self, pitch_inv, integrand, f, weight=None, num_well=None): + """Bounce integrate ∫ f(λ, ℓ) dℓ. - Computes the bounce integral ∫ f(ℓ) dℓ for every specified field line + Computes the bounce integral ∫ f(λ, ℓ) dℓ for every specified field line for every λ value in ``pitch``. Parameters ---------- - pitch : jnp.ndarray - Shape (P, L). - λ values to evaluate the bounce integral at each field line. λ(ρ) is - specified by ``pitch[...,ρ]`` where in the latter the labels ρ are - interpreted as the index into the last axis that corresponds to that field - line. If two-dimensional, the first axis is the batch axis. + pitch_inv : jnp.ndarray + Shape (M, L, P). # TODO: right now set up is (P, L). + 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by + ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. integrand : callable The composition operator on the set of functions in ``f`` that maps the functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the @@ -514,7 +517,7 @@ def integrate(self, pitch, integrand, f, weight=None, num_well=None): weight : jnp.ndarray Shape (L, 1, m, n). If supplied, the bounce integral labeled by well j is weighted such that - the returned value is w(j) ∫ f(ℓ) dℓ, where w(j) is ``weight`` + the returned value is w(j) ∫ f(λ, ℓ) dℓ, where w(j) is ``weight`` interpolated to the deepest point in the magnetic well. Use the method ``self.reshape_data`` to reshape the data into the expected shape. num_well : int or None @@ -535,9 +538,9 @@ def integrate(self, pitch, integrand, f, weight=None, num_well=None): Last axis enumerates the bounce integrals. """ - pitch = jnp.atleast_2d(pitch) - z1, z2 = self.bounce_points(pitch, num_well) - result = self._integrate(z1, z2, pitch, integrand, f) + pitch_inv = jnp.atleast_2d(pitch_inv) + z1, z2 = self.bounce_points(pitch_inv, num_well) + result = self._integrate(z1, z2, pitch_inv, integrand, f) errorif(weight is not None, NotImplementedError) return result diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py index 3dbc5b14a0..efb3b916b6 100644 --- a/desc/integrals/interp_utils.py +++ b/desc/integrals/interp_utils.py @@ -20,6 +20,7 @@ # TODO: Boyd's method 𝒪(N²) instead of Chebyshev companion matrix 𝒪(N³). # John P. Boyd, Computing real roots of a polynomial in Chebyshev series # form through subdivision. https://doi.org/10.1016/j.apnum.2005.09.007. +# This is likely the bottleneck. chebroots_vec = jnp.vectorize(chebroots, signature="(m)->(n)") @@ -143,10 +144,11 @@ def harmonic_vander(x, M): # TODO: For inverse transforms, do multipoint evaluation with FFT. # FFT cost is 𝒪(M N log[M N]) while direct evaluation is 𝒪(M² N²). # Chapter 10, https://doi.org/10.1017/CBO9781139856065. -# Right now we just do an MMT with the Vandermode matrix. -# Multipoint is likely better than using NFFT to evaluate f(xq) given fourier -# coefficients because evaluation points are quadratically packed near edges as -# required by quadrature to avoid runge. NFFT is only approximation anyway. +# Right now we do an MMT with the Vandermode matrix. +# Multipoint is likely better than using NFFT (for strong singular bounce +# integrals) to evaluate f(xq) given fourier coefficients because evaluation +# points are quadratically packed near edges for efficient quadrature. For +# weak singularities (e.g. effective ripple) NFFT should work well. # https://github.com/flatironinstitute/jax-finufft. @@ -451,6 +453,23 @@ def polyval_vec(*, x, c): # TODO: Eventually do a PR to move this stuff into interpax. +def _subtract_first(c, k): + """Subtract ``k`` from first index of last axis of ``c``. + + Semantically same as ``return c.copy().at[...,0].add(-k)``, + but allows dimension to increase. + """ + c_0 = c[..., 0] - k + c = jnp.concatenate( + [ + c_0[..., jnp.newaxis], + jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)), + ], + axis=-1, + ) + return c + + def _subtract_last(c, k): """Subtract ``k`` from last index of last axis of ``c``. diff --git a/desc/utils.py b/desc/utils.py index 72dd10f975..63375075b8 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -742,6 +742,18 @@ def atleast_nd(ndmin, ary): return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary +def atleast_3d_mid(ary): + """Like np.atleast_3d but if adds dim at axis 1 for 2d arrays.""" + ary = jnp.atleast_2d(ary) + return ary[:, jnp.newaxis] if ary.ndim == 2 else ary + + +def atleast_2d_end(ary): + """Like np.atleast_2d but if adds dim at axis 1 for 1d arrays.""" + ary = jnp.atleast_1d(ary) + return ary[:, jnp.newaxis] if ary.ndim == 1 else ary + + PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text diff --git a/tests/test_integrals.py b/tests/test_integrals.py index ca4e7e6d91..f08b7e2159 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -1635,7 +1635,7 @@ def integrand_den(B, pitch): normalization = -np.sign(data["psi"]) * data["Bref"] * data["a"] ** 2 drift_numerical_num = bounce.integrate( - pitch=pitch[:, np.newaxis], + pitch_inv=pitch[:, np.newaxis], integrand=integrand_num, f=Bounce2D.reshape_data( grid, @@ -1645,7 +1645,7 @@ def integrand_den(B, pitch): num_well=1, ) drift_numerical_den = bounce.integrate( - pitch=pitch[:, np.newaxis], + pitch_inv=pitch[:, np.newaxis], integrand=integrand_den, f=[], num_well=1,