diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index e8d74dd796..4c8ec74731 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -443,15 +443,25 @@ def _omni_map(params, transforms, profiles, data, **kwargs): iota = kwargs.get("iota", 1) # coordinate mapping matrix from (alpha,h) to (theta_B,zeta_B) + # need a bunch of wheres to avoid division by zero causing NaN in backward pass + # this is fine since the incorrect values get ignored later, except in OT or OH + # where fieldlines are exactly parallel to |B| contours, but this is a degenerate + # case of measure 0 so this kludge shouldn't affect things too much. + mat_OP = jnp.array( + [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] + ) + mat_OT = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) + den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) + mat_OH = jnp.array([[N, M * iota / den], [M, M / den]]) matrix = jnp.where( M == 0, - jnp.array([N, iota / N, 0, 1 / N]), # OP + mat_OP, jnp.where( N == 0, - jnp.array([0, -1, M, -1 / iota]), # OT - jnp.array([N, M * iota / (N - M * iota), M, M / (N - M * iota)]), # OH + mat_OT, + mat_OH, ), - ).reshape((2, 2)) + ) # solve for (theta_B,zeta_B) corresponding to (eta,alpha) booz = matrix @ jnp.vstack((data["alpha"], data["h"])) diff --git a/desc/objectives/_omnigenity.py b/desc/objectives/_omnigenity.py index a1acb54056..8653d4a770 100644 --- a/desc/objectives/_omnigenity.py +++ b/desc/objectives/_omnigenity.py @@ -666,7 +666,7 @@ def __init__( normalize=True, normalize_target=True, loss_function=None, - deriv_mode="fwd", # FIXME: get it working with rev mode (see GH issue #943) + deriv_mode="auto", eq_grid=None, field_grid=None, M_booz=None, @@ -891,20 +891,26 @@ def compute(self, params_1=None, params_2=None, constants=None): # update theta_B and zeta_B with new iota from the equilibrium M, N = constants["helicity"] iota = jnp.mean(eq_data["iota"]) + # see comment in desc.compute._omnigenity for the explanation of these + # wheres + mat_OP = jnp.array( + [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] + ) + mat_OT = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) + den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) + mat_OH = jnp.array([[N, M * iota / den], [M, M / den]]) matrix = jnp.where( M == 0, - jnp.array([N, iota / N, 0, 1 / N]), # OP + mat_OP, jnp.where( N == 0, - jnp.array([0, -1, M, -1 / iota]), # OT - jnp.array( - [N, M * iota / (N - M * iota), M, M / (N - M * iota)] # OH - ), + mat_OT, + mat_OH, ), - ).reshape((2, 2)) + ) booz = matrix @ jnp.vstack((field_data["alpha"], field_data["h"])) - field_data["theta_B"] = booz[0, :] - field_data["zeta_B"] = booz[1, :] + theta_B = booz[0, :] + zeta_B = booz[1, :] else: field_data = compute_fun( "desc.magnetic_fields._core.OmnigenousField", @@ -915,13 +921,15 @@ def compute(self, params_1=None, params_2=None, constants=None): helicity=constants["helicity"], iota=jnp.mean(eq_data["iota"]), ) + theta_B = field_data["theta_B"] + zeta_B = field_data["zeta_B"] # additional computations that cannot be part of the regular compute API nodes = jnp.vstack( ( - jnp.zeros_like(field_data["theta_B"]), - field_data["theta_B"], - field_data["zeta_B"], + jnp.zeros_like(theta_B), + theta_B, + zeta_B, ) ).T B_eta_alpha = jnp.matmul( diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index a611a9b2cc..b315be5d18 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -2418,6 +2418,35 @@ def test_objective_no_nangrad_coils(self, objective): g = obj.grad(obj.x()) assert not np.any(np.isnan(g)), str(objective) + @pytest.mark.unit + @pytest.mark.parametrize("helicity", [(1, 0), (1, 1), (0, 1)]) + def test_objective_no_nangrad_omnigenity(self, helicity): + """Omnigenity.""" + surf = FourierRZToroidalSurface.from_qp_model( + major_radius=1, + aspect_ratio=20, + elongation=6, + mirror_ratio=0.2, + torsion=0.1, + NFP=1, + sym=True, + ) + eq = Equilibrium(Psi=6e-3, M=4, N=4, surface=surf) + field = OmnigenousField( + L_B=0, + M_B=2, + L_x=1, + M_x=1, + N_x=1, + NFP=eq.NFP, + helicity=helicity, + B_lm=np.array([0.8, 1.2]), + ) + obj = ObjectiveFunction(Omnigenity(eq=eq, field=field)) + obj.build() + g = obj.grad(obj.x()) + assert not np.any(np.isnan(g)), str(helicity) + @pytest.mark.unit def test_asymmetric_normalization():