From 775ce6702b4168e909729466d465bae5059819e9 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 25 Jun 2024 20:50:39 -0400 Subject: [PATCH 1/4] Avoid NaN in reverse mode AD of Omnigenity --- desc/compute/_omnigenity.py | 18 ++++++++++++++---- desc/objectives/_omnigenity.py | 22 +++++++++++++++------- tests/test_objective_funs.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index e8d74dd796..d09ada7d7a 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -443,15 +443,25 @@ def _omni_map(params, transforms, profiles, data, **kwargs): iota = kwargs.get("iota", 1) # coordinate mapping matrix from (alpha,h) to (theta_B,zeta_B) + # need a bunch of wheres to avoid division by zero causing NaN in backward pass + # this is fine since the incorrect values get ignored later, except in OT or OH + # where fieldlines are exactly parallel to |B| contours, but this is a degenerate + # case of measure 0 so this kludge shouldn't affect things too much. + mat_01 = jnp.array( + [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] + ) # OP + mat_10 = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) # OT + den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) + mat_11 = jnp.array([[N, M * iota / den], [M, M / den]]) # OH matrix = jnp.where( M == 0, - jnp.array([N, iota / N, 0, 1 / N]), # OP + mat_01, # OP jnp.where( N == 0, - jnp.array([0, -1, M, -1 / iota]), # OT - jnp.array([N, M * iota / (N - M * iota), M, M / (N - M * iota)]), # OH + mat_10, # OT + mat_11, # OH ), - ).reshape((2, 2)) + ) # solve for (theta_B,zeta_B) corresponding to (eta,alpha) booz = matrix @ jnp.vstack((data["alpha"], data["h"])) diff --git a/desc/objectives/_omnigenity.py b/desc/objectives/_omnigenity.py index a1acb54056..9ce55019ef 100644 --- a/desc/objectives/_omnigenity.py +++ b/desc/objectives/_omnigenity.py @@ -666,7 +666,7 @@ def __init__( normalize=True, normalize_target=True, loss_function=None, - deriv_mode="fwd", # FIXME: get it working with rev mode (see GH issue #943) + deriv_mode="auto", eq_grid=None, field_grid=None, M_booz=None, @@ -891,17 +891,25 @@ def compute(self, params_1=None, params_2=None, constants=None): # update theta_B and zeta_B with new iota from the equilibrium M, N = constants["helicity"] iota = jnp.mean(eq_data["iota"]) + # see comment in desc.compute._omnigenity for the explanation of these + # wheres + mat_01 = jnp.array( + [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] + ) # OP + mat_10 = jnp.array( + [[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]] + ) # OT + den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) + mat_11 = jnp.array([[N, M * iota / den], [M, M / den]]) # OH matrix = jnp.where( M == 0, - jnp.array([N, iota / N, 0, 1 / N]), # OP + mat_01, # OP jnp.where( N == 0, - jnp.array([0, -1, M, -1 / iota]), # OT - jnp.array( - [N, M * iota / (N - M * iota), M, M / (N - M * iota)] # OH - ), + mat_10, # OT + mat_11, # OH ), - ).reshape((2, 2)) + ) booz = matrix @ jnp.vstack((field_data["alpha"], field_data["h"])) field_data["theta_B"] = booz[0, :] field_data["zeta_B"] = booz[1, :] diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 1aad5056b3..ff7d3545e2 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -2230,6 +2230,35 @@ def test_objective_no_nangrad_coils(self, objective): g = obj.grad(obj.x()) assert not np.any(np.isnan(g)), str(objective) + @pytest.mark.unit + @pytest.mark.parametrize("helicity", [(1, 0), (1, 1), (0, 1)]) + def test_objective_no_nangrad_omnigenity(self, helicity): + """Omnigenity.""" + surf = FourierRZToroidalSurface.from_qp_model( + major_radius=1, + aspect_ratio=20, + elongation=6, + mirror_ratio=0.2, + torsion=0.1, + NFP=1, + sym=True, + ) + eq = Equilibrium(Psi=6e-3, M=4, N=4, surface=surf) + field = OmnigenousField( + L_B=0, + M_B=2, + L_x=0, + M_x=0, + N_x=0, + NFP=eq.NFP, + helicity=helicity, + B_lm=np.array([0.8, 1.2]), + ) + obj = ObjectiveFunction(Omnigenity(eq=eq, field=field)) + obj.build() + g = obj.grad(obj.x()) + assert not np.any(np.isnan(g)), str(helicity) + @pytest.mark.unit def test_asymmetric_normalization(): From 126dd207015742e1fd6455a9a59e43941e47a086 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 26 Jun 2024 14:57:10 -0400 Subject: [PATCH 2/4] Rename matrices in omnigenity transformation --- desc/compute/_omnigenity.py | 14 +++++++------- desc/objectives/_omnigenity.py | 16 +++++++--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index d09ada7d7a..4c8ec74731 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -447,19 +447,19 @@ def _omni_map(params, transforms, profiles, data, **kwargs): # this is fine since the incorrect values get ignored later, except in OT or OH # where fieldlines are exactly parallel to |B| contours, but this is a degenerate # case of measure 0 so this kludge shouldn't affect things too much. - mat_01 = jnp.array( + mat_OP = jnp.array( [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] - ) # OP - mat_10 = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) # OT + ) + mat_OT = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) - mat_11 = jnp.array([[N, M * iota / den], [M, M / den]]) # OH + mat_OH = jnp.array([[N, M * iota / den], [M, M / den]]) matrix = jnp.where( M == 0, - mat_01, # OP + mat_OP, jnp.where( N == 0, - mat_10, # OT - mat_11, # OH + mat_OT, + mat_OH, ), ) diff --git a/desc/objectives/_omnigenity.py b/desc/objectives/_omnigenity.py index 9ce55019ef..deff459aae 100644 --- a/desc/objectives/_omnigenity.py +++ b/desc/objectives/_omnigenity.py @@ -893,21 +893,19 @@ def compute(self, params_1=None, params_2=None, constants=None): iota = jnp.mean(eq_data["iota"]) # see comment in desc.compute._omnigenity for the explanation of these # wheres - mat_01 = jnp.array( + mat_OP = jnp.array( [[N, iota / jnp.where(N == 0, 1, N)], [0, 1 / jnp.where(N == 0, 1, N)]] - ) # OP - mat_10 = jnp.array( - [[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]] - ) # OT + ) + mat_OT = jnp.array([[0, -1], [M, -1 / jnp.where(iota == 0, 1.0, iota)]]) den = jnp.where((N - M * iota) == 0, 1.0, (N - M * iota)) - mat_11 = jnp.array([[N, M * iota / den], [M, M / den]]) # OH + mat_OH = jnp.array([[N, M * iota / den], [M, M / den]]) matrix = jnp.where( M == 0, - mat_01, # OP + mat_OP, jnp.where( N == 0, - mat_10, # OT - mat_11, # OH + mat_OT, + mat_OH, ), ) booz = matrix @ jnp.vstack((field_data["alpha"], field_data["h"])) From 73524ee00c78c6d2d3020c1f18b374eb06d64afb Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 26 Jun 2024 14:57:36 -0400 Subject: [PATCH 3/4] Fix escaped tracer error --- desc/objectives/_omnigenity.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/desc/objectives/_omnigenity.py b/desc/objectives/_omnigenity.py index deff459aae..8653d4a770 100644 --- a/desc/objectives/_omnigenity.py +++ b/desc/objectives/_omnigenity.py @@ -909,8 +909,8 @@ def compute(self, params_1=None, params_2=None, constants=None): ), ) booz = matrix @ jnp.vstack((field_data["alpha"], field_data["h"])) - field_data["theta_B"] = booz[0, :] - field_data["zeta_B"] = booz[1, :] + theta_B = booz[0, :] + zeta_B = booz[1, :] else: field_data = compute_fun( "desc.magnetic_fields._core.OmnigenousField", @@ -921,13 +921,15 @@ def compute(self, params_1=None, params_2=None, constants=None): helicity=constants["helicity"], iota=jnp.mean(eq_data["iota"]), ) + theta_B = field_data["theta_B"] + zeta_B = field_data["zeta_B"] # additional computations that cannot be part of the regular compute API nodes = jnp.vstack( ( - jnp.zeros_like(field_data["theta_B"]), - field_data["theta_B"], - field_data["zeta_B"], + jnp.zeros_like(theta_B), + theta_B, + zeta_B, ) ).T B_eta_alpha = jnp.matmul( From cb34f59ab3912aea85c3abd7f9a75ab3b859d650 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Thu, 27 Jun 2024 12:47:10 -0400 Subject: [PATCH 4/4] Increase omnigenity resolution for nangrad test --- tests/test_objective_funs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 8cbf2ad18e..b315be5d18 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -2435,9 +2435,9 @@ def test_objective_no_nangrad_omnigenity(self, helicity): field = OmnigenousField( L_B=0, M_B=2, - L_x=0, - M_x=0, - N_x=0, + L_x=1, + M_x=1, + N_x=1, NFP=eq.NFP, helicity=helicity, B_lm=np.array([0.8, 1.2]),