From 76fa5bf6853020fb13503c52306a505876a5d9e3 Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 26 Jul 2024 14:54:17 -0400 Subject: [PATCH 01/14] Add map_clebsch_coords --- desc/equilibrium/coords.py | 155 ++++++++++++++++++++++++++++++++----- 1 file changed, 136 insertions(+), 19 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 33d2bfd497..34d7c5846b 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -11,7 +11,7 @@ from desc.compute import data_index, get_data_deps, get_profiles, get_transforms from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid from desc.transform import Transform -from desc.utils import setdefault +from desc.utils import parse_argname_change, setdefault def map_coordinates( # noqa: C901 @@ -162,6 +162,10 @@ def fixup(y, *args): vecroot = jit( vmap(lambda x0, *p: root(residual, x0, jac=jac, args=p, fixup=fixup, **kwargs)) ) + # See description here + # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532 + # except we make sure properly mod the function on which the root finding is + # done to handle periodic coordinates. yk, (res, niter) = vecroot(yk, coords) out = compute(yk, outbasis) @@ -231,24 +235,27 @@ def _distance_body(i, idx): return yg[idx] -def compute_theta_coords( - eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs +def compute_theta_coords( # TODO: change name to map_sfl_coords. + eq, sfl_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs ): - """Find theta_DESC for given straight field line theta_PEST. + """Find θ (theta_DESC) for given straight field line ϑ (theta_PEST). + + Assumes ζ = ϕ. Parameters ---------- eq : Equilibrium - Equilibrium to use - flux_coords : ndarray, shape(k,3) - 2d array of flux coordinates [rho,theta*,zeta]. Each row is a different - point in space. + Equilibrium to use. + sfl_coords : ndarray + Shape (k, 3). + Straight field line PEST coordinates [ρ, ϑ, ϕ]. + Each row is a different point in space. L_lmn : ndarray - spectral coefficients for lambda. Defaults to eq.L_lmn + Spectral coefficients for lambda. Defaults to those of the equilibrium. tol : float Stopping tolerance. maxiter : int > 0 - maximum number of Newton iterations + Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from the root finding and the second is the number of iterations. @@ -258,18 +265,22 @@ def compute_theta_coords( Returns ------- - coords : ndarray, shape(k,3) - coordinates [rho,theta,zeta]. + coords : ndarray + Shape (k, 3). + DESC computational coordinates [ρ, θ, ζ]. info : tuple - 2 element tuple containing residuals and number of iterations - for each point. Only returned if ``full_output`` is True + 2 element tuple containing residuals and number of iterations for each point. + Only returned if ``full_output`` is True. """ + parse_argname_change( + sfl_coords, kwargs, oldname="flux_coords", newname="sfl_coords" + ) kwargs.setdefault("maxiter", maxiter) kwargs.setdefault("tol", tol) if L_lmn is None: L_lmn = eq.L_lmn - rho, theta_star, zeta = flux_coords.T + rho, theta_PEST, zeta = sfl_coords.T def rootfun(theta_DESC, theta_PEST, rho, zeta): nodes = jnp.atleast_2d( @@ -277,14 +288,15 @@ def rootfun(theta_DESC, theta_PEST, rho, zeta): ) A = eq.L_basis.evaluate(nodes) lmbda = A @ L_lmn - theta_PESTk = theta_DESC + lmbda - r = (theta_PESTk % (2 * np.pi)) - (theta_PEST % (2 * np.pi)) - # r should be between -pi and pi + theta_PEST_k = (theta_DESC + lmbda) % (2 * np.pi) + r = theta_PEST_k - theta_PEST + # r should be between -pi and pi to minimize |r| r = jnp.where(r > np.pi, r - 2 * np.pi, r) r = jnp.where(r < -np.pi, r + 2 * np.pi, r) return r.squeeze() def jacfun(theta_DESC, theta_PEST, rho, zeta): + # Valid everywhere except θ such that θ+λ = k 2π where k ∈ ℤ. nodes = jnp.atleast_2d( jnp.array([rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()]) ) @@ -302,7 +314,7 @@ def fixup(x, *args): ) ) ) - theta_DESC, (res, niter) = vecroot(theta_star, theta_star, rho, zeta) + theta_DESC, (res, niter) = vecroot(theta_PEST, theta_PEST % (2 * np.pi), rho, zeta) nodes = jnp.array([rho, jnp.atleast_1d(theta_DESC.squeeze()), zeta]).T @@ -312,6 +324,111 @@ def fixup(x, *args): return out +def map_clebsch_coords( + eq, + clebsch_coords, + iota, + L_lmn=None, + tol=1e-6, + maxiter=20, + full_output=False, + **kwargs, +): + """Find θ for given Clebsch field line poloidal label α. + + Assumes ζ = ϕ. + + Parameters + ---------- + eq : Equilibrium + Equilibrium to use. + clebsch_coords : ndarray + Shape (k, 3). + Clebsch field line coordinates [ρ, α, ζ]. + Each row is a different point in space. + iota : ndarray + Shape (k, ) + Rotational transform on each node. + L_lmn : ndarray + Spectral coefficients for lambda. Defaults to those of the equilibrium. + tol : float + Stopping tolerance. + maxiter : int > 0 + Maximum number of Newton iterations. + full_output : bool, optional + If True, also return a tuple where the first element is the residual from + the root finding and the second is the number of iterations. + kwargs : dict, optional + Additional keyword arguments to pass to ``root_scalar`` such as ``maxiter_ls``, + ``alpha``. + + Returns + ------- + coords : ndarray + Shape (k, 3). + DESC computational coordinates [ρ, θ, ζ]. + info : tuple + 2 element tuple containing residuals and number of iterations for each point. + Only returned if ``full_output`` is True. + """ + kwargs.setdefault("maxiter", maxiter) + kwargs.setdefault("tol", tol) + + kwargs.setdefault("maxiter", maxiter) + kwargs.setdefault("tol", tol) + + if L_lmn is None: + L_lmn = eq.L_lmn + rho, alpha, zeta = clebsch_coords.T + alpha = alpha % (2 * np.pi) + + # Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0. + def rootfun(theta, alpha, rho, zeta, iota): + nodes = jnp.atleast_2d( + jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()]) + ) + A = eq.L_basis.evaluate(nodes) + lmbda = A @ L_lmn + # TODO: generalize for toroidal angle + alpha_k = ((theta + lmbda) % (2 * np.pi) - iota * zeta) % (2 * np.pi) + r = alpha_k - alpha + # r should be between -pi and pi to minimize |r| + r = jnp.where(r > np.pi, r - 2 * np.pi, r) + r = jnp.where(r < -np.pi, r + 2 * np.pi, r) + return r.squeeze() + + def jacfun(theta, alpha, rho, zeta, iota): + # Valid everywhere except θ such that θ+λ = k 2π where k ∈ ℤ. + nodes = jnp.atleast_2d( + jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()]) + ) + A1 = eq.L_basis.evaluate(nodes, (0, 1, 0)) + lmbda_t = jnp.dot(A1, L_lmn) + return 1 + lmbda_t.squeeze() + + def fixup(x, *args): + return x % (2 * np.pi) + + vecroot = jit( + vmap( + lambda x0, *p: root_scalar( + rootfun, x0, jac=jacfun, args=p, fixup=fixup, **kwargs + ) + ) + ) + # Assume λ is small for initial guess. + theta, (res, niter) = vecroot( + (alpha + iota * zeta) % (2 * np.pi), alpha, rho, zeta, iota + ) + + nodes = jnp.array([rho, jnp.atleast_1d(theta.squeeze()), zeta]).T + + out = nodes + if full_output: + return out, (res, niter) + return out + + def is_nested(eq, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): """Check that an equilibrium has properly nested flux surfaces in a plane. From e5609e7518bd110bff48d9cf48d7b8bb3efed8d6 Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 26 Jul 2024 19:31:41 -0400 Subject: [PATCH 02/14] Add failing test --- desc/equilibrium/coords.py | 46 +++++++--------- desc/equilibrium/equilibrium.py | 96 ++++++++++++++++++++++++++++----- tests/test_equilibrium.py | 6 +++ 3 files changed, 109 insertions(+), 39 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 34d7c5846b..f7f97579ef 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -11,7 +11,7 @@ from desc.compute import data_index, get_data_deps, get_profiles, get_transforms from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid from desc.transform import Transform -from desc.utils import parse_argname_change, setdefault +from desc.utils import setdefault def map_coordinates( # noqa: C901 @@ -44,20 +44,20 @@ def map_coordinates( # noqa: C901 inbasis, outbasis : tuple of str Labels for input and output coordinates, eg ("R", "phi", "Z") or ("rho", "alpha", "zeta") or any combination thereof. Labels should be the - same as the compute function data key + same as the compute function data key. guess : None or ndarray, shape(k,3) Initial guess for the computational coordinates ['rho', 'theta', 'zeta'] corresponding to coords in inbasis. If None, heuristics are used based on inbasis or a nearest neighbor search on a grid. params : dict - Values of equilibrium parameters to use, eg eq.params_dict + Values of equilibrium parameters to use, e.g. ``eq.params_dict``. period : tuple of float Assumed periodicity for each quantity in inbasis. Use np.inf to denote no periodicity. tol : float Stopping tolerance. maxiter : int > 0 - Maximum number of Newton iterations + Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from the root finding and the second is the number of iterations. @@ -164,8 +164,7 @@ def fixup(y, *args): ) # See description here # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532 - # except we make sure properly mod the function on which the root finding is - # done to handle periodic coordinates. + # except we make sure properly handle periodic coordinates. yk, (res, niter) = vecroot(yk, coords) out = compute(yk, outbasis) @@ -235,20 +234,18 @@ def _distance_body(i, idx): return yg[idx] -def compute_theta_coords( # TODO: change name to map_sfl_coords. - eq, sfl_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs +def compute_theta_coords( + eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs ): """Find θ (theta_DESC) for given straight field line ϑ (theta_PEST). - Assumes ζ = ϕ. - Parameters ---------- eq : Equilibrium Equilibrium to use. - sfl_coords : ndarray + flux_coords : ndarray Shape (k, 3). - Straight field line PEST coordinates [ρ, ϑ, ϕ]. + Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. Each row is a different point in space. L_lmn : ndarray Spectral coefficients for lambda. Defaults to those of the equilibrium. @@ -271,17 +268,17 @@ def compute_theta_coords( # TODO: change name to map_sfl_coords. info : tuple 2 element tuple containing residuals and number of iterations for each point. Only returned if ``full_output`` is True. + """ - parse_argname_change( - sfl_coords, kwargs, oldname="flux_coords", newname="sfl_coords" - ) kwargs.setdefault("maxiter", maxiter) kwargs.setdefault("tol", tol) if L_lmn is None: L_lmn = eq.L_lmn - rho, theta_PEST, zeta = sfl_coords.T + rho, theta_PEST, zeta = flux_coords.T + theta_PEST = theta_PEST % (2 * np.pi) + # Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0. def rootfun(theta_DESC, theta_PEST, rho, zeta): nodes = jnp.atleast_2d( jnp.array([rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()]) @@ -290,7 +287,7 @@ def rootfun(theta_DESC, theta_PEST, rho, zeta): lmbda = A @ L_lmn theta_PEST_k = (theta_DESC + lmbda) % (2 * np.pi) r = theta_PEST_k - theta_PEST - # r should be between -pi and pi to minimize |r| + # r should be between -pi and pi r = jnp.where(r > np.pi, r - 2 * np.pi, r) r = jnp.where(r < -np.pi, r + 2 * np.pi, r) return r.squeeze() @@ -314,7 +311,8 @@ def fixup(x, *args): ) ) ) - theta_DESC, (res, niter) = vecroot(theta_PEST, theta_PEST % (2 * np.pi), rho, zeta) + # Assume λ=0 for initial guess. + theta_DESC, (res, niter) = vecroot(theta_PEST, theta_PEST, rho, zeta) nodes = jnp.array([rho, jnp.atleast_1d(theta_DESC.squeeze()), zeta]).T @@ -336,15 +334,13 @@ def map_clebsch_coords( ): """Find θ for given Clebsch field line poloidal label α. - Assumes ζ = ϕ. - Parameters ---------- eq : Equilibrium Equilibrium to use. clebsch_coords : ndarray Shape (k, 3). - Clebsch field line coordinates [ρ, α, ζ]. + Clebsch field line coordinates [ρ, α, ζ]. Assumes ζ = ϕ. Each row is a different point in space. iota : ndarray Shape (k, ) @@ -370,10 +366,8 @@ def map_clebsch_coords( info : tuple 2 element tuple containing residuals and number of iterations for each point. Only returned if ``full_output`` is True. - """ - kwargs.setdefault("maxiter", maxiter) - kwargs.setdefault("tol", tol) + """ kwargs.setdefault("maxiter", maxiter) kwargs.setdefault("tol", tol) @@ -392,7 +386,7 @@ def rootfun(theta, alpha, rho, zeta, iota): # TODO: generalize for toroidal angle alpha_k = ((theta + lmbda) % (2 * np.pi) - iota * zeta) % (2 * np.pi) r = alpha_k - alpha - # r should be between -pi and pi to minimize |r| + # r should be between -pi and pi r = jnp.where(r > np.pi, r - 2 * np.pi, r) r = jnp.where(r < -np.pi, r + 2 * np.pi, r) return r.squeeze() @@ -416,7 +410,7 @@ def fixup(x, *args): ) ) ) - # Assume λ is small for initial guess. + # Assume λ=0 for initial guess. theta, (res, niter) = vecroot( (alpha + iota * zeta) % (2 * np.pi), alpha, rho, zeta, iota ) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 22d5947d1a..bc103026b7 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -48,12 +48,19 @@ copy_coeffs, errorif, only1, + parse_argname_change, setdefault, warnif, ) from ..compute.data_index import is_0d_vol_grid, is_1dr_rad_grid, is_1dz_tor_grid -from .coords import compute_theta_coords, is_nested, map_coordinates, to_sfl +from .coords import ( + compute_theta_coords, + is_nested, + map_clebsch_coords, + map_coordinates, + to_sfl, +) from .initial_guess import set_initial_guess from .utils import parse_axis, parse_profile, parse_surface @@ -1219,21 +1226,20 @@ def map_coordinates( # noqa: C901 def compute_theta_coords( self, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs ): - """Find theta_DESC for given straight field line theta_PEST. + """Find θ (theta_DESC) for given straight field line ϑ (theta_PEST). Parameters ---------- - eq : Equilibrium - Equilibrium to use - flux_coords : ndarray, shape(k,3) - 2d array of flux coordinates [rho,theta*,zeta]. Each row is a different - point in space. + flux_coords : ndarray + Shape (k, 3). + Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. + Each row is a different point in space. L_lmn : ndarray - spectral coefficients for lambda. Defaults to eq.L_lmn + Spectral coefficients for lambda. Defaults to those of the equilibrium. tol : float Stopping tolerance. maxiter : int > 0 - maximum number of Newton iterations + Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from the root finding and the second is the number of iterations. @@ -1243,12 +1249,17 @@ def compute_theta_coords( Returns ------- - coords : ndarray, shape(k,3) - coordinates [rho,theta,zeta]. + coords : ndarray + Shape (k, 3). + DESC computational coordinates [ρ, θ, ζ]. info : tuple - 2 element tuple containing residuals and number of iterations - for each point. Only returned if ``full_output`` is True + 2 element tuple containing residuals and number of iterations for each + point. Only returned if ``full_output`` is True. + """ + flux_coords = parse_argname_change( + flux_coords, kwargs, oldname="flux_coords", newname="sfl_coords" + ) return compute_theta_coords( self, flux_coords, @@ -1259,6 +1270,65 @@ def compute_theta_coords( **kwargs, ) + def map_clebsch_coords( + self, + clebsch_coords, + iota, + L_lmn=None, + tol=1e-6, + maxiter=20, + full_output=False, + **kwargs, + ): + """Find θ for given Clebsch field line poloidal label α. + + Assumes ζ = ϕ. + + Parameters + ---------- + eq : Equilibrium + Equilibrium to use. + clebsch_coords : ndarray + Shape (k, 3). + Clebsch field line coordinates [ρ, α, ζ]. + Each row is a different point in space. + iota : ndarray + Shape (k, ) + Rotational transform on each node. + L_lmn : ndarray + Spectral coefficients for lambda. Defaults to those of the equilibrium. + tol : float + Stopping tolerance. + maxiter : int > 0 + Maximum number of Newton iterations. + full_output : bool, optional + If True, also return a tuple where the first element is the residual from + the root finding and the second is the number of iterations. + kwargs : dict, optional + Additional keyword arguments to pass to ``root_scalar`` such as + ``maxiter_ls``, ``alpha``. + + Returns + ------- + coords : ndarray + Shape (k, 3). + DESC computational coordinates [ρ, θ, ζ]. + info : tuple + 2 element tuple containing residuals and number of iterations for each + point. Only returned if ``full_output`` is True. + + """ + return map_clebsch_coords( + self, + clebsch_coords, + iota, + L_lmn=L_lmn, + maxiter=maxiter, + tol=tol, + full_output=full_output, + **kwargs, + ) + @execute_on_cpu def is_nested(self, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): """Check that an equilibrium has properly nested flux surfaces in a plane. diff --git a/tests/test_equilibrium.py b/tests/test_equilibrium.py index 0b1af2572e..bd084d7ac9 100644 --- a/tests/test_equilibrium.py +++ b/tests/test_equilibrium.py @@ -49,6 +49,9 @@ def test_map_coordinates(): """Test root finding for (rho,theta,zeta) for common use cases.""" # finding coordinates along a single field line eq = get("NCSX") + grid = LinearGrid(rho=1, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) + iota = grid.compress(eq.compute("iota", grid=grid)["iota"]) + with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(3, 3, 3, 6, 6, 6) n = 100 @@ -61,6 +64,9 @@ def test_map_coordinates(): ) assert not np.any(np.isnan(out)) + iota = np.broadcast_to(iota, shape=(n,)) + np.testing.assert_allclose(eq.map_clebsch_coords(coords, iota), out) + eq = get("DSHAPE") inbasis = ["R", "phi", "Z"] From f1644a90222983e699403289aaa6f44cd1a276f3 Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 26 Jul 2024 19:35:16 -0400 Subject: [PATCH 03/14] Remove old code --- desc/equilibrium/equilibrium.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index bc103026b7..d4f0b14ecb 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -48,7 +48,6 @@ copy_coeffs, errorif, only1, - parse_argname_change, setdefault, warnif, ) @@ -1257,9 +1256,6 @@ def compute_theta_coords( point. Only returned if ``full_output`` is True. """ - flux_coords = parse_argname_change( - flux_coords, kwargs, oldname="flux_coords", newname="sfl_coords" - ) return compute_theta_coords( self, flux_coords, From 97400214da1731910da79387aa1955822aab4080 Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 26 Jul 2024 22:35:06 -0400 Subject: [PATCH 04/14] Eq.change_resolution causes 3d coordinate mapping to give bad results New Clebsch map which does 1d newton iteration over surface is invariant to radial resolution. So remove eq.change_resolution from test so that two methods converge to correct result --- desc/equilibrium/coords.py | 16 ++++++---------- desc/equilibrium/equilibrium.py | 4 +--- tests/test_equilibrium.py | 20 ++++++++++---------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index f7f97579ef..862edd2157 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -280,8 +280,8 @@ def compute_theta_coords( # Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0. def rootfun(theta_DESC, theta_PEST, rho, zeta): - nodes = jnp.atleast_2d( - jnp.array([rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()]) + nodes = jnp.array( + [rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2 ) A = eq.L_basis.evaluate(nodes) lmbda = A @ L_lmn @@ -294,8 +294,8 @@ def rootfun(theta_DESC, theta_PEST, rho, zeta): def jacfun(theta_DESC, theta_PEST, rho, zeta): # Valid everywhere except θ such that θ+λ = k 2π where k ∈ ℤ. - nodes = jnp.atleast_2d( - jnp.array([rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()]) + nodes = jnp.array( + [rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2 ) A1 = eq.L_basis.evaluate(nodes, (0, 1, 0)) lmbda_t = jnp.dot(A1, L_lmn) @@ -378,9 +378,7 @@ def map_clebsch_coords( # Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0. def rootfun(theta, alpha, rho, zeta, iota): - nodes = jnp.atleast_2d( - jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()]) - ) + nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2) A = eq.L_basis.evaluate(nodes) lmbda = A @ L_lmn # TODO: generalize for toroidal angle @@ -393,9 +391,7 @@ def rootfun(theta, alpha, rho, zeta, iota): def jacfun(theta, alpha, rho, zeta, iota): # Valid everywhere except θ such that θ+λ = k 2π where k ∈ ℤ. - nodes = jnp.atleast_2d( - jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()]) - ) + nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2) A1 = eq.L_basis.evaluate(nodes, (0, 1, 0)) lmbda_t = jnp.dot(A1, L_lmn) return 1 + lmbda_t.squeeze() diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index d4f0b14ecb..0e3213d522 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1278,15 +1278,13 @@ def map_clebsch_coords( ): """Find θ for given Clebsch field line poloidal label α. - Assumes ζ = ϕ. - Parameters ---------- eq : Equilibrium Equilibrium to use. clebsch_coords : ndarray Shape (k, 3). - Clebsch field line coordinates [ρ, α, ζ]. + Clebsch field line coordinates [ρ, α, ζ]. Assumes ζ = ϕ. Each row is a different point in space. iota : ndarray Shape (k, ) diff --git a/tests/test_equilibrium.py b/tests/test_equilibrium.py index bd084d7ac9..b67c723038 100644 --- a/tests/test_equilibrium.py +++ b/tests/test_equilibrium.py @@ -49,23 +49,23 @@ def test_map_coordinates(): """Test root finding for (rho,theta,zeta) for common use cases.""" # finding coordinates along a single field line eq = get("NCSX") + n = 100 + coords = np.array([np.ones(n), np.zeros(n), np.linspace(0, 10 * np.pi, n)]).T grid = LinearGrid(rho=1, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) iota = grid.compress(eq.compute("iota", grid=grid)["iota"]) + iota = np.broadcast_to(iota, shape=(n,)) - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(3, 3, 3, 6, 6, 6) - n = 100 - coords = np.array([np.ones(n), np.zeros(n), np.linspace(0, 10 * np.pi, n)]).T - out = eq.map_coordinates( + out_1 = eq.map_clebsch_coords(coords, iota) + assert np.isfinite(out_1).all() + out_2 = eq.map_coordinates( coords, ["rho", "alpha", "zeta"], ["rho", "theta", "zeta"], - period=(np.inf, 2 * np.pi, 10 * np.pi), + period=(np.inf, 2 * np.pi, np.inf), ) - assert not np.any(np.isnan(out)) - - iota = np.broadcast_to(iota, shape=(n,)) - np.testing.assert_allclose(eq.map_clebsch_coords(coords, iota), out) + assert np.isfinite(out_2).all() + diff = (out_1 - out_2) % (2 * np.pi) + assert np.all(np.isclose(diff, 0) | np.isclose(diff, 2 * np.pi)) eq = get("DSHAPE") From 03df62f4fbcf643b7bb471951c7ad68bff82e0aa Mon Sep 17 00:00:00 2001 From: unalmis Date: Sat, 27 Jul 2024 10:47:27 -0400 Subject: [PATCH 05/14] Remove unneeded modulo --- desc/equilibrium/coords.py | 6 +++--- desc/equilibrium/equilibrium.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 862edd2157..a5ffdbfdc6 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -248,7 +248,7 @@ def compute_theta_coords( Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. Each row is a different point in space. L_lmn : ndarray - Spectral coefficients for lambda. Defaults to those of the equilibrium. + Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. tol : float Stopping tolerance. maxiter : int > 0 @@ -346,7 +346,7 @@ def map_clebsch_coords( Shape (k, ) Rotational transform on each node. L_lmn : ndarray - Spectral coefficients for lambda. Defaults to those of the equilibrium. + Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. tol : float Stopping tolerance. maxiter : int > 0 @@ -382,7 +382,7 @@ def rootfun(theta, alpha, rho, zeta, iota): A = eq.L_basis.evaluate(nodes) lmbda = A @ L_lmn # TODO: generalize for toroidal angle - alpha_k = ((theta + lmbda) % (2 * np.pi) - iota * zeta) % (2 * np.pi) + alpha_k = (theta + lmbda - iota * zeta) % (2 * np.pi) r = alpha_k - alpha # r should be between -pi and pi r = jnp.where(r > np.pi, r - 2 * np.pi, r) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 0e3213d522..5841692f89 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1234,7 +1234,7 @@ def compute_theta_coords( Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. Each row is a different point in space. L_lmn : ndarray - Spectral coefficients for lambda. Defaults to those of the equilibrium. + Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. tol : float Stopping tolerance. maxiter : int > 0 @@ -1290,7 +1290,7 @@ def map_clebsch_coords( Shape (k, ) Rotational transform on each node. L_lmn : ndarray - Spectral coefficients for lambda. Defaults to those of the equilibrium. + Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. tol : float Stopping tolerance. maxiter : int > 0 From ca62476a49fcc28f79a001393cb80066bda5c4c6 Mon Sep 17 00:00:00 2001 From: unalmis Date: Mon, 5 Aug 2024 23:49:26 -0400 Subject: [PATCH 06/14] Detach map clebsch coordinates api from equilibrium object --- desc/equilibrium/coords.py | 14 +++++++------- desc/equilibrium/equilibrium.py | 11 +++++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index a5ffdbfdc6..5ca8611cae 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -323,10 +323,10 @@ def fixup(x, *args): def map_clebsch_coords( - eq, clebsch_coords, iota, - L_lmn=None, + L_lmn, + L_basis, tol=1e-6, maxiter=20, full_output=False, @@ -346,7 +346,9 @@ def map_clebsch_coords( Shape (k, ) Rotational transform on each node. L_lmn : ndarray - Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. + Spectral coefficients for lambda. + L_basis : Basis + Spectral basis for lambda. tol : float Stopping tolerance. maxiter : int > 0 @@ -371,15 +373,13 @@ def map_clebsch_coords( kwargs.setdefault("maxiter", maxiter) kwargs.setdefault("tol", tol) - if L_lmn is None: - L_lmn = eq.L_lmn rho, alpha, zeta = clebsch_coords.T alpha = alpha % (2 * np.pi) # Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0. def rootfun(theta, alpha, rho, zeta, iota): nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2) - A = eq.L_basis.evaluate(nodes) + A = L_basis.evaluate(nodes) lmbda = A @ L_lmn # TODO: generalize for toroidal angle alpha_k = (theta + lmbda - iota * zeta) % (2 * np.pi) @@ -392,7 +392,7 @@ def rootfun(theta, alpha, rho, zeta, iota): def jacfun(theta, alpha, rho, zeta, iota): # Valid everywhere except θ such that θ+λ = k 2π where k ∈ ℤ. nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2) - A1 = eq.L_basis.evaluate(nodes, (0, 1, 0)) + A1 = L_basis.evaluate(nodes, (0, 1, 0)) lmbda_t = jnp.dot(A1, L_lmn) return 1 + lmbda_t.squeeze() diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 5841692f89..758ade21e5 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1260,8 +1260,8 @@ def compute_theta_coords( self, flux_coords, L_lmn=L_lmn, - maxiter=maxiter, tol=tol, + maxiter=maxiter, full_output=full_output, **kwargs, ) @@ -1271,6 +1271,7 @@ def map_clebsch_coords( clebsch_coords, iota, L_lmn=None, + L_basis=None, tol=1e-6, maxiter=20, full_output=False, @@ -1291,6 +1292,8 @@ def map_clebsch_coords( Rotational transform on each node. L_lmn : ndarray Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. + L_basis : Basis + Spectral basis for lambda. Defaults to ``eq.L_basis``. tol : float Stopping tolerance. maxiter : int > 0 @@ -1313,12 +1316,12 @@ def map_clebsch_coords( """ return map_clebsch_coords( - self, clebsch_coords, iota, - L_lmn=L_lmn, - maxiter=maxiter, + L_lmn=setdefault(L_lmn, self.L_lmn), + L_basis=setdefault(L_basis, self.L_basis), tol=tol, + maxiter=maxiter, full_output=full_output, **kwargs, ) From 616099d840157f263385dae03196d93c487ce5e2 Mon Sep 17 00:00:00 2001 From: unalmis Date: Mon, 5 Aug 2024 23:54:29 -0400 Subject: [PATCH 07/14] Fix definition of alpha for general toroidal angle --- desc/compute/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/desc/compute/_core.py b/desc/compute/_core.py index a157d91412..70f70b73a2 100644 --- a/desc/compute/_core.py +++ b/desc/compute/_core.py @@ -1475,10 +1475,10 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["theta_PEST", "zeta", "iota"], + data=["theta_PEST", "phi", "iota"], ) def _alpha(params, transforms, profiles, data, **kwargs): - data["alpha"] = (data["theta_PEST"] - data["iota"] * data["zeta"]) % (2 * jnp.pi) + data["alpha"] = (data["theta_PEST"] - data["iota"] * data["phi"]) % (2 * jnp.pi) return data From 350dc92a78e4eff86b435516f6f5410c0afd36e8 Mon Sep 17 00:00:00 2001 From: unalmis Date: Thu, 8 Aug 2024 18:58:24 -0400 Subject: [PATCH 08/14] Compactify coordinate mapping API --- desc/compute/utils.py | 34 ++-- desc/equilibrium/coords.py | 287 ++++++++++++++++++++------------ desc/equilibrium/equilibrium.py | 140 ++++------------ desc/utils.py | 4 +- desc/vmec.py | 2 +- tests/test_equilibrium.py | 17 +- tests/test_plotting.py | 2 +- tests/utils.py | 2 +- 8 files changed, 247 insertions(+), 241 deletions(-) diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 4567637be3..92c41a000f 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -4,7 +4,6 @@ import inspect import numpy as np -from termcolor import colored from desc.backend import cond, execute_on_cpu, fori_loop, jnp, put from desc.grid import ConcentricGrid, Grid, LinearGrid @@ -43,7 +42,7 @@ def compute(parameterization, names, params, transforms, profiles, data=None, ** Type of object to compute for, eg Equilibrium, Curve, etc. names : str or array-like of str Name(s) of the quantity(s) to compute. - params : dict of ndarray + params : dict[str, jnp.ndarray] Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc. Defaults to attributes of self. transforms : dict of Transform @@ -51,7 +50,7 @@ def compute(parameterization, names, params, transforms, profiles, data=None, ** profiles : dict of Profile Profile objects for pressure, iota, current, etc. Defaults to attributes of self - data : dict of ndarray + data : dict[str, jnp.ndarray] Data computed so far, generally output from other compute functions. Any vector v = v¹ R̂ + v² ϕ̂ + v³ Ẑ should be given in components v = [v¹, v², v³] where R̂, ϕ̂, Ẑ are the normalized basis vectors @@ -212,7 +211,7 @@ def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None): Whether the grid to compute on has a node on the magnetic axis. basis : {"rpz", "xyz"} Basis of computed quantities. - data : dict of ndarray + data : dict[str, jnp.ndarray] Data computed so far, generally output from other compute functions Returns @@ -287,7 +286,7 @@ def _get_deps(parameterization, names, deps, data=None, has_axis=False, check_fu Name(s) of the quantity(s) to compute. deps : set[str] Dependencies gathered so far. - data : dict of ndarray or None + data : dict[str, jnp.ndarray] Data computed so far, generally output from other compute functions. has_axis : bool Whether the grid to compute on has a node on the magnetic axis. @@ -375,7 +374,7 @@ def get_derivs(keys, obj, has_axis=False, basis="rpz"): Returns ------- - derivs : dict of list of int + derivs : dict[list, str] Orders of derivatives needed to compute key. Keys for R, Z, L, etc @@ -465,7 +464,7 @@ def get_params(keys, obj, has_axis=False, basis="rpz"): Returns ------- - params : list of str or dict of ndarray + params : list[str] or dict[str, jnp.ndarray] Parameters needed to compute key. If eq is None, returns a list of the names of params needed otherwise, returns a dict of ndarray with keys for R_lmn, Z_lmn, etc. @@ -624,13 +623,13 @@ def has_dependencies(parameterization, qty, params, transforms, profiles, data): Type of thing we're checking dependencies for. eg desc.equilibrium.Equilibrium qty : str Name of something from the data index. - params : dict of ndarray + params : dict[str, jnp.ndarray] Dictionary of parameters we have. - transforms : dict of Transform + transforms : dict[str, Transform] Dictionary of transforms we have. - profiles : dict of Profile + profiles : dict[str, Profile] Dictionary of profiles we have. - data : dict of ndarray + data : dict[str, jnp.ndarray] Dictionary of what we've computed so far. Returns @@ -988,8 +987,10 @@ def line_integrals( line_label != "poloidal" and isinstance(grid, ConcentricGrid), msg="ConcentricGrid should only be used for poloidal line integrals.", ) - msg = colored("Correctness not guaranteed on grids with duplicate nodes.", "yellow") - warnif(isinstance(grid, LinearGrid) and grid.endpoint, msg=msg) + warnif( + isinstance(grid, LinearGrid) and grid.endpoint, + msg="Correctness not guaranteed on grids with duplicate nodes.", + ) # Generate a new quantity q_prime which is zero everywhere # except on the fixed surface, on which q_prime takes the value of q. # Then forward the computation to surface_integrals(). @@ -1075,11 +1076,8 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14) surface_label = grid.get_label(surface_label) warnif( surface_label == "poloidal" and isinstance(grid, ConcentricGrid), - msg=colored( - "Integrals over constant poloidal surfaces" - " are poorly defined for ConcentricGrid.", - "yellow", - ), + msg="Integrals over constant poloidal surfaces" + " are poorly defined for ConcentricGrid.", ) unique_size, inverse_idx, spacing, has_endpoint_dupe, has_idx = _get_grid_surface( grid, surface_label diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 5ca8611cae..15f69ce843 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -1,17 +1,19 @@ """Functions for mapping between flux, sfl, and real space coordinates.""" import functools -import warnings import numpy as np -from termcolor import colored from desc.backend import fori_loop, jit, jnp, put, root, root_scalar, vmap from desc.compute import compute as compute_fun from desc.compute import data_index, get_data_deps, get_profiles, get_transforms from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid from desc.transform import Transform -from desc.utils import setdefault +from desc.utils import check_posint, errorif, setdefault, warnif + + +def _periodic(x, period): + return jnp.where(jnp.isfinite(period), x % period, x) def map_coordinates( # noqa: C901 @@ -21,7 +23,7 @@ def map_coordinates( # noqa: C901 outbasis=("rho", "theta", "zeta"), guess=None, params=None, - period=(np.inf, np.inf, np.inf), + period=None, tol=1e-6, maxiter=30, full_output=False, @@ -38,7 +40,7 @@ def map_coordinates( # noqa: C901 Parameters ---------- eq : Equilibrium - Equilibrium to use + Equilibrium to use. coords : ndarray, shape(k,3) 2D array of input coordinates. Each row is a different point in space. inbasis, outbasis : tuple of str @@ -49,6 +51,7 @@ def map_coordinates( # noqa: C901 Initial guess for the computational coordinates ['rho', 'theta', 'zeta'] corresponding to coords in inbasis. If None, heuristics are used based on inbasis or a nearest neighbor search on a grid. + In most cases, this must be given to be compatible with JIT. params : dict Values of equilibrium parameters to use, e.g. ``eq.params_dict``. period : tuple of float @@ -56,7 +59,7 @@ def map_coordinates( # noqa: C901 Use np.inf to denote no periodicity. tol : float Stopping tolerance. - maxiter : int > 0 + maxiter : int Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from @@ -67,7 +70,7 @@ def map_coordinates( # noqa: C901 Returns ------- - coords : ndarray, shape(k,3) + out : ndarray, shape(k,3) Coordinates mapped from inbasis to outbasis. Values of NaN will be returned for coordinates where root finding did not succeed, possibly because the coordinate is not in the plasma volume. @@ -75,36 +78,57 @@ def map_coordinates( # noqa: C901 2 element tuple containing residuals and number of iterations for each point. Only returned if ``full_output`` is True - Notes - ----- - ``guess`` must be given for this function to be compatible with ``jit``. - """ + check_posint(maxiter, allow_none=False) + errorif( + not np.isfinite(tol) or tol <= 0, + ValueError, + f"tol must be a positive float, got {tol}", + ) + params = setdefault(params, eq.params_dict) inbasis = tuple(inbasis) outbasis = tuple(outbasis) - assert ( - np.isfinite(maxiter) and maxiter > 0 - ), f"maxiter must be a positive integer, got {maxiter}" - assert np.isfinite(tol) and tol > 0, f"tol must be a positive float, got {tol}" - basis_derivs = tuple([f"{X}_{d}" for X in inbasis for d in ("r", "t", "z")]) + if outbasis == ("rho", "theta", "zeta"): + # TODO: get iota if not supplied using below logic + if inbasis == ("rho", "alpha", "zeta") and "iota" in kwargs: + return _map_clebsch_coordinates( + coords, + kwargs.pop("iota"), + params["L_lmn"], + eq.L_basis, + guess, + tol, + maxiter, + full_output, + **kwargs, + ) + if inbasis == ("rho", "theta_PEST", "zeta"): + return _map_PEST_coordinates( + coords, + params["L_lmn"], + eq.L_basis, + guess, + tol, + maxiter, + full_output, + **kwargs, + ) + + basis_derivs = tuple(f"{X}_{d}" for X in inbasis for d in ("r", "t", "z")) for key in basis_derivs: - assert ( - key in data_index["desc.equilibrium.equilibrium.Equilibrium"] - ), f"don't have recipe to compute partial derivative {key}" + errorif( + key not in data_index["desc.equilibrium.equilibrium.Equilibrium"], + NotImplementedError, + f"don't have recipe to compute partial derivative {key}", + ) rhomin = kwargs.pop("rhomin", tol / 10) - kwargs.setdefault("tol", tol) - kwargs.setdefault("maxiter", maxiter) - period = np.asarray(period) + warnif(period is None, msg="Assuming no periodicity.") + period = np.asarray(setdefault(period, (np.inf, np.inf, np.inf))) + coords = _periodic(coords, period) - def periodic(x): - return jnp.where(jnp.isfinite(period), x % period, x) - - coords = periodic(coords) - - params = setdefault(params, eq.params_dict) - profiles = get_profiles(inbasis + basis_derivs, eq, None) + profiles = get_profiles(inbasis + basis_derivs, eq) p = "desc.equilibrium.equilibrium.Equilibrium" names = inbasis + basis_derivs + outbasis deps = list(set(get_data_deps(names, obj=p) + list(names))) @@ -132,7 +156,7 @@ def compute(y, basis): @jit def residual(y, coords): xk = compute(y, inbasis) - r = periodic(xk) - periodic(coords) + r = _periodic(xk, period) - _periodic(coords, period) return jnp.where((r > period / 2) & jnp.isfinite(period), -period + r, r) @jit @@ -155,12 +179,23 @@ def fixup(y, *args): if yk is None: yk = _initial_guess_heuristic(yk, coords, inbasis, eq, profiles) if yk is None: - yk = _initial_guess_nn_search(yk, coords, inbasis, eq, period, compute) + yk = _initial_guess_nn_search(coords, inbasis, eq, period, compute) yk = fixup(yk) vecroot = jit( - vmap(lambda x0, *p: root(residual, x0, jac=jac, args=p, fixup=fixup, **kwargs)) + vmap( + lambda x0, *p: root( + residual, + x0, + jac=jac, + args=p, + fixup=fixup, + tol=tol, + maxiter=maxiter, + **kwargs, + ) + ) ) # See description here # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532 @@ -208,22 +243,19 @@ def _initial_guess_heuristic(yk, coords, inbasis, eq, profiles): iota = profiles["iota"](rho) theta = (alpha + iota * zeta) % (2 * jnp.pi) - yk = jnp.array([rho, theta, zeta]).T + yk = jnp.column_stack([rho, theta, zeta]) return yk -def _initial_guess_nn_search(yk, coords, inbasis, eq, period, compute): +def _initial_guess_nn_search(coords, inbasis, eq, period, compute): # nearest neighbor search on dense grid yg = ConcentricGrid(eq.L_grid, eq.M_grid, max(eq.N_grid, eq.M_grid)).nodes xg = compute(yg, inbasis) idx = jnp.zeros(len(coords)).astype(int) coords = jnp.asarray(coords) - def periodic(x): - return jnp.where(jnp.isfinite(period), x % period, x) - def _distance_body(i, idx): - d = periodic(coords[i]) - periodic(xg) + d = _periodic(coords[i], period) - _periodic(xg, period) d = jnp.where((d > period / 2) & jnp.isfinite(period), period - d, d) distance = jnp.linalg.norm(d, axis=-1) k = jnp.argmin(distance) @@ -234,24 +266,32 @@ def _distance_body(i, idx): return yg[idx] -def compute_theta_coords( - eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs +# TODO: decide later whether to assume given phi instead of zeta. +def _map_PEST_coordinates( + PEST_coords, + L_lmn, + L_basis, + guess, + tol=1e-6, + maxiter=30, + full_output=False, + **kwargs, ): """Find θ (theta_DESC) for given straight field line ϑ (theta_PEST). Parameters ---------- - eq : Equilibrium - Equilibrium to use. - flux_coords : ndarray + PEST_coords : ndarray Shape (k, 3). Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. Each row is a different point in space. - L_lmn : ndarray - Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. + L_lmn : jnp.ndarray + Spectral coefficients for lambda. + L_basis : Basis + Spectral basis for lambda. tol : float Stopping tolerance. - maxiter : int > 0 + maxiter : int Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from @@ -262,7 +302,7 @@ def compute_theta_coords( Returns ------- - coords : ndarray + out : ndarray Shape (k, 3). DESC computational coordinates [ρ, θ, ζ]. info : tuple @@ -270,20 +310,17 @@ def compute_theta_coords( Only returned if ``full_output`` is True. """ - kwargs.setdefault("maxiter", maxiter) - kwargs.setdefault("tol", tol) - - if L_lmn is None: - L_lmn = eq.L_lmn - rho, theta_PEST, zeta = flux_coords.T + rho, theta_PEST, zeta = PEST_coords.T theta_PEST = theta_PEST % (2 * np.pi) + # Assume λ=0 for initial guess. + guess = setdefault(guess, theta_PEST) # Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0. def rootfun(theta_DESC, theta_PEST, rho, zeta): nodes = jnp.array( [rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2 ) - A = eq.L_basis.evaluate(nodes) + A = L_basis.evaluate(nodes) lmbda = A @ L_lmn theta_PEST_k = (theta_DESC + lmbda) % (2 * np.pi) r = theta_PEST_k - theta_PEST @@ -297,7 +334,7 @@ def jacfun(theta_DESC, theta_PEST, rho, zeta): nodes = jnp.array( [rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2 ) - A1 = eq.L_basis.evaluate(nodes, (0, 1, 0)) + A1 = L_basis.evaluate(nodes, (0, 1, 0)) lmbda_t = jnp.dot(A1, L_lmn) return 1 + lmbda_t.squeeze() @@ -307,28 +344,36 @@ def fixup(x, *args): vecroot = jit( vmap( lambda x0, *p: root_scalar( - rootfun, x0, jac=jacfun, args=p, fixup=fixup, **kwargs + rootfun, + x0, + jac=jacfun, + args=p, + fixup=fixup, + tol=tol, + maxiter=maxiter, + **kwargs, ) ) ) # Assume λ=0 for initial guess. - theta_DESC, (res, niter) = vecroot(theta_PEST, theta_PEST, rho, zeta) + theta_DESC, (res, niter) = vecroot(guess, theta_PEST, rho, zeta) - nodes = jnp.array([rho, jnp.atleast_1d(theta_DESC.squeeze()), zeta]).T + out = jnp.column_stack([rho, jnp.atleast_1d(theta_DESC.squeeze()), zeta]) - out = nodes if full_output: return out, (res, niter) return out -def map_clebsch_coords( +# TODO: decide later whether to assume given phi instead of zeta. +def _map_clebsch_coordinates( clebsch_coords, iota, L_lmn, L_basis, + guess=None, tol=1e-6, - maxiter=20, + maxiter=30, full_output=False, **kwargs, ): @@ -336,8 +381,6 @@ def map_clebsch_coords( Parameters ---------- - eq : Equilibrium - Equilibrium to use. clebsch_coords : ndarray Shape (k, 3). Clebsch field line coordinates [ρ, α, ζ]. Assumes ζ = ϕ. @@ -345,13 +388,15 @@ def map_clebsch_coords( iota : ndarray Shape (k, ) Rotational transform on each node. - L_lmn : ndarray + L_lmn : jnp.ndarray Spectral coefficients for lambda. L_basis : Basis Spectral basis for lambda. + guess : ndarray, shape(k,3) + Optional initial guess for the computational coordinates. tol : float Stopping tolerance. - maxiter : int > 0 + maxiter : int Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from @@ -362,7 +407,7 @@ def map_clebsch_coords( Returns ------- - coords : ndarray + out : ndarray Shape (k, 3). DESC computational coordinates [ρ, θ, ζ]. info : tuple @@ -370,20 +415,18 @@ def map_clebsch_coords( Only returned if ``full_output`` is True. """ - kwargs.setdefault("maxiter", maxiter) - kwargs.setdefault("tol", tol) - rho, alpha, zeta = clebsch_coords.T - alpha = alpha % (2 * np.pi) + if guess is None: + # Assume λ=0 for initial guess. + guess = (alpha + iota * zeta) % (2 * np.pi) # Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0. def rootfun(theta, alpha, rho, zeta, iota): nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2) A = L_basis.evaluate(nodes) lmbda = A @ L_lmn - # TODO: generalize for toroidal angle - alpha_k = (theta + lmbda - iota * zeta) % (2 * np.pi) - r = alpha_k - alpha + alpha_k = theta + lmbda - iota * zeta + r = (alpha_k - alpha) % (2 * np.pi) # r should be between -pi and pi r = jnp.where(r > np.pi, r - 2 * np.pi, r) r = jnp.where(r < -np.pi, r + 2 * np.pi, r) @@ -402,18 +445,20 @@ def fixup(x, *args): vecroot = jit( vmap( lambda x0, *p: root_scalar( - rootfun, x0, jac=jacfun, args=p, fixup=fixup, **kwargs + rootfun, + x0, + jac=jacfun, + args=p, + fixup=fixup, + tol=tol, + maxiter=maxiter, + **kwargs, ) ) ) - # Assume λ=0 for initial guess. - theta, (res, niter) = vecroot( - (alpha + iota * zeta) % (2 * np.pi), alpha, rho, zeta, iota - ) + theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota) + out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta]) - nodes = jnp.array([rho, jnp.atleast_1d(theta.squeeze()), zeta]).T - - out = nodes if full_output: return out, (res, niter) return out @@ -461,11 +506,7 @@ def is_nested(eq, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): data = compute_fun( "desc.equilibrium.equilibrium.Equilibrium", "sqrt(g)_PEST", - params={ - "R_lmn": R_lmn, - "Z_lmn": Z_lmn, - "L_lmn": L_lmn, - }, + params={"R_lmn": R_lmn, "Z_lmn": Z_lmn, "L_lmn": L_lmn}, transforms=transforms, profiles={}, # no profiles needed ) @@ -473,25 +514,20 @@ def is_nested(eq, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): nested = jnp.all( jnp.sign(data["sqrt(g)_PEST"][0]) == jnp.sign(data["sqrt(g)_PEST"]) ) - if not nested: - if msg == "auto": - warnings.warn( - colored( - "WARNING: Flux surfaces are no longer nested, exiting early. " - + "Automatic continuation method failed, consider specifying " - + "continuation steps manually", - "yellow", - ) - ) - elif msg == "manual": - warnings.warn( - colored( - "WARNING: Flux surfaces are no longer nested, exiting early." - + "Consider taking smaller perturbation/resolution steps " - + "or reducing trust radius", - "yellow", - ) - ) + warnif( + not nested, + RuntimeWarning, + ( + "Flux surfaces are no longer nested, exiting early. " + + { + "auto": "Automatic continuation method failed, consider specifying " + "continuation steps manually.", + "manual": "Consider taking smaller perturbation/resolution steps " + "or reducing trust radius.", + None: "", + }[msg] + ), + ) return nested @@ -672,3 +708,44 @@ def get_rtz_grid(eq, radial, poloidal, toroidal, coordinates, period, jitable=Tr _inverse_rho_idx=grid.inverse_rho_idx, ) return desc_grid + + +def compute_theta_coords( + eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs +): + """Find θ (theta_DESC) for given straight field line ϑ (theta_PEST). + + Parameters + ---------- + eq : Equilibrium + Equilibrium to use. + flux_coords : ndarray + Shape (k, 3). + Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. + Each row is a different point in space. + L_lmn : ndarray + Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. + tol : float + Stopping tolerance. + maxiter : int + Maximum number of Newton iterations. + full_output : bool, optional + If True, also return a tuple where the first element is the residual from + the root finding and the second is the number of iterations. + kwargs : dict, optional + Additional keyword arguments to pass to ``root_scalar`` such as + ``maxiter_ls``, ``alpha``. + + Returns + ------- + coords : ndarray + Shape (k, 3). + DESC computational coordinates [ρ, θ, ζ]. + info : tuple + 2 element tuple containing residuals and number of iterations for each + point. Only returned if ``full_output`` is True. + + """ + return eq.compute_theta_coords( + flux_coords, L_lmn, tol, maxiter, full_output, **kwargs + ) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 758ade21e5..ad5318fa04 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -2,13 +2,11 @@ import copy import numbers -import warnings from collections.abc import MutableSequence import numpy as np from scipy import special from scipy.constants import mu_0 -from termcolor import colored from desc.backend import execute_on_cpu, jnp from desc.basis import FourierZernikeBasis, fourier, zernike_radial @@ -53,13 +51,7 @@ ) from ..compute.data_index import is_0d_vol_grid, is_1dr_rad_grid, is_1dz_tor_grid -from .coords import ( - compute_theta_coords, - is_nested, - map_clebsch_coords, - map_coordinates, - to_sfl, -) +from .coords import is_nested, map_coordinates, to_sfl from .initial_guess import set_initial_guess from .utils import parse_axis, parse_profile, parse_surface @@ -361,10 +353,10 @@ def __init__( p = getattr(self, profile) if hasattr(p, "change_resolution"): p.change_resolution(max(p.basis.L, self.L)) - if isinstance(p, PowerSeriesProfile) and p.sym != "even": - warnings.warn( - colored(f"{profile} profile is not an even power series.", "yellow") - ) + warnif( + isinstance(p, PowerSeriesProfile) and p.sym != "even", + msg=f"{profile} profile is not an even power series.", + ) # ensure number of field periods agree before setting guesses eq_NFP = self.NFP @@ -830,7 +822,7 @@ def compute( # noqa: C901 profiles : dict of Profile Profile objects for pressure, iota, current, etc. Defaults to attributes of self - data : dict of ndarray + data : dict[str, jnp.ndarray] Data computed so far, generally output from other compute functions. Any vector v = v¹ R̂ + v² ϕ̂ + v³ Ẑ should be given in components v = [v¹, v², v³] where R̂, ϕ̂, Ẑ are the normalized basis vectors @@ -983,10 +975,9 @@ def need_src(name): for dep in deps: req = data_index[p][dep]["resolution_requirement"] coords = data_index[p][dep]["coordinates"] - msg = lambda direction: colored( + msg = lambda direction: ( f"Dependency {dep} may require more {direction}" - " resolution to compute accurately.", - "yellow", + " resolution to compute accurately." ) warnif( # if need more radial resolution @@ -1148,14 +1139,14 @@ def need_src(name): ) return data - def map_coordinates( # noqa: C901 + def map_coordinates( self, coords, inbasis, outbasis=("rho", "theta", "zeta"), guess=None, params=None, - period=(np.inf, np.inf, np.inf), + period=None, tol=1e-6, maxiter=30, full_output=False, @@ -1179,6 +1170,7 @@ def map_coordinates( # noqa: C901 Initial guess for the computational coordinates ['rho', 'theta', 'zeta'] corresponding to coords in inbasis. If None, heuristics are used based on in basis and a nearest neighbor search on a coarse grid. + In most cases, this must be given to be compatible with JIT. params : dict Values of equilibrium parameters to use, eg eq.params_dict period : tuple of float @@ -1186,7 +1178,7 @@ def map_coordinates( # noqa: C901 Use np.inf to denote no periodicity. tol : float Stopping tolerance. - maxiter : int > 0 + maxiter : int Maximum number of Newton iterations full_output : bool, optional If True, also return a tuple where the first element is the residual from @@ -1197,15 +1189,11 @@ def map_coordinates( # noqa: C901 Returns ------- - coords : ndarray, shape(k,3) + out : ndarray, shape(k,3) Coordinates mapped from inbasis to outbasis. info : tuple 2 element tuple containing residuals and number of iterations - for each point. Only returned if ``full_output`` is True - - Notes - ----- - ``guess`` must be given for this function to be compatible with ``jit``. + for each point. Only returned if ``full_output`` is True. """ return map_coordinates( @@ -1237,7 +1225,7 @@ def compute_theta_coords( Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. tol : float Stopping tolerance. - maxiter : int > 0 + maxiter : int Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from @@ -1256,70 +1244,14 @@ def compute_theta_coords( point. Only returned if ``full_output`` is True. """ - return compute_theta_coords( + warnif(True, DeprecationWarning, msg="Use map_coordinates instead.") + return map_coordinates( self, flux_coords, - L_lmn=L_lmn, - tol=tol, - maxiter=maxiter, - full_output=full_output, - **kwargs, - ) - - def map_clebsch_coords( - self, - clebsch_coords, - iota, - L_lmn=None, - L_basis=None, - tol=1e-6, - maxiter=20, - full_output=False, - **kwargs, - ): - """Find θ for given Clebsch field line poloidal label α. - - Parameters - ---------- - eq : Equilibrium - Equilibrium to use. - clebsch_coords : ndarray - Shape (k, 3). - Clebsch field line coordinates [ρ, α, ζ]. Assumes ζ = ϕ. - Each row is a different point in space. - iota : ndarray - Shape (k, ) - Rotational transform on each node. - L_lmn : ndarray - Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. - L_basis : Basis - Spectral basis for lambda. Defaults to ``eq.L_basis``. - tol : float - Stopping tolerance. - maxiter : int > 0 - Maximum number of Newton iterations. - full_output : bool, optional - If True, also return a tuple where the first element is the residual from - the root finding and the second is the number of iterations. - kwargs : dict, optional - Additional keyword arguments to pass to ``root_scalar`` such as - ``maxiter_ls``, ``alpha``. - - Returns - ------- - coords : ndarray - Shape (k, 3). - DESC computational coordinates [ρ, θ, ζ]. - info : tuple - 2 element tuple containing residuals and number of iterations for each - point. Only returned if ``full_output`` is True. - - """ - return map_clebsch_coords( - clebsch_coords, - iota, - L_lmn=setdefault(L_lmn, self.L_lmn), - L_basis=setdefault(L_basis, self.L_basis), + inbasis=("rho", "theta_PEST", "zeta"), + outbasis=("rho", "theta", "zeta"), + params=setdefault({"L_lmn": L_lmn}, self.params_dict, L_lmn), + basis={"L_basis": self.L_basis}, tol=tol, maxiter=maxiter, full_output=full_output, @@ -2098,22 +2030,20 @@ def solve( if not isinstance(constraints, (list, tuple)): constraints = tuple([constraints]) - if self.N > self.N_grid or self.M > self.M_grid or self.L > self.L_grid: - warnings.warn( - colored( - "Equilibrium has one or more spectral resolutions " - + "greater than the corresponding collocation grid resolution! " - + "This is not recommended and may result in poor convergence. " - + "Set grid resolutions to be higher, (i.e. eq.N_grid=2*eq.N) " - + "to avoid this warning.", - "yellow", - ) - ) - if self.bdry_mode == "poincare": - raise NotImplementedError( - "Solving equilibrium with poincare XS as BC is not supported yet " - + "on master branch." - ) + warnif( + self.N > self.N_grid or self.M > self.M_grid or self.L > self.L_grid, + msg="Equilibrium has one or more spectral resolutions " + + "greater than the corresponding collocation grid resolution! " + + "This is not recommended and may result in poor convergence. " + + "Set grid resolutions to be higher, (i.e. eq.N_grid=2*eq.N) " + + "to avoid this warning.", + ) + errorif( + self.bdry_mode == "poincare", + NotImplementedError, + "Solving equilibrium with poincare XS as BC is not supported yet " + + "on master branch.", + ) things, result = optimizer.optimize( self, diff --git a/desc/utils.py b/desc/utils.py index 36b58f2281..a1df551229 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -532,7 +532,7 @@ def errorif(cond, err=ValueError, msg=""): just AssertionError. """ if cond: - raise err(msg) + raise err(colored(msg, "red")) class ResolutionWarning(UserWarning): @@ -544,7 +544,7 @@ class ResolutionWarning(UserWarning): def warnif(cond, err=UserWarning, msg=""): """Throw a warning if condition is met.""" if cond: - warnings.warn(msg, err) + warnings.warn(colored(msg, "yellow"), err) def check_nonnegint(x, name="", allow_none=True): diff --git a/desc/vmec.py b/desc/vmec.py index 7a68f52de9..47507ebd97 100644 --- a/desc/vmec.py +++ b/desc/vmec.py @@ -1687,7 +1687,7 @@ def root_fun(theta): axis=-1, ) theta_star_k = theta + lmbda # theta* = theta + lambda - err = theta_star - theta_star_k + err = theta_star - theta_star_k # FIXME: mod by 2pi return err out = optimize.root( diff --git a/tests/test_equilibrium.py b/tests/test_equilibrium.py index b67c723038..834775ecae 100644 --- a/tests/test_equilibrium.py +++ b/tests/test_equilibrium.py @@ -20,8 +20,8 @@ @pytest.mark.unit -def test_compute_theta_coords(): - """Test root finding for theta(theta*,lambda(theta)).""" +def test_map_PEST_coordinates(): + """Test root finding for theta(theta_PEST,lambda(theta)).""" eq = get("DSHAPE_CURRENT") with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(3, 3, 0, 6, 6, 0) @@ -34,7 +34,7 @@ def test_compute_theta_coords(): flux_coords = nodes.copy() flux_coords[:, 1] += coords["lambda"] - geom_coords = eq.compute_theta_coords(flux_coords) + geom_coords = eq.map_coordinates(flux_coords, inbasis=("rho", "theta_PEST", "zeta")) geom_coords = np.array(geom_coords) # catch difference between 0 and 2*pi @@ -55,17 +55,16 @@ def test_map_coordinates(): iota = grid.compress(eq.compute("iota", grid=grid)["iota"]) iota = np.broadcast_to(iota, shape=(n,)) - out_1 = eq.map_clebsch_coords(coords, iota) + out_1 = eq.map_coordinates(coords, inbasis=["rho", "alpha", "zeta"], iota=iota) assert np.isfinite(out_1).all() out_2 = eq.map_coordinates( coords, - ["rho", "alpha", "zeta"], - ["rho", "theta", "zeta"], + inbasis=["rho", "alpha", "zeta"], period=(np.inf, 2 * np.pi, np.inf), ) assert np.isfinite(out_2).all() diff = (out_1 - out_2) % (2 * np.pi) - assert np.all(np.isclose(diff, 0) | np.isclose(diff, 2 * np.pi)) + assert np.all(np.isclose(diff, 0) | np.isclose(np.abs(diff), 2 * np.pi)) eq = get("DSHAPE") @@ -142,7 +141,9 @@ def foo(params, in_coords): @jax.jit def bar(L_lmn): - geom_coords = eq.compute_theta_coords(flux_coords, L_lmn) + geom_coords = eq.map_coordinates( + flux_coords, inbasis=("rho", "theta_PEST", "zeta") + ) return geom_coords J1 = jax.jit(jax.jacfwd(bar))(eq.params_dict["L_lmn"]) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 843e125319..697c9dd893 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -839,7 +839,7 @@ def test_plot_b_mag(): rhoa = rho * np.ones_like(zeta) c = np.vstack([rhoa, thetas, zeta]).T - coords = eq.compute_theta_coords(c) + coords = eq._map_coordinates(c, inbasis=("rho", "theta_PEST", "zeta")) grid = Grid(coords) # compute |B| normalized in the usual flux tube way diff --git a/tests/utils.py b/tests/utils.py index 85960f804d..56077b1271 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,7 +36,7 @@ def compute_coords(equil, Nr=10, Nt=8, Nz=None): # angle in PEST-like flux coordinates # find theta angles corresponding to desired theta* angles - v_grid = Grid(equil.compute_theta_coords(t_grid.nodes)) + v_grid = Grid(equil._map_PEST_coordinates(t_grid.nodes)) r_coords = equil.compute(["R", "Z"], grid=r_grid) v_coords = equil.compute(["R", "Z"], grid=v_grid) From eab2117addf9a211f1d41767e47f27ba3afeab8a Mon Sep 17 00:00:00 2001 From: unalmis Date: Thu, 8 Aug 2024 19:08:20 -0400 Subject: [PATCH 09/14] Add kwargs to get_rtz_grid --- desc/equilibrium/coords.py | 10 ++++++++-- desc/grid.py | 5 ++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 15f69ce843..19e30af09c 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -648,7 +648,9 @@ def to_sfl( return eq_sfl -def get_rtz_grid(eq, radial, poloidal, toroidal, coordinates, period, jitable=True): +def get_rtz_grid( + eq, radial, poloidal, toroidal, coordinates, period, jitable=True, **kwargs +): """Return DESC grid in rtz (rho, theta, zeta) coordinates from given coordinates. Create a tensor-product grid from the given coordinates, and return the same grid @@ -691,12 +693,15 @@ def get_rtz_grid(eq, radial, poloidal, toroidal, coordinates, period, jitable=Tr "v": "theta_PEST", "a": "alpha", "z": "zeta", + "p": "phi", } - rtz_nodes = eq.map_coordinates( + rtz_nodes = map_coordinates( + eq, grid.nodes, inbasis=[inbasis[char] for char in coordinates], outbasis=("rho", "theta", "zeta"), period=period, + **kwargs, ) desc_grid = Grid( nodes=rtz_nodes, @@ -704,6 +709,7 @@ def get_rtz_grid(eq, radial, poloidal, toroidal, coordinates, period, jitable=Tr source_grid=grid, sort=False, jitable=jitable, + # Assumes that in basis radial coordinate is single variable function of rho. _unique_rho_idx=grid.unique_rho_idx, _inverse_rho_idx=grid.inverse_rho_idx, ) diff --git a/desc/grid.py b/desc/grid.py index aad2c86f33..06ce1329ae 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -715,9 +715,8 @@ def __init__( self._N = self.num_nodes errorif(len(kwargs), ValueError, f"Got unexpected kwargs {kwargs.keys()}") - @classmethod + @staticmethod def create_meshgrid( - cls, nodes, spacing=None, coordinates="rtz", @@ -790,7 +789,7 @@ def create_meshgrid( a.size, ) inverse_c_idx = jnp.tile(unique_c_idx, a.size * b.size) - return cls( + return Grid( nodes=nodes, spacing=spacing, weights=weights, From bf795505ff188b338b3f7cb06e5d78c212f4ca61 Mon Sep 17 00:00:00 2001 From: unalmis Date: Thu, 8 Aug 2024 19:23:39 -0400 Subject: [PATCH 10/14] Replace some calls to compute_theta_coords with new function --- tests/test_plotting.py | 2 +- tests/utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 697c9dd893..8252d349d8 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -839,7 +839,7 @@ def test_plot_b_mag(): rhoa = rho * np.ones_like(zeta) c = np.vstack([rhoa, thetas, zeta]).T - coords = eq._map_coordinates(c, inbasis=("rho", "theta_PEST", "zeta")) + coords = eq.map_coordinates(c, inbasis=("rho", "theta_PEST", "zeta")) grid = Grid(coords) # compute |B| normalized in the usual flux tube way diff --git a/tests/utils.py b/tests/utils.py index 56077b1271..8799bfdbf2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,7 +36,9 @@ def compute_coords(equil, Nr=10, Nt=8, Nz=None): # angle in PEST-like flux coordinates # find theta angles corresponding to desired theta* angles - v_grid = Grid(equil._map_PEST_coordinates(t_grid.nodes)) + v_grid = Grid( + equil.map_coordinates(t_grid.nodes, inbasis=("rho", "theta_PEST", "zeta")) + ) r_coords = equil.compute(["R", "Z"], grid=r_grid) v_coords = equil.compute(["R", "Z"], grid=v_grid) From b31c26404fe18444a072155cfa920953e3e8668a Mon Sep 17 00:00:00 2001 From: unalmis Date: Thu, 8 Aug 2024 19:44:24 -0400 Subject: [PATCH 11/14] Clean up docstring --- desc/equilibrium/coords.py | 54 ++++++++++++++++----------------- desc/equilibrium/equilibrium.py | 42 ++++++++++++++----------- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 19e30af09c..7528e0e0fb 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -29,33 +29,34 @@ def map_coordinates( # noqa: C901 full_output=False, **kwargs, ): - """Given coordinates in inbasis, compute corresponding coordinates in outbasis. + """Transform coordinates given in ``inbasis`` to ``outbasis``. - First solves for the computational coordinates that correspond to inbasis, then - evaluates outbasis at those locations. + Solves for the computational coordinates that correspond to ``inbasis``, + then evaluates ``outbasis`` at those locations. - Speed can often be significantly improved by providing a reasonable initial guess. - The default is a nearest neighbor search on a grid. + Performance can often improve significantly given a reasonable initial guess. Parameters ---------- eq : Equilibrium Equilibrium to use. - coords : ndarray, shape(k,3) + coords : ndarray + Shape (k, 3). 2D array of input coordinates. Each row is a different point in space. inbasis, outbasis : tuple of str - Labels for input and output coordinates, eg ("R", "phi", "Z") or + Labels for input and output coordinates, e.g. ("R", "phi", "Z") or ("rho", "alpha", "zeta") or any combination thereof. Labels should be the same as the compute function data key. - guess : None or ndarray, shape(k,3) + guess : jnp.ndarray + Shape (k, 3). Initial guess for the computational coordinates ['rho', 'theta', 'zeta'] - corresponding to coords in inbasis. If None, heuristics are used based on - inbasis or a nearest neighbor search on a grid. - In most cases, this must be given to be compatible with JIT. + corresponding to ``coords`` in ``inbasis``. If not given, then heuristics + based on ``inbasis`` or a nearest neighbor search on a grid may be used. + In general, this must be given to be compatible with JIT. params : dict Values of equilibrium parameters to use, e.g. ``eq.params_dict``. period : tuple of float - Assumed periodicity for each quantity in inbasis. + Assumed periodicity for each quantity in ``inbasis``. Use np.inf to denote no periodicity. tol : float Stopping tolerance. @@ -70,13 +71,14 @@ def map_coordinates( # noqa: C901 Returns ------- - out : ndarray, shape(k,3) - Coordinates mapped from inbasis to outbasis. Values of NaN will be returned - for coordinates where root finding did not succeed, possibly because the - coordinate is not in the plasma volume. + out : jnp.ndarray + Shape (k, 3). + Coordinates mapped from ``inbasis`` to ``outbasis``. Values of NaN will be + returned for coordinates where root finding did not succeed, possibly + because the coordinate is not in the plasma volume. info : tuple 2 element tuple containing residuals and number of iterations - for each point. Only returned if ``full_output`` is True + for each point. Only returned if ``full_output`` is True. """ check_posint(maxiter, allow_none=False) @@ -517,16 +519,14 @@ def is_nested(eq, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): warnif( not nested, RuntimeWarning, - ( - "Flux surfaces are no longer nested, exiting early. " - + { - "auto": "Automatic continuation method failed, consider specifying " - "continuation steps manually.", - "manual": "Consider taking smaller perturbation/resolution steps " - "or reducing trust radius.", - None: "", - }[msg] - ), + "Flux surfaces are no longer nested, exiting early. " + + { + "auto": "Automatic continuation method failed, consider specifying " + "continuation steps manually.", + "manual": "Consider taking smaller perturbation/resolution steps " + "or reducing trust radius.", + None: "", + }[msg], ) return nested diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index ad5318fa04..c0bff9db1e 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1152,34 +1152,39 @@ def map_coordinates( full_output=False, **kwargs, ): - """Given coordinates in inbasis, compute corresponding coordinates in outbasis. + """Transform coordinates given in ``inbasis`` to ``outbasis``. - First solves for the computational coordinates that correspond to inbasis, then - evaluates outbasis at those locations. + Solves for the computational coordinates that correspond to ``inbasis``, + then evaluates ``outbasis`` at those locations. + + Performance can often improve significantly given a reasonable initial guess. Parameters ---------- - coords : ndarray, shape(k,3) - 2D array of input coordinates. Each row is a different - point in space. + eq : Equilibrium + Equilibrium to use. + coords : ndarray + Shape (k, 3). + 2D array of input coordinates. Each row is a different point in space. inbasis, outbasis : tuple of str - Labels for input and output coordinates, eg ("R", "phi", "Z") or + Labels for input and output coordinates, e.g. ("R", "phi", "Z") or ("rho", "alpha", "zeta") or any combination thereof. Labels should be the - same as the compute function data key - guess : None or ndarray, shape(k,3) + same as the compute function data key. + guess : jnp.ndarray + Shape (k, 3). Initial guess for the computational coordinates ['rho', 'theta', 'zeta'] - corresponding to coords in inbasis. If None, heuristics are used based on - in basis and a nearest neighbor search on a coarse grid. - In most cases, this must be given to be compatible with JIT. + corresponding to ``coords`` in ``inbasis``. If not given, then heuristics + based on ``inbasis`` or a nearest neighbor search on a grid may be used. + In general, this must be given to be compatible with JIT. params : dict - Values of equilibrium parameters to use, eg eq.params_dict + Values of equilibrium parameters to use, e.g. ``eq.params_dict``. period : tuple of float - Assumed periodicity for each quantity in inbasis. + Assumed periodicity for each quantity in ``inbasis``. Use np.inf to denote no periodicity. tol : float Stopping tolerance. maxiter : int - Maximum number of Newton iterations + Maximum number of Newton iterations. full_output : bool, optional If True, also return a tuple where the first element is the residual from the root finding and the second is the number of iterations. @@ -1189,8 +1194,11 @@ def map_coordinates( Returns ------- - out : ndarray, shape(k,3) - Coordinates mapped from inbasis to outbasis. + out : jnp.ndarray + Shape (k, 3). + Coordinates mapped from ``inbasis`` to ``outbasis``. Values of NaN will be + returned for coordinates where root finding did not succeed, possibly + because the coordinate is not in the plasma volume. info : tuple 2 element tuple containing residuals and number of iterations for each point. Only returned if ``full_output`` is True. From f81b79190eafbcd955738079a8a1d2683ececd2d Mon Sep 17 00:00:00 2001 From: unalmis Date: Thu, 8 Aug 2024 19:51:39 -0400 Subject: [PATCH 12/14] Add assumption to docstring --- desc/equilibrium/coords.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 7528e0e0fb..1d1eaea123 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -357,7 +357,6 @@ def fixup(x, *args): ) ) ) - # Assume λ=0 for initial guess. theta_DESC, (res, niter) = vecroot(guess, theta_PEST, rho, zeta) out = jnp.column_stack([rho, jnp.atleast_1d(theta_DESC.squeeze()), zeta]) @@ -662,6 +661,9 @@ def get_rtz_grid( Equilibrium on which to perform coordinate mapping. radial : ndarray Sorted unique radial coordinates. + These coordinates are assumed to be a single variable function of rho. + Create a GitHub issue if you have a use-case where this assumption + cannot be made. poloidal : ndarray Sorted unique poloidal coordinates. toroidal : ndarray From d2a4f6e4aaf44e1f5ccd7613ae158308250b5d8b Mon Sep 17 00:00:00 2001 From: unalmis Date: Thu, 8 Aug 2024 22:20:58 -0400 Subject: [PATCH 13/14] Fix some issues with previous commits --- desc/compute/_metric.py | 2 +- desc/equilibrium/coords.py | 43 +++++++++++++++++++-------------- desc/equilibrium/equilibrium.py | 5 ++-- desc/vmec.py | 4 ++- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 754116a98d..96228ffc06 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1803,7 +1803,7 @@ def _gradrho(params, transforms, profiles, data, **kwargs): @register_compute_fun( name="<|grad(rho)|>", # same as S(r) / V_r(r) - label="\\langle \\vert \\nabla \\rho \\vert \\rangle|", + label="\\langle \\vert \\nabla \\rho \\vert \\rangle", units="m^{-1}", units_long="inverse meters", description="Magnitude of contravariant radial basis vector, flux surface average", diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 1d1eaea123..2fda119f04 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -57,7 +57,7 @@ def map_coordinates( # noqa: C901 Values of equilibrium parameters to use, e.g. ``eq.params_dict``. period : tuple of float Assumed periodicity for each quantity in ``inbasis``. - Use np.inf to denote no periodicity. + Use ``np.inf`` to denote no periodicity. tol : float Stopping tolerance. maxiter : int @@ -91,6 +91,7 @@ def map_coordinates( # noqa: C901 inbasis = tuple(inbasis) outbasis = tuple(outbasis) + # TODO: make this work for permutations of in/out basis if outbasis == ("rho", "theta", "zeta"): # TODO: get iota if not supplied using below logic if inbasis == ("rho", "alpha", "zeta") and "iota" in kwargs: @@ -99,7 +100,7 @@ def map_coordinates( # noqa: C901 kwargs.pop("iota"), params["L_lmn"], eq.L_basis, - guess, + guess[:, 1] if guess is not None else None, tol, maxiter, full_output, @@ -110,7 +111,7 @@ def map_coordinates( # noqa: C901 coords, params["L_lmn"], eq.L_basis, - guess, + guess[:, 1] if guess is not None else None, tol, maxiter, full_output, @@ -270,7 +271,7 @@ def _distance_body(i, idx): # TODO: decide later whether to assume given phi instead of zeta. def _map_PEST_coordinates( - PEST_coords, + coords, L_lmn, L_basis, guess, @@ -283,7 +284,7 @@ def _map_PEST_coordinates( Parameters ---------- - PEST_coords : ndarray + coords : ndarray Shape (k, 3). Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. Each row is a different point in space. @@ -291,6 +292,9 @@ def _map_PEST_coordinates( Spectral coefficients for lambda. L_basis : Basis Spectral basis for lambda. + guess : jnp.ndarray + Shape (k, ). + Optional initial guess for the computational coordinates. tol : float Stopping tolerance. maxiter : int @@ -312,7 +316,7 @@ def _map_PEST_coordinates( Only returned if ``full_output`` is True. """ - rho, theta_PEST, zeta = PEST_coords.T + rho, theta_PEST, zeta = coords.T theta_PEST = theta_PEST % (2 * np.pi) # Assume λ=0 for initial guess. guess = setdefault(guess, theta_PEST) @@ -368,7 +372,7 @@ def fixup(x, *args): # TODO: decide later whether to assume given phi instead of zeta. def _map_clebsch_coordinates( - clebsch_coords, + coords, iota, L_lmn, L_basis, @@ -382,18 +386,19 @@ def _map_clebsch_coordinates( Parameters ---------- - clebsch_coords : ndarray + coords : ndarray Shape (k, 3). Clebsch field line coordinates [ρ, α, ζ]. Assumes ζ = ϕ. Each row is a different point in space. iota : ndarray - Shape (k, ) + Shape (k, ). Rotational transform on each node. L_lmn : jnp.ndarray Spectral coefficients for lambda. L_basis : Basis Spectral basis for lambda. - guess : ndarray, shape(k,3) + guess : jnp.ndarray + Shape (k, ). Optional initial guess for the computational coordinates. tol : float Stopping tolerance. @@ -416,7 +421,7 @@ def _map_clebsch_coordinates( Only returned if ``full_output`` is True. """ - rho, alpha, zeta = clebsch_coords.T + rho, alpha, zeta = coords.T if guess is None: # Assume λ=0 for initial guess. guess = (alpha + iota * zeta) % (2 * np.pi) @@ -516,7 +521,7 @@ def is_nested(eq, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): jnp.sign(data["sqrt(g)_PEST"][0]) == jnp.sign(data["sqrt(g)_PEST"]) ) warnif( - not nested, + not nested and msg is not None, RuntimeWarning, "Flux surfaces are no longer nested, exiting early. " + { @@ -661,9 +666,6 @@ def get_rtz_grid( Equilibrium on which to perform coordinate mapping. radial : ndarray Sorted unique radial coordinates. - These coordinates are assumed to be a single variable function of rho. - Create a GitHub issue if you have a use-case where this assumption - cannot be made. poloidal : ndarray Sorted unique poloidal coordinates. toroidal : ndarray @@ -705,19 +707,24 @@ def get_rtz_grid( period=period, **kwargs, ) + idx = {} + if inbasis[coordinates[0]] == "rho": + # Should work as long as inbasis radial coordinate is + # single variable, monotonic increasing function of rho. + idx["_unique_rho_idx"] = grid.unique_rho_idx + idx["_inverse_rho_idx"] = grid.inverse_rho_idx desc_grid = Grid( nodes=rtz_nodes, coordinates="rtz", source_grid=grid, sort=False, jitable=jitable, - # Assumes that in basis radial coordinate is single variable function of rho. - _unique_rho_idx=grid.unique_rho_idx, - _inverse_rho_idx=grid.inverse_rho_idx, + **idx, ) return desc_grid +# TODO: deprecated, remove eventually def compute_theta_coords( eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs ): diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index c0bff9db1e..60efde70a5 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1180,7 +1180,7 @@ def map_coordinates( Values of equilibrium parameters to use, e.g. ``eq.params_dict``. period : tuple of float Assumed periodicity for each quantity in ``inbasis``. - Use np.inf to denote no periodicity. + Use ``np.inf`` to denote no periodicity. tol : float Stopping tolerance. maxiter : int @@ -1258,8 +1258,7 @@ def compute_theta_coords( flux_coords, inbasis=("rho", "theta_PEST", "zeta"), outbasis=("rho", "theta", "zeta"), - params=setdefault({"L_lmn": L_lmn}, self.params_dict, L_lmn), - basis={"L_basis": self.L_basis}, + params=setdefault({"L_lmn": L_lmn}, self.params_dict, L_lmn is not None), tol=tol, maxiter=maxiter, full_output=full_output, diff --git a/desc/vmec.py b/desc/vmec.py index 47507ebd97..14c4d2346c 100644 --- a/desc/vmec.py +++ b/desc/vmec.py @@ -1751,7 +1751,9 @@ def compute_coord_surfaces(cls, equil, vmec_data, Nr=10, Nt=8, Nz=None, **kwargs # angle in PEST-like flux coordinates # find theta angles corresponding to desired theta* angles - v_grid = Grid(equil.compute_theta_coords(t_grid.nodes)) + v_grid = Grid( + equil.map_coordinates(t_grid.nodes, inbasis=("rho", "theta_PEST", "zeta")) + ) r_coords_desc = equil.compute(["R", "Z"], grid=r_grid) v_coords_desc = equil.compute(["R", "Z"], grid=v_grid) From 52e715337ed5c79dc31041f09cb00f24113e2a9f Mon Sep 17 00:00:00 2001 From: unalmis Date: Thu, 8 Aug 2024 22:25:22 -0400 Subject: [PATCH 14/14] Use simpler conditional --- desc/equilibrium/equilibrium.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 60efde70a5..41d205a595 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1258,7 +1258,7 @@ def compute_theta_coords( flux_coords, inbasis=("rho", "theta_PEST", "zeta"), outbasis=("rho", "theta", "zeta"), - params=setdefault({"L_lmn": L_lmn}, self.params_dict, L_lmn is not None), + params=self.params_dict if L_lmn is None else {"L_lmn": L_lmn}, tol=tol, maxiter=maxiter, full_output=full_output,