Skip to content

Commit

Permalink
Merge branch 'master' into yge/cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ddudt authored Jul 23, 2024
2 parents 393d243 + 4b21abc commit 5bb27d2
Show file tree
Hide file tree
Showing 24 changed files with 1,345 additions and 545 deletions.
45 changes: 44 additions & 1 deletion desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@
switch = jax.lax.switch
while_loop = jax.lax.while_loop
vmap = jax.vmap
scan = jax.lax.scan
bincount = jnp.bincount
repeat = jnp.repeat
take = jnp.take
scan = jax.lax.scan
from jax import custom_jvp
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
Expand Down Expand Up @@ -659,6 +661,13 @@ def bincount(x, weights=None, minlength=None, length=None):
"""Same as np.bincount but with a dummy parameter to match jnp.bincount API."""
return np.bincount(x, weights, minlength)

def repeat(a, repeats, axis=None, total_repeat_length=None):
"""A numpy implementation of jnp.repeat."""
out = np.repeat(a, repeats, axis)
if total_repeat_length is not None:
out = out[:total_repeat_length]
return out

def custom_jvp(fun, *args, **kwargs):
"""Dummy function for custom_jvp without JAX."""
fun.defjvp = lambda *args, **kwargs: None
Expand Down Expand Up @@ -768,3 +777,37 @@ def root(
"""
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out

def take(
a,
indices,
axis=None,
out=None,
mode="fill",
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
):
"""A numpy implementation of jnp.take."""
if mode == "fill":
if fill_value is None:
# copy jax logic
# https://jax.readthedocs.io/en/latest/_modules/jax/_src/lax/slicing.html#gather
if np.issubdtype(a.dtype, np.inexact):
fill_value = np.nan
elif np.issubdtype(a.dtype, np.signedinteger):
fill_value = np.iinfo(a.dtype).min
elif np.issubdtype(a.dtype, np.unsignedinteger):
fill_value = np.iinfo(a.dtype).max
elif a.dtype == np.bool_:
fill_value = True
else:
raise ValueError(f"Unsupported dtype {a.dtype}.")
out = np.where(
(-a.size <= indices) & (indices < a.size),
np.take(a, indices, axis, out, mode="wrap"),
fill_value,
)
else:
out = np.take(a, indices, axis, out, mode)
return out
91 changes: 70 additions & 21 deletions desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ def _e_sup_theta_rr(params, transforms, profiles, data, **kwargs):
/ data["sqrt(g)"] ** 2
+ 2 * temp.T * data["sqrt(g)_r"] * data["sqrt(g)_r"] / data["sqrt(g)"] ** 3
).T

return data


Expand Down Expand Up @@ -1140,7 +1139,6 @@ def _e_sup_zeta_rz(params, transforms, profiles, data, **kwargs):
/ data["sqrt(g)"] ** 2
+ 2 * temp.T * data["sqrt(g)_r"] * data["sqrt(g)_z"] / data["sqrt(g)"] ** 3
).T

return data


Expand Down Expand Up @@ -1184,7 +1182,6 @@ def _e_sup_zeta_t(params, transforms, profiles, data, **kwargs):
* safediv(data["sqrt(g)_rt"], data["sqrt(g)_r"] ** 2)
).T,
)

return data


Expand Down Expand Up @@ -1235,7 +1232,6 @@ def _e_sup_zeta_tt(params, transforms, profiles, data, **kwargs):
/ data["sqrt(g)"] ** 2
+ 2 * temp.T * data["sqrt(g)_t"] * data["sqrt(g)_t"] / data["sqrt(g)"] ** 3
).T

return data


Expand Down Expand Up @@ -1291,7 +1287,6 @@ def _e_sup_zeta_tz(params, transforms, profiles, data, **kwargs):
/ data["sqrt(g)"] ** 2
+ 2 * temp.T * data["sqrt(g)_t"] * data["sqrt(g)_z"] / data["sqrt(g)"] ** 3
).T

return data


Expand Down Expand Up @@ -1335,7 +1330,6 @@ def _e_sup_zeta_z(params, transforms, profiles, data, **kwargs):
* safediv(data["sqrt(g)_rz"], data["sqrt(g)_r"] ** 2)
).T,
)

return data


Expand Down Expand Up @@ -1386,16 +1380,15 @@ def _e_sup_zeta_zz(params, transforms, profiles, data, **kwargs):
/ data["sqrt(g)"] ** 2
+ 2 * temp.T * data["sqrt(g)_z"] * data["sqrt(g)_z"] / data["sqrt(g)"] ** 3
).T

return data


@register_compute_fun(
name="e_phi",
label="\\mathbf{e}_{\\phi}",
name="e_phi|r,t",
label="\\mathbf{e}_{\\phi} |_{\\rho, \\theta}",
units="m",
units_long="meters",
description="Covariant cylindrical toroidal basis vector",
description="Covariant toroidal basis vector in (ρ,θ,ϕ) coordinates",
dim=3,
params=[],
transforms={},
Expand All @@ -1407,11 +1400,13 @@ def _e_sup_zeta_zz(params, transforms, profiles, data, **kwargs):
"desc.geometry.surface.FourierRZToroidalSurface",
"desc.geometry.core.Surface",
],
aliases=["e_phi"],
# Our usual notation implies e_phi = (∂X/∂ϕ)|R,Z = R ϕ̂, but we need to alias e_phi
# to e_phi|r,t = (∂X/∂ϕ)|ρ,θ for compatibility with older versions of the code.
)
def _e_sub_phi(params, transforms, profiles, data, **kwargs):
# dX/dphi at const r,t = dX/dz * dz/dphi = dX/dz / (dphi/dz)
data["e_phi"] = (data["e_zeta"].T / data["phi_z"]).T

def _e_sub_phi_rt(params, transforms, profiles, data, **kwargs):
# (∂X/∂ϕ)|ρ,θ = (∂X/∂ζ)|ρ,θ / (∂ϕ/∂ζ)|ρ,θ
data["e_phi|r,t"] = (data["e_zeta"].T / data["phi_z"]).T
return data


Expand Down Expand Up @@ -2434,27 +2429,81 @@ def _e_sub_theta_over_sqrt_g(params, transforms, profiles, data, **kwargs):
safediv(data["e_theta"].T, data["sqrt(g)"]).T,
lambda: safediv(data["e_theta_r"].T, data["sqrt(g)_r"]).T,
)

return data


@register_compute_fun(
name="e_theta_PEST",
label="\\mathbf{e}_{\\theta_{PEST}}",
label="\\mathbf{e}_{\\vartheta} |_{\\rho, \\phi} = \\mathbf{e}_{\\theta_{PEST}}",
units="m",
units_long="meters",
description="Covariant straight field line (PEST) poloidal basis vector",
description="Covariant poloidal basis vector in (ρ,ϑ,ϕ) coordinates or"
" straight field line PEST coordinates. ϕ increases counterclockwise"
" when viewed from above (cylindrical R,ϕ plane with Z out of page).",
dim=3,
params=[],
transforms={},
profiles=[],
coordinates="rtz",
data=["e_theta", "theta_PEST_t"],
data=["e_theta", "theta_PEST_t", "e_zeta", "theta_PEST_z", "phi_t", "phi_z"],
aliases=["e_vartheta"],
)
def _e_sub_theta_pest(params, transforms, profiles, data, **kwargs):
# dX/dv at const r,z = dX/dt * dt/dv / dX/dt / dv/dt
data["e_theta_PEST"] = (data["e_theta"].T / data["theta_PEST_t"]).T
def _e_sub_vartheta_rp(params, transforms, profiles, data, **kwargs):
# constant ρ and ϕ
e_vartheta = (
data["e_theta"].T * data["phi_z"] - data["e_zeta"].T * data["phi_t"]
) / (data["theta_PEST_t"] * data["phi_z"] - data["theta_PEST_z"] * data["phi_t"])
data["e_theta_PEST"] = e_vartheta.T
return data


@register_compute_fun(
name="e_phi|r,v",
label="\\mathbf{e}_{\\phi} |_{\\rho, \\vartheta}",
units="m",
units_long="meters",
description="Covariant toroidal basis vector in (ρ,ϑ,ϕ) coordinates or"
" straight field line PEST coordinates. ϕ increases counterclockwise"
" when viewed from above (cylindrical R,ϕ plane with Z out of page).",
dim=3,
params=[],
transforms={},
profiles=[],
coordinates="rtz",
data=["e_theta", "theta_PEST_t", "e_zeta", "theta_PEST_z", "phi_t", "phi_z"],
)
def _e_sub_phi_rv(params, transforms, profiles, data, **kwargs):
# constant ρ and ϑ
e_phi = (
data["e_zeta"].T * data["theta_PEST_t"]
- data["e_theta"].T * data["theta_PEST_z"]
) / (data["theta_PEST_t"] * data["phi_z"] - data["theta_PEST_z"] * data["phi_t"])
data["e_phi|r,v"] = e_phi.T
return data


@register_compute_fun(
name="e_rho|v,p",
label="\\mathbf{e}_{\\rho} |_{\\vartheta, \\phi}",
units="m",
units_long="meters",
description="Covariant radial basis vector in (ρ,ϑ,ϕ) coordinates or"
" straight field line PEST coordinates. ϕ increases counterclockwise"
" when viewed from above (cylindrical R,ϕ plane with Z out of page).",
dim=3,
params=[],
transforms={},
profiles=[],
coordinates="rtz",
data=["e_rho", "e_vartheta", "e_phi|r,v", "theta_PEST_r", "phi_r"],
)
def _e_sub_rho_vp(params, transforms, profiles, data, **kwargs):
# constant ϑ and ϕ
data["e_rho|v,p"] = (
data["e_rho"].T
- data["e_vartheta"].T * data["theta_PEST_r"]
- data["e_phi|r,v"].T * data["phi_r"]
).T
return data


Expand Down
1 change: 1 addition & 0 deletions desc/compute/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
coordinates="r",
data=["sqrt(g)", "V_r(r)", "|B|", "<|B|^2>", "max_tz |B|"],
axis_limit_data=["sqrt(g)_r", "V_rr(r)"],
resolution_requirement="tz",
n_gauss="int: Number of quadrature points to use for estimating trapped fraction. "
+ "Default 20.",
)
Expand Down
6 changes: 6 additions & 0 deletions desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def _J_dot_B(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["J*sqrt(g)", "B", "V_r(r)"],
axis_limit_data=["(J*sqrt(g))_r", "V_rr(r)"],
resolution_requirement="tz",
)
def _J_dot_B_fsa(params, transforms, profiles, data, **kwargs):
J = transforms["grid"].replace_at_axis(
Expand Down Expand Up @@ -534,6 +535,7 @@ def _Fmag(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|F|", "sqrt(g)", "V"],
resolution_requirement="rtz",
)
def _Fmag_vol(params, transforms, profiles, data, **kwargs):
data["<|F|>_vol"] = (
Expand Down Expand Up @@ -655,6 +657,7 @@ def _F_anisotropic(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|B|", "sqrt(g)"],
resolution_requirement="rtz",
)
def _W_B(params, transforms, profiles, data, **kwargs):
data["W_B"] = jnp.sum(
Expand All @@ -675,6 +678,7 @@ def _W_B(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["B", "sqrt(g)"],
resolution_requirement="rtz",
)
def _W_Bpol(params, transforms, profiles, data, **kwargs):
data["W_Bpol"] = jnp.sum(
Expand All @@ -697,6 +701,7 @@ def _W_Bpol(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["B", "sqrt(g)"],
resolution_requirement="rtz",
)
def _W_Btor(params, transforms, profiles, data, **kwargs):
data["W_Btor"] = jnp.sum(
Expand All @@ -718,6 +723,7 @@ def _W_Btor(params, transforms, profiles, data, **kwargs):
coordinates="",
data=["p", "sqrt(g)"],
gamma="float: Adiabatic index. Default 0",
resolution_requirement="rtz",
)
def _W_p(params, transforms, profiles, data, **kwargs):
data["W_p"] = jnp.sum(data["p"] * data["sqrt(g)"] * transforms["grid"].weights) / (
Expand Down
14 changes: 12 additions & 2 deletions desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2598,6 +2598,7 @@ def _grad_B(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["sqrt(g)", "|B|", "V"],
resolution_requirement="rtz",
)
def _B_vol(params, transforms, profiles, data, **kwargs):
data["<|B|>_vol"] = (
Expand All @@ -2617,11 +2618,12 @@ def _B_vol(params, transforms, profiles, data, **kwargs):
transforms={"grid": []},
profiles=[],
coordinates="",
data=["sqrt(g)", "|B|", "V"],
data=["sqrt(g)", "|B|^2", "V"],
resolution_requirement="rtz",
)
def _B_rms(params, transforms, profiles, data, **kwargs):
data["<|B|>_rms"] = jnp.sqrt(
jnp.sum(data["|B|"] ** 2 * data["sqrt(g)"] * transforms["grid"].weights)
jnp.sum(data["|B|^2"] * data["sqrt(g)"] * transforms["grid"].weights)
/ data["V"]
)
return data
Expand All @@ -2640,6 +2642,7 @@ def _B_rms(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "|B|"],
axis_limit_data=["sqrt(g)_r"],
resolution_requirement="tz",
)
def _B_fsa(params, transforms, profiles, data, **kwargs):
data["<|B|>"] = surface_averages(
Expand All @@ -2665,6 +2668,7 @@ def _B_fsa(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "|B|^2"],
axis_limit_data=["sqrt(g)_r"],
resolution_requirement="tz",
)
def _B2_fsa(params, transforms, profiles, data, **kwargs):
data["<|B|^2>"] = surface_averages(
Expand All @@ -2690,6 +2694,7 @@ def _B2_fsa(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "|B|"],
axis_limit_data=["sqrt(g)_r"],
resolution_requirement="tz",
)
def _1_over_B_fsa(params, transforms, profiles, data, **kwargs):
data["<1/|B|>"] = surface_averages(
Expand All @@ -2715,6 +2720,7 @@ def _1_over_B_fsa(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "sqrt(g)_r", "B", "B_r", "|B|^2", "V_r(r)", "V_rr(r)"],
axis_limit_data=["sqrt(g)_rr", "V_rrr(r)"],
resolution_requirement="tz",
)
def _B2_fsa_r(params, transforms, profiles, data, **kwargs):
integrate = surface_integrals_map(transforms["grid"])
Expand Down Expand Up @@ -2877,6 +2883,7 @@ def _gradB2mag(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|grad(|B|^2)|/2mu0", "sqrt(g)", "V"],
resolution_requirement="rtz",
)
def _gradB2mag_vol(params, transforms, profiles, data, **kwargs):
data["<|grad(|B|^2)|/2mu0>_vol"] = (
Expand Down Expand Up @@ -3077,6 +3084,7 @@ def _B_dot_grad_B_mag(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|(B*grad)B|", "sqrt(g)", "V"],
resolution_requirement="rtz",
)
def _B_dot_grad_B_mag_vol(params, transforms, profiles, data, **kwargs):
data["<|(B*grad)B|>_vol"] = (
Expand Down Expand Up @@ -3214,6 +3222,7 @@ def _B_dot_gradB_z(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["|B|"],
resolution_requirement="tz",
)
def _min_tz_modB(params, transforms, profiles, data, **kwargs):
data["min_tz |B|"] = surface_min(transforms["grid"], data["|B|"])
Expand All @@ -3232,6 +3241,7 @@ def _min_tz_modB(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["|B|"],
resolution_requirement="tz",
)
def _max_tz_modB(params, transforms, profiles, data, **kwargs):
data["max_tz |B|"] = surface_max(transforms["grid"], data["|B|"])
Expand Down
Loading

0 comments on commit 5bb27d2

Please sign in to comment.