diff --git a/desc/objectives/_coils.py b/desc/objectives/_coils.py index 778fddcd18..d75572c9b5 100644 --- a/desc/objectives/_coils.py +++ b/desc/objectives/_coils.py @@ -16,7 +16,7 @@ from desc.compute.utils import safenorm from desc.grid import LinearGrid, _Grid from desc.integrals import compute_B_plasma -from desc.utils import Timer, errorif, warnif +from desc.utils import Timer, broadcast_tree, errorif, warnif from .normalization import compute_scaling_factors from .objective_funs import _Objective @@ -1606,13 +1606,15 @@ def build(self, use_jit=True, verbose=1): self._dim_f = 1 self._data_keys = ["G"] + all_params = tree_map(lambda dim: np.arange(dim), coil.dimensions) + current_params = tree_map(lambda idx: {"current": idx}, True) + self._indices = tree_leaves(broadcast_tree(current_params, all_params)) + self._num_coils = coil.num_coils + profiles = get_profiles(self._data_keys, obj=eq, grid=grid) transforms = get_transforms(self._data_keys, obj=eq, grid=grid) self._constants = { - "eq": eq, - "coil": coil, - "grid": grid, "profiles": profiles, "transforms": transforms, "quad_weights": 1.0, @@ -1653,9 +1655,12 @@ def compute(self, eq_params, coil_params, constants=None): profiles=constants["profiles"], ) eq_linking_current = 2 * jnp.pi * data["G"][0] / mu_0 - coil_linking_current = jnp.sum( + coil_linking_current = self._num_coils * jnp.mean( jnp.concatenate( - [jnp.atleast_1d(param["current"]) for param in tree_leaves(coil_params)] + [ + jnp.atleast_1d(param[idx]) + for param, idx in zip(tree_leaves(coil_params), self._indices) + ] ) ) return eq_linking_current - coil_linking_current