Skip to content

Commit

Permalink
clean up jax svd version
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Nov 21, 2024
1 parent ca3d967 commit b116a8a
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions desc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,7 @@ def svd_inv_null(A):
num = jnp.sum(large, dtype=int)
uk = u[:, :K]
vhk = vh[:K, :]
s = jnp.where(large, 1 / s, s)
s.shape
s = s.at[(~large,)].set(0)
s = jnp.where(large, 1 / s, 0)
Ainv = vhk.T @ jnp.diag(s) @ uk.T
Z = vh[num:, :].T.conj()
return Ainv, Z
Expand Down

0 comments on commit b116a8a

Please sign in to comment.