diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index f02dbc942f..31a8ab9d91 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -1,15 +1,14 @@ """Methods for computing bounce integrals (singular or otherwise).""" -import numpy as np from interpax import CubicHermiteSpline, PPoly from orthax.legendre import leggauss from desc.backend import jnp from desc.integrals.bounce_utils import ( + _bounce_quadrature, _check_bounce_points, _set_default_plot_kwargs, bounce_points, - bounce_quadrature, get_pitch_inv, interp_to_argmin, plot_ppoly, @@ -21,7 +20,7 @@ grad_automorphism_sin, ) from desc.io import IOAble -from desc.utils import atleast_nd, errorif, setdefault, warnif +from desc.utils import errorif, setdefault, warnif class Bounce1D(IOAble): @@ -108,6 +107,8 @@ def __init__( automorphism=(automorphism_sin, grad_automorphism_sin), Bref=1.0, Lref=1.0, + *, + is_reshaped=False, check=False, **kwargs, ): @@ -137,6 +138,13 @@ def __init__( Optional. Reference magnetic field strength for normalization. Lref : float Optional. Reference length scale for normalization. + is_reshaped : bool + Whether the arrays in ``data`` are already reshaped to the expected form of + shape (..., N) or (..., L, N) or (M, L, N). This option can be used to + iteratively compute bounce integrals one field line or one flux surface + at a time, respectively, potentially reducing memory usage. To do so, + set to true and provide only those axes of the reshaped data. + Default is false. check : bool Flag for debugging. Must be false for JAX transformations. @@ -159,7 +167,11 @@ def __init__( "|B|": data["|B|"] / Bref, "|B|_z|r,a": data["|B|_z|r,a"] / Bref, # This is already the correct sign. } - self._data = dict(zip(data.keys(), Bounce1D.reshape_data(grid, *data.values()))) + self._data = ( + data + if is_reshaped + else dict(zip(data.keys(), Bounce1D.reshape_data(grid, *data.values()))) + ) self._x, self._w = get_quadrature(quad, automorphism) # Compute local splines. @@ -176,8 +188,10 @@ def __init__( destination=(-1, -2), ) self._dB_dz = polyder_vec(self.B) - assert self.B.shape == (grid.num_alpha, grid.num_rho, grid.num_zeta - 1, 4) - assert self._dB_dz.shape == (grid.num_alpha, grid.num_rho, grid.num_zeta - 1, 3) + + # Add axis here instead of in ``_bounce_quadrature``. + for name in self._data: + self._data[name] = self._data[name][..., jnp.newaxis, :] @staticmethod def reshape_data(grid, *arys): @@ -192,26 +206,23 @@ def reshape_data(grid, *arys): Returns ------- - f : list[jnp.ndarray] - List of reshaped data which may be given to ``integrate``. + f : jnp.ndarray + Shape (M, L, N). + Reshaped data which may be given to ``integrate``. """ f = [grid.meshgrid_reshape(d, "arz") for d in arys] - return f + return f if len(f) > 1 else f[0] - def points(self, pitch_inv, num_well=None): + def points(self, pitch_inv, *, num_well=None): """Compute bounce points. - Notes - ----- - Only the dimensions following L are required. The leading axes are batch axes. - Parameters ---------- pitch_inv : jnp.ndarray - Shape (P, M, L). + Shape (M, L, P). 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is - specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels + specified by ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted as the indices that correspond to that field line. num_well : int or None Specify to return the first ``num_well`` pairs of bounce points for each @@ -227,7 +238,7 @@ def points(self, pitch_inv, num_well=None): Returns ------- z1, z2 : (jnp.ndarray, jnp.ndarray) - Shape (P, M, L, num_well). + Shape (M, L, P, num_well). ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. @@ -239,20 +250,20 @@ def points(self, pitch_inv, num_well=None): """ return bounce_points(pitch_inv, self._zeta, self.B, self._dB_dz, num_well) - def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs): + def check_points(self, z1, z2, pitch_inv, *, plot=True, **kwargs): """Check that bounce points are computed correctly. Parameters ---------- z1, z2 : (jnp.ndarray, jnp.ndarray) - Shape (P, M, L, num_well). + Shape (M, L, P, num_well). ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. pitch_inv : jnp.ndarray - Shape (P, M, L). + Shape (M, L, P). 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is - specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels + specified by ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted as the indices that correspond to that field line. plot : bool Whether to plot the field lines and bounce points of the given pitch angles. @@ -268,7 +279,7 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs): return _check_bounce_points( z1=z1, z2=z2, - pitch_inv=atleast_nd(3, pitch_inv), + pitch_inv=pitch_inv, knots=self._zeta, B=self.B, plot=plot, @@ -277,10 +288,11 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs): def integrate( self, - pitch_inv, integrand, + pitch_inv, f=None, weight=None, + *, num_well=None, method="cubic", batch=True, @@ -291,24 +303,20 @@ def integrate( Computes the bounce integral ∫ f(ℓ) dℓ for every field line and pitch. - Notes - ----- - Only the dimensions following L are required. The leading axes are batch axes. - Parameters ---------- - pitch_inv : jnp.ndarray - Shape (P, M, L). - 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by - ``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted - as the indices that correspond to that field line. integrand : callable The composition operator on the set of functions in ``f`` that maps the functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the arrays in ``f`` as arguments as well as the additional keyword arguments: ``B`` and ``pitch``. A quadrature will be performed to approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. - f : list[jnp.ndarray] + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by + ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. + f : list[jnp.ndarray] or jnp.ndarray Shape (M, L, N). Real scalar-valued functions evaluated on the ``grid`` supplied to construct this object. These functions should be arguments to the callable @@ -345,20 +353,19 @@ def integrate( Returns ------- result : jnp.ndarray - Shape (P, M, L, num_well). - Last axis enumerates the bounce integrals for a given pitch, field line, - and flux surface. + Shape (M, L, P, num_well). + Last axis enumerates the bounce integrals for a given field line, + flux surface, and pitch value. """ - pitch_inv = atleast_nd(3, pitch_inv) - z1, z2 = self.points(pitch_inv, num_well) - result = bounce_quadrature( + z1, z2 = self.points(pitch_inv, num_well=num_well) + result = _bounce_quadrature( x=self._x, w=self._w, z1=z1, z2=z2, - pitch_inv=pitch_inv, integrand=integrand, + pitch_inv=pitch_inv, f=setdefault(f, []), data=self._data, knots=self._zeta, @@ -377,11 +384,10 @@ def integrate( self._dB_dz, method, ) - assert result.shape[0] == pitch_inv.shape[0] - assert result.shape[-1] == setdefault(num_well, np.prod(self._dB_dz.shape[-2:])) + assert result.shape == z1.shape return result - def plot(self, m, l, pitch_inv=None, **kwargs): + def plot(self, m, l, pitch_inv=None, /, **kwargs): """Plot the field line and bounce points of the given pitch angles. Parameters @@ -402,22 +408,21 @@ def plot(self, m, l, pitch_inv=None, **kwargs): Matplotlib (fig, ax) tuple. """ + B, dB_dz = self.B, self._dB_dz + if B.ndim == 4: + B = B[m, l] + dB_dz = dB_dz[m, l] + elif B.ndim == 3: + B = B[l] + dB_dz = dB_dz[l] if pitch_inv is not None: - pitch_inv = jnp.atleast_1d(jnp.squeeze(pitch_inv)) errorif( - pitch_inv.ndim != 1, + pitch_inv.ndim > 1, msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.", ) - z1, z2 = bounce_points( - pitch_inv[:, jnp.newaxis, jnp.newaxis], - self._zeta, - self.B[m, l], - self._dB_dz[m, l], - ) + z1, z2 = bounce_points(pitch_inv, self._zeta, B, dB_dz) kwargs["z1"] = z1 kwargs["z2"] = z2 kwargs["k"] = pitch_inv - fig, ax = plot_ppoly( - PPoly(self.B[m, l].T, self._zeta), **_set_default_plot_kwargs(kwargs) - ) + fig, ax = plot_ppoly(PPoly(B.T, self._zeta), **_set_default_plot_kwargs(kwargs)) return fig, ax diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index 2a8adfcdb1..e6fe1f9657 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -13,13 +13,11 @@ polyval_vec, ) from desc.integrals.quad_utils import ( + _composite_linspace, bijection_from_disc, - composite_linspace, grad_bijection_from_disc, ) from desc.utils import ( - atleast_2d_end, - atleast_3d_mid, atleast_nd, errorif, flatten_matrix, @@ -35,10 +33,8 @@ def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): Parameters ---------- min_B : jnp.ndarray - Shape (..., L). Minimum |B| value. max_B : jnp.ndarray - Shape (..., L). Maximum |B| value. num : int Number of values, not including endpoints. @@ -49,7 +45,7 @@ def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): Returns ------- pitch_inv : jnp.ndarray - Shape (num + 2, ..., L) with ndim > 2. + Shape (*min_B.shape, num + 2). 1/λ values. """ @@ -58,14 +54,13 @@ def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): min_B = (1 + relative_shift) * min_B max_B = (1 - relative_shift) * max_B # Samples should be uniformly spaced in |B| and not λ (GitHub issue #1228). - pitch_inv = atleast_3d_mid( - atleast_2d_end(composite_linspace(jnp.stack([min_B, max_B]), num)) - ) + pitch_inv = jnp.moveaxis(_composite_linspace(jnp.stack([min_B, max_B]), num), 0, -1) + assert pitch_inv.shape == (*min_B.shape, num + 2) return pitch_inv def _check_spline_shape(knots, g, dg_dz, pitch_inv=None): - """Ensure inputs have compatible shape, and return them with full dimension. + """Ensure inputs have compatible shape. Parameters ---------- @@ -73,18 +68,18 @@ def _check_spline_shape(knots, g, dg_dz, pitch_inv=None): Shape (N, ). ζ coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1]). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1] - 1). + Shape (..., N - 1, g.shape[-1] - 1). Polynomial coefficients of the spline of ∂g/∂ζ in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. pitch_inv : jnp.ndarray - Shape (P, M, L). - 1/λ values. 1/λ(α,ρ) is specified by ``pitch_inv[...,α,ρ]`` where in + Shape (..., P). + 1/λ values. 1/λ(α,ρ) is specified by ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted as the indices that correspond to that field line. @@ -102,12 +97,12 @@ def _check_spline_shape(knots, g, dg_dz, pitch_inv=None): or g.shape != (*dg_dz.shape[:-1], dg_dz.shape[-1] + 1), msg=f"Invalid shape {g.shape} for spline and derivative {dg_dz.shape}.", ) - g = atleast_nd(4, g) - dg_dz = atleast_nd(4, dg_dz) + g, dg_dz = jnp.atleast_2d(g, dg_dz) if pitch_inv is not None: - pitch_inv = atleast_nd(3, pitch_inv) + pitch_inv = jnp.atleast_1d(pitch_inv) errorif( - pitch_inv.ndim > 3 or not is_broadcastable(pitch_inv.shape, g.shape[:2]), + pitch_inv.ndim > 3 + or not is_broadcastable(pitch_inv.shape[:-1], g.shape[:-2]), msg=f"Invalid shape {pitch_inv.shape} for pitch angles.", ) return g, dg_dz, pitch_inv @@ -118,27 +113,21 @@ def bounce_points( ): """Compute the bounce points given spline of |B| and pitch λ. - Notes - ----- - Only the dimensions following L are required. The leading axes are batch axes. - Parameters ---------- pitch_inv : jnp.ndarray - Shape (P, M, L). - 1/λ values to compute the bounce points. 1/λ(α,ρ) is specified by - ``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted - as the indices that correspond to that field line. + Shape (..., P). + 1/λ values to compute the bounce points. knots : jnp.ndarray Shape (N, ). ζ coordinates of spline knots. Must be strictly increasing. B : jnp.ndarray - Shape (M, L, N - 1, B.shape[-1]). + Shape (..., N - 1, B.shape[-1]). Polynomial coefficients of the spline of |B| in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. dB_dz : jnp.ndarray - Shape (M, L, N - 1, B.shape[-1] - 1). + Shape (..., N - 1, B.shape[-1] - 1). Polynomial coefficients of the spline of (∂|B|/∂ζ)|(ρ,α) in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. @@ -162,7 +151,7 @@ def bounce_points( Returns ------- z1, z2 : (jnp.ndarray, jnp.ndarray) - Shape (P, M, L, num_well). + Shape (..., P, num_well). ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. @@ -174,25 +163,23 @@ def bounce_points( """ B, dB_dz, pitch_inv = _check_spline_shape(knots, B, dB_dz, pitch_inv) intersect = polyroot_vec( - c=B, - k=pitch_inv[..., jnp.newaxis], + c=B[..., jnp.newaxis, :, :], # Add P axis + k=pitch_inv[..., jnp.newaxis], # Add N axis a_min=jnp.array([0.0]), a_max=jnp.diff(knots), sort=True, sentinel=-1.0, distinct=True, ) - assert intersect.shape == ( - pitch_inv.shape[0], - B.shape[0], - B.shape[1], + assert intersect.shape[-3:] == ( + pitch_inv.shape[-1], knots.size - 1, B.shape[-1] - 1, ) # Reshape so that last axis enumerates intersects of a pitch along a field line. dB_sign = flatten_matrix( - jnp.sign(polyval_vec(x=intersect, c=dB_dz[..., jnp.newaxis, :])) + jnp.sign(polyval_vec(x=intersect, c=dB_dz[..., jnp.newaxis, :, jnp.newaxis, :])) ) # Only consider intersect if it is within knots that bound that polynomial. is_intersect = flatten_matrix(intersect) >= 0 @@ -234,6 +221,11 @@ def _set_default_plot_kwargs(kwargs): def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): """Check that bounce points are computed correctly.""" + z1 = atleast_nd(4, z1) + z2 = atleast_nd(4, z2) + pitch_inv = atleast_nd(3, pitch_inv) + B = atleast_nd(4, B) + kwargs = _set_default_plot_kwargs(kwargs) plots = [] @@ -248,8 +240,8 @@ def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): eps = kwargs.pop("eps", jnp.finfo(jnp.array(1.0).dtype).eps * 10) for ml in np.ndindex(B.shape[:-2]): ppoly = PPoly(B[ml].T, knots) - for p in range(pitch_inv.shape[0]): - idx = (p, *ml) + for p in range(pitch_inv.shape[-1]): + idx = (*ml, p) B_midpoint = ppoly((z1[idx] + z2[idx]) / 2) err_3 = jnp.any(B_midpoint > pitch_inv[idx] + eps) if not (err_1[idx] or err_2[idx] or err_3): @@ -262,7 +254,7 @@ def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): z1=_z1, z2=_z2, k=pitch_inv[idx], - title=kwargs.pop("title") + f", (p,m,l)={idx}", + title=kwargs.pop("title") + f", (m,l,p)={idx}", **kwargs, ) @@ -276,26 +268,25 @@ def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): "bounce points is in hypograph(|B|). Use more knots.\n" ) if plot: - idx = (slice(None), *ml) plots.append( plot_ppoly( ppoly=ppoly, - z1=z1[idx], - z2=z2[idx], - k=pitch_inv[idx], + z1=z1[ml], + z2=z2[ml], + k=pitch_inv[ml], **kwargs, ) ) return plots -def bounce_quadrature( +def _bounce_quadrature( x, w, z1, z2, - pitch_inv, integrand, + pitch_inv, f, data, knots, @@ -315,27 +306,25 @@ def bounce_quadrature( Shape (w.size, ). Quadrature weights. z1, z2 : jnp.ndarray - Shape (P, M, L, num_well). + Shape (..., P, num_well). ζ coordinates of bounce points. The points are ordered and grouped such that the straight line path between ``z1`` and ``z2`` resides in the epigraph of |B|. - pitch_inv : jnp.ndarray - Shape (P, M, L). - 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by - ``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted - as the indices that correspond to that field line. integrand : callable The composition operator on the set of functions in ``f`` that maps the functions in ``f`` to the integrand f(ℓ) in ∫ f(ℓ) dℓ. It should accept the arrays in ``f`` as arguments as well as the additional keyword arguments: ``B`` and ``pitch``. A quadrature will be performed to approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values to compute the bounce integrals. f : list[jnp.ndarray] - Shape (M, L, N). + Shape (..., N). Real scalar-valued functions evaluated on the ``knots``. These functions should be arguments to the callable ``integrand``. data : dict[str, jnp.ndarray] - Shape (M, L, N). + Shape (..., 1, N). Required data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. Must include names in ``Bounce1D.required_names``. knots : jnp.ndarray @@ -357,14 +346,14 @@ def bounce_quadrature( Returns ------- result : jnp.ndarray - Shape (P, M, L, num_well). - Last axis enumerates the bounce integrals for a given pitch, field line, - and flux surface. + Shape (..., P, num_well). + Last axis enumerates the bounce integrals for a field line, + flux surface, and pitch. """ errorif(x.ndim != 1 or x.shape != w.shape) - errorif(z1.ndim != 4 or z1.shape != z2.shape) - errorif(pitch_inv.ndim != 3) + errorif(z1.ndim < 2 or z1.shape != z2.shape) + pitch_inv = jnp.atleast_1d(pitch_inv) if not isinstance(f, (list, tuple)): f = [f] if isinstance(f, (jnp.ndarray, np.ndarray)) else list(f) @@ -384,7 +373,7 @@ def bounce_quadrature( ) else: # TODO: Use batched vmap. - def loop(z): + def loop(z): # over num well axis z1, z2 = z # Need to return tuple because input was tuple; artifact of JAX map. return None, _interpolate_and_integrate( @@ -398,6 +387,7 @@ def loop(z): method=method, check=False, plot=False, + batch=True, ) result = jnp.moveaxis( @@ -406,9 +396,7 @@ def loop(z): destination=-1, ) - result = result * grad_bijection_from_disc(z1, z2) - assert result.shape == z1.shape - return result + return result * grad_bijection_from_disc(z1, z2) def _interpolate_and_integrate( @@ -422,6 +410,7 @@ def _interpolate_and_integrate( method, check, plot, + batch=False, ): """Interpolate given functions to points ``Q`` and perform quadrature. @@ -431,11 +420,8 @@ def _interpolate_and_integrate( Shape (w.size, ). Quadrature weights. Q : jnp.ndarray - Shape (P, M, L, Q.shape[-2], w.size). + Shape (..., P, Q.shape[-2], w.size). Quadrature points in ζ coordinates. - data : dict[str, jnp.ndarray] - Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. - Must include names in ``Bounce1D.required_names``. Returns ------- @@ -444,12 +430,13 @@ def _interpolate_and_integrate( Quadrature result. """ - assert w.ndim == 1 - assert 3 < Q.ndim < 6 and Q.shape[0] == pitch_inv.shape[0] and Q.shape[-1] == w.size + assert w.ndim == 1 and Q.shape[-1] == w.size + assert Q.shape[-3 + batch] == pitch_inv.shape[-1] assert data["|B|"].shape[-1] == knots.size shape = Q.shape - Q = Q.reshape(*Q.shape[:3], -1) + if not batch: + Q = flatten_matrix(Q) b_sup_z = interp1d_Hermite_vec( Q, knots, @@ -460,17 +447,11 @@ def _interpolate_and_integrate( B = interp1d_Hermite_vec(Q, knots, data["|B|"], data["|B|_z|r,a"]) # Spline each function separately so that operations in the integrand # that do not preserve smoothness can be captured. - f = [interp1d_vec(Q, knots, f_i, method=method) for f_i in f] - result = jnp.dot( - ( - integrand( - *f, - B=B, - pitch=1 / pitch_inv[..., jnp.newaxis], - ) - / b_sup_z - ).reshape(shape), - w, + f = [interp1d_vec(Q, knots, f_i[..., jnp.newaxis, :], method=method) for f_i in f] + result = ( + (integrand(*f, B=B, pitch=1 / pitch_inv[..., jnp.newaxis]) / b_sup_z) + .reshape(shape) + .dot(w) ) if check: _check_interp(shape, Q, f, b_sup_z, B, result, plot) @@ -484,7 +465,7 @@ def _check_interp(shape, Q, f, b_sup_z, B, result, plot): Parameters ---------- shape : tuple - (P, M, L, Q.shape[-2], w.size). + (..., P, Q.shape[-2], w.size). Q : jnp.ndarray Quadrature points in ζ coordinates. f : list[jnp.ndarray] @@ -528,7 +509,7 @@ def _check_interp(shape, Q, f, b_sup_z, B, result, plot): def _plot_check_interp(Q, V, name=""): - """Plot V[λ, α, ρ, (ζ₁, ζ₂)](Q).""" + """Plot V[..., λ, (ζ₁, ζ₂)](Q).""" for idx in np.ndindex(Q.shape[:3]): marked = jnp.nonzero(jnp.any(Q[idx] != 0.0, axis=-1))[0] if marked.size == 0: @@ -536,7 +517,7 @@ def _plot_check_interp(Q, V, name=""): fig, ax = plt.subplots() ax.set_xlabel(r"$\zeta$") ax.set_ylabel(name) - ax.set_title(f"Interpolation of {name} to quadrature points, (p,m,l)={idx}") + ax.set_title(f"Interpolation of {name} to quadrature points, (m,l,p)={idx}") for i in marked: ax.plot(Q[(*idx, i)], V[(*idx, i)], marker="o") fig.text(0.01, 0.01, "Each color specifies a particular integral.") @@ -547,22 +528,18 @@ def _plot_check_interp(Q, V, name=""): def _get_extrema(knots, g, dg_dz, sentinel=jnp.nan): """Return extrema (z*, g(z*)). - Notes - ----- - Only the dimensions following L are required. The leading axes are batch axes. - Parameters ---------- knots : jnp.ndarray Shape (N, ). ζ coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1]). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1] - 1). + Shape (..., N - 1, g.shape[-1] - 1). Polynomial coefficients of the spline of ∂g/∂z in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. @@ -572,7 +549,7 @@ def _get_extrema(knots, g, dg_dz, sentinel=jnp.nan): Returns ------- ext, g_ext : jnp.ndarray - Shape (M, L, (N - 1) * (g.shape[-1] - 2)). + Shape (..., (N - 1) * (g.shape[-1] - 2)). First array enumerates z*. Second array enumerates g(z*) Sorting order of extrema is arbitrary. @@ -584,15 +561,15 @@ def _get_extrema(knots, g, dg_dz, sentinel=jnp.nan): g_ext = flatten_matrix(polyval_vec(x=ext, c=g[..., jnp.newaxis, :])) # Transform out of local power basis expansion. ext = flatten_matrix(ext + knots[:-1, jnp.newaxis]) + assert ext.shape == g_ext.shape and ext.shape[-1] == g.shape[-2] * (g.shape[-1] - 2) return ext, g_ext def _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel): - assert z1.shape[1:3] == z2.shape[1:3] == ext.shape[:2] == g_ext.shape[:2] return jnp.where( - (z1[..., jnp.newaxis] < ext[:, :, jnp.newaxis]) - & (ext[:, :, jnp.newaxis] < z2[..., jnp.newaxis]), - g_ext[:, :, jnp.newaxis], + (z1[..., jnp.newaxis] < ext[..., jnp.newaxis, jnp.newaxis, :]) + & (ext[..., jnp.newaxis, jnp.newaxis, :] < z2[..., jnp.newaxis]), + g_ext[..., jnp.newaxis, jnp.newaxis, :], upper_sentinel, ) @@ -604,28 +581,24 @@ def interp_to_argmin( Let E = {ζ ∣ ζ₁ < ζ < ζ₂} and A = argmin_E g(ζ). Returns mean_A h(ζ). - Notes - ----- - Only the dimensions following L are required. The leading axes are batch axes. - Parameters ---------- h : jnp.ndarray - Shape (M, L, N). + Shape (..., N). Values evaluated on ``knots`` to interpolate. z1, z2 : jnp.ndarray - Shape (P, M, L, num_well). + Shape (..., P, W). Boundaries to detect argmin between. knots : jnp.ndarray Shape (N, ). z coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1]). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1] - 1). + Shape (..., N - 1, g.shape[-1] - 1). Polynomial coefficients of the spline of ∂g/∂z in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. @@ -651,12 +624,10 @@ def interp_to_argmin( Returns ------- h : jnp.ndarray - Shape (P, M, L, num_well). - mean_A h(ζ) + Shape (..., P, W). """ - z1 = atleast_nd(4, z1) - z2 = atleast_nd(4, z2) + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) # Our softargmax(x) does the proper shift to compute softargmax(x - max(x)), # but it's still not a good idea to compute over a large length scale, so we @@ -667,9 +638,9 @@ def interp_to_argmin( ) h = jnp.linalg.vecdot( argmin, - interp1d_vec(ext, knots, h, method=method)[:, :, jnp.newaxis], + interp1d_vec(ext, knots, h, method=method)[..., jnp.newaxis, jnp.newaxis, :], ) - assert h.shape == z1.shape or h.shape == z2.shape + assert h.shape == z1.shape return h @@ -684,28 +655,24 @@ def interp_to_argmin_hard(h, z1, z2, knots, g, dg_dz, method="cubic"): Accomplishes the same task, but handles the case of non-unique global minima more correctly. It is also more efficient if P >> 1. - Notes - ----- - Only the dimensions following L are required. The leading axes are batch axes. - Parameters ---------- h : jnp.ndarray - Shape (M, L, N). + Shape (..., N). Values evaluated on ``knots`` to interpolate. z1, z2 : jnp.ndarray - Shape (P, M, L, num_well). + Shape (..., P, W). Boundaries to detect argmin between. knots : jnp.ndarray Shape (N, ). z coordinates of spline knots. Must be strictly increasing. g : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1]). + Shape (..., N - 1, g.shape[-1]). Polynomial coefficients of the spline of g in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. dg_dz : jnp.ndarray - Shape (M, L, N - 1, g.shape[-1] - 1). + Shape (..., N - 1, g.shape[-1] - 1). Polynomial coefficients of the spline of ∂g/∂z in local power basis. Last axis enumerates the coefficients of power series. Second to last axis enumerates the polynomials that compose a particular spline. @@ -717,12 +684,10 @@ def interp_to_argmin_hard(h, z1, z2, knots, g, dg_dz, method="cubic"): Returns ------- h : jnp.ndarray - Shape (P, M, L, num_well). - h(A) + Shape (..., P, W). """ - z1 = atleast_nd(4, z1) - z2 = atleast_nd(4, z2) + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) # We can use the non-differentiable max because we actually want the gradients # to accumulate through only the minimum since we are differentiating how our @@ -735,10 +700,10 @@ def interp_to_argmin_hard(h, z1, z2, knots, g, dg_dz, method="cubic"): h = interp1d_vec( jnp.take_along_axis(ext[jnp.newaxis], argmin, axis=-1), knots, - h, + h[..., jnp.newaxis, :], method=method, ) - assert h.shape == z1.shape or h.shape == z2.shape + assert h.shape == z1.shape, h.shape return h diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py index b7bfbb6a6c..2a00801d8c 100644 --- a/desc/integrals/quad_utils.py +++ b/desc/integrals/quad_utils.py @@ -220,7 +220,7 @@ def get_quadrature(quad, automorphism): return x, w -def composite_linspace(x, num): +def _composite_linspace(x, num): """Returns linearly spaced values between every pair of values in ``x``. Parameters diff --git a/tests/test_integrals.py b/tests/test_integrals.py index a368987853..23f7539dd3 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -952,7 +952,7 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism): check=True, **kwargs, ) - result = bounce.integrate(pitch_inv, integrand, check=True, plot=True) + result = bounce.integrate(integrand, pitch_inv, check=True, plot=True) assert np.count_nonzero(result) == 1 np.testing.assert_allclose(result.sum(), truth, rtol=1e-4) @@ -1099,14 +1099,14 @@ def test_bounce1d_checks(self): grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 ) num = bounce.integrate( - pitch_inv, integrand=TestBounce1D._example_numerator, + pitch_inv=pitch_inv, f=Bounce1D.reshape_data(grid.source_grid, data["g_zz"]), check=True, ) den = bounce.integrate( - pitch_inv, integrand=TestBounce1D._example_denominator, + pitch_inv=pitch_inv, check=True, batch=False, ) @@ -1117,14 +1117,12 @@ def test_bounce1d_checks(self): # Sum all bounce averages across a particular field line, for every field line. result = avg.sum(axis=-1) # Group the result by pitch and flux surface. - result = result.reshape(pitch_inv.shape[0], alpha.size, rho.size) + result = result.reshape(alpha.size, rho.size, pitch_inv.shape[-1]) # The result stored at - p, m, l = 3, 0, 1 - print("Result(λ, α, ρ):", result[p, m, l]) + m, l, p = 0, 1, 3 + print("Result(α, ρ, λ):", result[m, l, p]) # corresponds to the 1/λ value - print( - "1/λ(α, ρ):", pitch_inv[p, m % pitch_inv.shape[1], l % pitch_inv.shape[-1]] - ) + print("1/λ(α, ρ):", pitch_inv[l, p]) # for the Clebsch-type field line coordinates nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes[:, :2], "arz") print("(α, ρ):", nodes[m, l, 0]) @@ -1133,7 +1131,7 @@ def test_bounce1d_checks(self): bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) # 8. Plotting - fig, ax = bounce.plot(m, l, pitch_inv[..., l], include_legend=False, show=False) + fig, ax = bounce.plot(m, l, pitch_inv[l], include_legend=False, show=False) return fig @pytest.mark.unit @@ -1166,20 +1164,20 @@ def dg_dz(z): "|B|_z|r,a": dg_dz(zeta), }, ) - np.testing.assert_allclose(bounce._zeta, zeta) + z1 = np.array(0, ndmin=4) + z2 = np.array(2 * np.pi, ndmin=4) argmin = 5.61719 - np.testing.assert_allclose( - h(argmin), - func( - h=h(zeta), - z1=np.array(0, ndmin=3), - z2=np.array(2 * np.pi, ndmin=3), - knots=zeta, - g=bounce.B, - dg_dz=bounce._dB_dz, - ), - rtol=1e-3, - ) + h_min = h(argmin) + result = func( + h=h(zeta), + z1=z1, + z2=z2, + knots=zeta, + g=bounce.B, + dg_dz=bounce._dB_dz, + ) + assert result.shape == z1.shape + np.testing.assert_allclose(h_min, result, rtol=1e-3) @staticmethod def drift_analytic(data): @@ -1262,7 +1260,7 @@ def drift_analytic(data): # Exclude singularity not captured by analytic approximation for pitch near # the maximum |B|. (This is captured by the numerical integration). - pitch_inv = get_pitch_inv(np.min(B), np.max(B), 100).squeeze()[:-1] + pitch_inv = get_pitch_inv(np.min(B), np.max(B), 100)[:-1] k2 = 0.5 * ((1 - B0 / pitch_inv) / (epsilon * B0 / pitch_inv) + 1) I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 = ( TestBounce1DQuadrature.elliptic_incomplete(k2) @@ -1283,7 +1281,7 @@ def drift_analytic(data): ) / G0 drift_analytic_den = I_0 / G0 drift_analytic = drift_analytic_num / drift_analytic_den - return drift_analytic, cvdrift, gbdrift, pitch_inv.reshape(-1, 1, 1) + return drift_analytic, cvdrift, gbdrift, pitch_inv @staticmethod def drift_num_integrand(cvdrift, gbdrift, B, pitch): @@ -1361,15 +1359,15 @@ def test_binormal_drift_bounce1d(self): f = Bounce1D.reshape_data(grid.source_grid, cvdrift, gbdrift) drift_numerical_num = bounce.integrate( - pitch_inv=pitch_inv, integrand=TestBounce1D.drift_num_integrand, + pitch_inv=pitch_inv, f=f, num_well=1, check=True, ) drift_numerical_den = bounce.integrate( - pitch_inv=pitch_inv, integrand=TestBounce1D.drift_den_integrand, + pitch_inv=pitch_inv, num_well=1, weight=np.ones(zeta.size), check=True, @@ -1389,8 +1387,8 @@ def test_binormal_drift_bounce1d(self): ) fig, ax = plt.subplots() - ax.plot(pitch_inv.squeeze(), drift_analytic) - ax.plot(pitch_inv.squeeze(), drift_numerical) + ax.plot(pitch_inv, drift_analytic) + ax.plot(pitch_inv, drift_numerical) return fig @staticmethod @@ -1442,11 +1440,11 @@ def integrand_grad(*args, **kwargs2): return grad_fun(*args, *kwargs2.values()) def fun1(pitch): - return bounce.integrate(1 / pitch, integrand, check=False, **kwargs).sum() + return bounce.integrate(integrand, 1 / pitch, check=False, **kwargs).sum() def fun2(pitch): return bounce.integrate( - 1 / pitch, integrand_grad, check=True, **kwargs + integrand_grad, 1 / pitch, check=True, **kwargs ).sum() pitch = 1.0 diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py index 5a7c3d00e7..ce9408f12a 100644 --- a/tests/test_quad_utils.py +++ b/tests/test_quad_utils.py @@ -6,11 +6,11 @@ from desc.backend import jnp from desc.integrals.quad_utils import ( + _composite_linspace, automorphism_arcsin, automorphism_sin, bijection_from_disc, bijection_to_disc, - composite_linspace, grad_automorphism_arcsin, grad_automorphism_sin, grad_bijection_from_disc, @@ -26,7 +26,7 @@ def test_composite_linspace(): B_min_tz = np.array([0.1, 0.2]) B_max_tz = np.array([1, 3]) breaks = np.linspace(B_min_tz, B_max_tz, num=5) - b = composite_linspace(breaks, num=3) + b = _composite_linspace(breaks, num=3) for i in range(breaks.shape[0]): for j in range(breaks.shape[1]): assert only1(np.isclose(breaks[i, j], b[:, j]).tolist())