Skip to content

Commit

Permalink
Improve coordinate mapping between DESC and Clebsch coords (#1153)
Browse files Browse the repository at this point in the history
It's likely a significant performance improvement to do scalar root
finding directly on the $\lambda$ coefficients rather than 3d root
finding and building transforms unnecessarily. This should help for
objectives that do coordinate mapping in compute functions, like
neoclassical stuff.
  • Loading branch information
dpanici authored Aug 9, 2024
2 parents 44bac68 + 52e7153 commit cbf9ebb
Show file tree
Hide file tree
Showing 11 changed files with 439 additions and 228 deletions.
4 changes: 2 additions & 2 deletions desc/compute/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,10 +1475,10 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs):
transforms={},
profiles=[],
coordinates="rtz",
data=["theta_PEST", "zeta", "iota"],
data=["theta_PEST", "phi", "iota"],
)
def _alpha(params, transforms, profiles, data, **kwargs):
data["alpha"] = (data["theta_PEST"] - data["iota"] * data["zeta"]) % (2 * jnp.pi)
data["alpha"] = (data["theta_PEST"] - data["iota"] * data["phi"]) % (2 * jnp.pi)
return data


Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,7 +1803,7 @@ def _gradrho(params, transforms, profiles, data, **kwargs):

@register_compute_fun(
name="<|grad(rho)|>", # same as S(r) / V_r(r)
label="\\langle \\vert \\nabla \\rho \\vert \\rangle|",
label="\\langle \\vert \\nabla \\rho \\vert \\rangle",
units="m^{-1}",
units_long="inverse meters",
description="Magnitude of contravariant radial basis vector, flux surface average",
Expand Down
34 changes: 16 additions & 18 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect

import numpy as np
from termcolor import colored

from desc.backend import cond, execute_on_cpu, fori_loop, jnp, put
from desc.grid import ConcentricGrid, Grid, LinearGrid
Expand Down Expand Up @@ -43,15 +42,15 @@ def compute(parameterization, names, params, transforms, profiles, data=None, **
Type of object to compute for, eg Equilibrium, Curve, etc.
names : str or array-like of str
Name(s) of the quantity(s) to compute.
params : dict of ndarray
params : dict[str, jnp.ndarray]
Parameters from the equilibrium, such as R_lmn, Z_lmn, i_l, p_l, etc.
Defaults to attributes of self.
transforms : dict of Transform
Transforms for R, Z, lambda, etc. Default is to build from grid
profiles : dict of Profile
Profile objects for pressure, iota, current, etc. Defaults to attributes
of self
data : dict of ndarray
data : dict[str, jnp.ndarray]
Data computed so far, generally output from other compute functions.
Any vector v = v¹ R̂ + v² ϕ̂ + v³ Ẑ should be given in components
v = [v¹, v², v³] where R̂, ϕ̂, Ẑ are the normalized basis vectors
Expand Down Expand Up @@ -212,7 +211,7 @@ def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None):
Whether the grid to compute on has a node on the magnetic axis.
basis : {"rpz", "xyz"}
Basis of computed quantities.
data : dict of ndarray
data : dict[str, jnp.ndarray]
Data computed so far, generally output from other compute functions
Returns
Expand Down Expand Up @@ -287,7 +286,7 @@ def _get_deps(parameterization, names, deps, data=None, has_axis=False, check_fu
Name(s) of the quantity(s) to compute.
deps : set[str]
Dependencies gathered so far.
data : dict of ndarray or None
data : dict[str, jnp.ndarray]
Data computed so far, generally output from other compute functions.
has_axis : bool
Whether the grid to compute on has a node on the magnetic axis.
Expand Down Expand Up @@ -375,7 +374,7 @@ def get_derivs(keys, obj, has_axis=False, basis="rpz"):
Returns
-------
derivs : dict of list of int
derivs : dict[list, str]
Orders of derivatives needed to compute key.
Keys for R, Z, L, etc
Expand Down Expand Up @@ -465,7 +464,7 @@ def get_params(keys, obj, has_axis=False, basis="rpz"):
Returns
-------
params : list of str or dict of ndarray
params : list[str] or dict[str, jnp.ndarray]
Parameters needed to compute key.
If eq is None, returns a list of the names of params needed
otherwise, returns a dict of ndarray with keys for R_lmn, Z_lmn, etc.
Expand Down Expand Up @@ -624,13 +623,13 @@ def has_dependencies(parameterization, qty, params, transforms, profiles, data):
Type of thing we're checking dependencies for. eg desc.equilibrium.Equilibrium
qty : str
Name of something from the data index.
params : dict of ndarray
params : dict[str, jnp.ndarray]
Dictionary of parameters we have.
transforms : dict of Transform
transforms : dict[str, Transform]
Dictionary of transforms we have.
profiles : dict of Profile
profiles : dict[str, Profile]
Dictionary of profiles we have.
data : dict of ndarray
data : dict[str, jnp.ndarray]
Dictionary of what we've computed so far.
Returns
Expand Down Expand Up @@ -988,8 +987,10 @@ def line_integrals(
line_label != "poloidal" and isinstance(grid, ConcentricGrid),
msg="ConcentricGrid should only be used for poloidal line integrals.",
)
msg = colored("Correctness not guaranteed on grids with duplicate nodes.", "yellow")
warnif(isinstance(grid, LinearGrid) and grid.endpoint, msg=msg)
warnif(
isinstance(grid, LinearGrid) and grid.endpoint,
msg="Correctness not guaranteed on grids with duplicate nodes.",
)
# Generate a new quantity q_prime which is zero everywhere
# except on the fixed surface, on which q_prime takes the value of q.
# Then forward the computation to surface_integrals().
Expand Down Expand Up @@ -1075,11 +1076,8 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14)
surface_label = grid.get_label(surface_label)
warnif(
surface_label == "poloidal" and isinstance(grid, ConcentricGrid),
msg=colored(
"Integrals over constant poloidal surfaces"
" are poorly defined for ConcentricGrid.",
"yellow",
),
msg="Integrals over constant poloidal surfaces"
" are poorly defined for ConcentricGrid.",
)
unique_size, inverse_idx, spacing, has_endpoint_dupe, has_idx = _get_grid_surface(
grid, surface_label
Expand Down
Loading

0 comments on commit cbf9ebb

Please sign in to comment.