From 99ff2e2d807702850b98e697da5482f65b3d4e88 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 28 Feb 2024 20:05:22 -0500 Subject: [PATCH] Fix signature --- desc/objectives/objective_funs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 12d3a16c41..ffff8335ce 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1035,7 +1035,7 @@ def jvp_scaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_scaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" + sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def jvp_unscaled(self, v, x, constants=None): @@ -1061,7 +1061,7 @@ def jvp_unscaled(self, v, x, constants=None): jvpfun = lambda *dx: Derivative.compute_jvp( compute_unscaled, tuple(range(len(x))), dx, *x ) - sig = "(n)," * (len(x) - 1) + "(n)" + "->(k)" + sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) def print_value(self, *args, **kwargs):