diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 9d722ce52d..c7b51b24ab 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -685,12 +685,12 @@ def get_rtz_grid( rvp : rho, theta_PEST, phi rtz : rho, theta, zeta period : tuple of float - Assumed periodicity for each quantity in inbasis. + Assumed periodicity for functions of the given coordinates. Use ``np.inf`` to denote no periodicity. jitable : bool, optional If false the returned grid has additional attributes. Required to be false to retain nodes at magnetic axis. - kwargs : dict + kwargs Additional parameters to supply to the coordinate mapping function. See ``desc.equilibrium.coords.map_coordinates``. diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index defc4b3b91..f02dbc942f 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -255,7 +255,7 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs): specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels are interpreted as the indices that correspond to that field line. plot : bool - Whether to plot stuff. + Whether to plot the field lines and bounce points of the given pitch angles. kwargs Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``. @@ -285,6 +285,7 @@ def integrate( method="cubic", batch=True, check=False, + plot=False, ): """Bounce integrate ∫ f(ℓ) dℓ. @@ -337,6 +338,9 @@ def integrate( Whether to perform computation in a batched manner. Default is true. check : bool Flag for debugging. Must be false for JAX transformations. + plot : bool + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. Returns ------- @@ -361,6 +365,7 @@ def integrate( method=method, batch=batch, check=check, + plot=plot, ) if weight is not None: result *= interp_to_argmin( diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index 8e45c9d44c..2a8adfcdb1 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -30,7 +30,7 @@ def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): - """Return 1/λ values uniformly spaced between ``min_B`` and ``max_B``. + """Return 1/λ values for quadrature between ``min_B`` and ``max_B``. Parameters ---------- @@ -262,6 +262,7 @@ def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): z1=_z1, z2=_z2, k=pitch_inv[idx], + title=kwargs.pop("title") + f", (p,m,l)={idx}", **kwargs, ) @@ -350,7 +351,8 @@ def bounce_quadrature( Flag for debugging. Must be false for JAX transformations. Ignored if ``batch`` is false. plot : bool - Whether to plot stuff if ``check`` is true. Default is false. + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. Returns ------- @@ -418,8 +420,8 @@ def _interpolate_and_integrate( data, knots, method, - check=False, - plot=False, + check, + plot, ): """Interpolate given functions to points ``Q`` and perform quadrature. @@ -526,12 +528,7 @@ def _check_interp(shape, Q, f, b_sup_z, B, result, plot): def _plot_check_interp(Q, V, name=""): - """Plot V[λ, α, ρ, (ζ₁, ζ₂)](Q). - - These are pretty, but likely only useful for developers - doing debugging, so we don't include an option to plot these - in the public API of Bounce1D. - """ + """Plot V[λ, α, ρ, (ζ₁, ζ₂)](Q).""" for idx in np.ndindex(Q.shape[:3]): marked = jnp.nonzero(jnp.any(Q[idx] != 0.0, axis=-1))[0] if marked.size == 0: @@ -539,15 +536,10 @@ def _plot_check_interp(Q, V, name=""): fig, ax = plt.subplots() ax.set_xlabel(r"$\zeta$") ax.set_ylabel(name) - ax.set_title(f"Interpolation of {name} to quadrature points. Index {idx}.") + ax.set_title(f"Interpolation of {name} to quadrature points, (p,m,l)={idx}") for i in marked: ax.plot(Q[(*idx, i)], V[(*idx, i)], marker="o") - fig.text( - 0.01, - 0.01, - f"Each color specifies {name} interpolated to the quadrature " - "points of a particular integral.", - ) + fig.text(0.01, 0.01, "Each color specifies a particular integral.") plt.tight_layout() plt.show() @@ -765,7 +757,7 @@ def plot_ppoly( start=None, stop=None, include_knots=False, - knot_transparency=0.1, + knot_transparency=0.2, include_legend=True, ): """Plot the piecewise polynomial ``ppoly``. @@ -805,6 +797,8 @@ def plot_ppoly( Whether to plot vertical lines at the knots. knot_transparency : float Transparency of knot lines. + include_legend : bool + Whether to include the legend in the plot. Default is true. Returns ------- diff --git a/desc/utils.py b/desc/utils.py index 0f6553b67f..6ead7a5078 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -739,21 +739,17 @@ def flatten_matrix(y): # https://github.com/numpy/numpy/issues/25805 def atleast_nd(ndmin, ary): """Adds dimensions to front if necessary.""" - if ndmin == 1: - return jnp.atleast_1d(ary) - if ndmin == 2: - return jnp.atleast_2d(ary) return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary def atleast_3d_mid(ary): - """Like np.atleast3d but if adds dim at axis 1 for 2d arrays.""" + """Like np.atleast_3d but if adds dim at axis 1 for 2d arrays.""" ary = jnp.atleast_2d(ary) return ary[:, jnp.newaxis] if ary.ndim == 2 else ary def atleast_2d_end(ary): - """Like np.atleast2d but if adds dim at axis 1 for 1d arrays.""" + """Like np.atleast_2d but if adds dim at axis 1 for 1d arrays.""" ary = jnp.atleast_1d(ary) return ary[:, jnp.newaxis] if ary.ndim == 1 else ary diff --git a/tests/test_integrals.py b/tests/test_integrals.py index cb59fc9f67..a368987853 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -40,11 +40,11 @@ get_pitch_inv, interp_to_argmin, interp_to_argmin_hard, - plot_ppoly, ) from desc.integrals.quad_utils import ( automorphism_sin, bijection_from_disc, + get_quadrature, grad_automorphism_sin, grad_bijection_from_disc, leggauss_lob, @@ -738,7 +738,9 @@ def test_z1_first(self): B = CubicHermiteSpline(knots, np.cos(knots), -np.sin(knots)) pitch_inv = 0.5 intersect = B.solve(pitch_inv, extrapolate=False) - z1, z2 = bounce_points(pitch_inv, knots, B.c.T, B.derivative().c.T, check=True) + z1, z2 = bounce_points( + pitch_inv, knots, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size np.testing.assert_allclose(z1, intersect[0::2]) @@ -753,7 +755,9 @@ def test_z2_first(self): B = CubicHermiteSpline(k, np.cos(k), -np.sin(k)) pitch_inv = 0.5 intersect = B.solve(pitch_inv, extrapolate=False) - z1, z2 = bounce_points(pitch_inv, k, B.c.T, B.derivative().c.T, check=True) + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size np.testing.assert_allclose(z1, intersect[1:-1:2]) @@ -772,7 +776,9 @@ def test_z1_before_extrema(self): ) dB_dz = B.derivative() pitch_inv = B(dB_dz.roots(extrapolate=False))[3] - 1e-13 - z1, z2 = bounce_points(pitch_inv, k, B.c.T, dB_dz.c.T, check=True) + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size intersect = B.solve(pitch_inv, extrapolate=False) @@ -797,7 +803,9 @@ def test_z2_before_extrema(self): ) dB_dz = B.derivative() pitch_inv = B(dB_dz.roots(extrapolate=False))[2] - z1, z2 = bounce_points(pitch_inv, k, B.c.T, dB_dz.c.T, check=True) + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size intersect = B.solve(pitch_inv, extrapolate=False) @@ -819,9 +827,14 @@ def test_extrema_first_and_before_z1(self): dB_dz = B.derivative() pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + 1e-13 z1, z2 = bounce_points( - pitch_inv, k[2:], B.c[:, 2:].T, dB_dz.c[:, 2:].T, check=True, plot=False + pitch_inv, + k[2:], + B.c[:, 2:].T, + dB_dz.c[:, 2:].T, + check=True, + start=k[2], + include_knots=True, ) - plot_ppoly(B, z1=z1, z2=z2, k=pitch_inv, start=k[2]) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size intersect = B.solve(pitch_inv, extrapolate=False) @@ -844,7 +857,9 @@ def test_extrema_first_and_before_z2(self): ) dB_dz = B.derivative() pitch_inv = B(dB_dz.roots(extrapolate=False))[1] - 1e-13 - z1, z2 = bounce_points(pitch_inv, k, B.c.T, dB_dz.c.T, check=True) + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) z1, z2 = TestBounce1DPoints.filter(z1, z2) assert z1.size and z2.size # Our routine correctly detects intersection, while scipy, jnp.root fails. @@ -937,7 +952,7 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism): check=True, **kwargs, ) - result = bounce.integrate(pitch_inv, integrand, check=True) + result = bounce.integrate(pitch_inv, integrand, check=True, plot=True) assert np.count_nonzero(result) == 1 np.testing.assert_allclose(result.sum(), truth, rtol=1e-4) @@ -950,14 +965,10 @@ def _adaptive_elliptic(integrand, k): @staticmethod def _fixed_elliptic(integrand, k, deg): - # Can use this test to benchmark quadrature performance. - # Just k = np.atleast_1d(k) a = np.zeros_like(k) b = 2 * np.arcsin(k) - x, w = leggauss(deg) - w = w * grad_automorphism_sin(x) - x = automorphism_sin(x) + x, w = get_quadrature(leggauss(deg), (automorphism_sin, grad_automorphism_sin)) Z = bijection_from_disc(x, a[..., np.newaxis], b[..., np.newaxis]) k = k[..., np.newaxis] quad = np.dot(integrand(Z, k), w) * grad_bijection_from_disc(a, b) @@ -1118,7 +1129,10 @@ def test_bounce1d_checks(self): nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes[:, :2], "arz") print("(α, ρ):", nodes[m, l, 0]) - # 7. Plotting + # 7. Optionally check for correctness of bounce points + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + + # 8. Plotting fig, ax = bounce.plot(m, l, pitch_inv[..., l], include_legend=False, show=False) return fig @@ -1343,6 +1357,8 @@ def test_binormal_drift_bounce1d(self): Lref=data["a"], check=True, ) + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + f = Bounce1D.reshape_data(grid.source_grid, cvdrift, gbdrift) drift_numerical_num = bounce.integrate( pitch_inv=pitch_inv, @@ -1389,8 +1405,8 @@ def _test_bounce_autodiff(bounce, integrand, **kwargs): If the AD tool works properly, then these operations should be assigned zero gradients while the gradients wrt parameters of our physics computations accumulate correctly. Less mature AD tools may have subtle bugs that cause - the gradients to not accumulate correctly. (There's more than a few - GitHub issues that JAX has fixed related to this in the past!) + the gradients to not accumulate correctly. (There's a few + GitHub issues that JAX has fixed related to this in the past.) This test first confirms the gradients computed by reverse mode AD matches the analytic approximation of the true gradient. Then we confirm that the