From aa570d4cb5608072be5a83af558c27521f6f106f Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Fri, 26 Jul 2024 13:18:02 -0600 Subject: [PATCH] scale FD derivatives by tangent norm --- desc/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/utils.py b/desc/utils.py index 96fe1ea7b3..c1f191f74f 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -766,7 +766,7 @@ def func_jvp(primals, tangents): def f(x): return jax.flatten_util.ravel_pytree(func(*unflatx(x)))[0] - tangent_out = (f(x + fd_step * vh) - y) / fd_step + tangent_out = (f(x + fd_step * vh) - y) / fd_step * normv tangent_out = unflaty(tangent_out) return primal_out, tangent_out