diff --git a/desc/utils.py b/desc/utils.py index abf5e5666..0bc4a0bfb 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -424,9 +424,7 @@ def svd_inv_null(A): num = jnp.sum(large, dtype=int) uk = u[:, :K] vhk = vh[:K, :] - s = jnp.where(large, 1 / s, s) - s.shape - s = s.at[(~large,)].set(0) + s = jnp.where(large, 1 / s, 0) Ainv = vhk.T @ jnp.diag(s) @ uk.T Z = vh[num:, :].T.conj() return Ainv, Z