diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index 25b720dc9f..949ee9c21f 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -429,10 +429,10 @@ def _omni_angle(params, transforms, profiles, data, **kwargs): @register_compute_fun( name="theta_B", - label="(\\theta_{B},\\zeta_{B})", + label="\\theta_{B}", units="rad", units_long="radians", - description="Boozer angular coordinates", + description="Boozer poloidal angle", dim=1, params=[], transforms={}, @@ -477,49 +477,19 @@ def _omni_map_theta_B(params, transforms, profiles, data, **kwargs): @register_compute_fun( name="zeta_B", - label="(\\theta_{B},\\zeta_{B})", + label="\\zeta_{B}", units="rad", units_long="radians", - description="Boozer angular coordinates", + description="Boozer toroidal angle", dim=1, params=[], transforms={}, profiles=[], coordinates="rtz", - data=["alpha", "h"], + data=["theta_B"], parameterization="desc.magnetic_fields._core.OmnigenousField", - helicity="tuple: Type of quasisymmetry, (M,N). Default (1,0)", - iota="float: Value of rotational transform on the Omnigenous surface. Default 1.0", ) def _omni_map_zeta_B(params, transforms, profiles, data, **kwargs): - M, N = kwargs.get("helicity", (1, 0)) - iota = kwargs.get("iota", 1) - - # coordinate mapping matrix from (alpha,h) to (theta_B,zeta_B) - # need a bunch of wheres to avoid division by zero causing NaN in backward pass - # this is fine since the incorrect values get ignored later, except in OT or OH - # where fieldlines are exactly parallel to |B| contours, but this is a degenerate - # case of measure 0 so this kludge shouldn't affect things too much. - mat_OP = jnp.array( - [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] - ) - mat_OT = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) - den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) - mat_OH = jnp.array([[N, M * iota / den], [M, M / den]]) - matrix = jnp.where( - M == 0, - mat_OP, - jnp.where( - N == 0, - mat_OT, - mat_OH, - ), - ) - - # solve for (theta_B,zeta_B) corresponding to (eta,alpha) - booz = matrix @ jnp.vstack((data["alpha"], data["h"])) - data["theta_B"] = booz[0, :] - data["zeta_B"] = booz[1, :] return data