diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 7967ef4ee7..12d3a16c41 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1035,7 +1035,7 @@ def jvp_scaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_scaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * len(x) + "->(k)" + sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def jvp_unscaled(self, v, x, constants=None): @@ -1061,7 +1061,7 @@ def jvp_unscaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_unscaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * len(x) + "->(k)" + sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def print_value(self, *args, **kwargs):