diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 85784d61a9..55f4f671c2 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -5,7 +5,7 @@ import numpy as np -from desc.backend import cond, jnp, logsumexp, put +from desc.backend import cond, jit, jnp, logsumexp, put from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif @@ -148,13 +148,19 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901 else: Ainv_full = A.T Z = np.eye(A.shape[1]) + Ainv_full = jnp.asarray(Ainv_full) + Z = jnp.asarray(Z) + b = jnp.asarray(b) xp = put(xp, unfixed_idx, Ainv_full @ b) + xp = jnp.asarray(xp) + @jit def project(x): """Project a full state vector into the reduced optimization vector.""" x_reduced = Z.T @ ((x - xp)[unfixed_idx]) return jnp.atleast_1d(jnp.squeeze(x_reduced)) + @jit def recover(x_reduced): """Recover the full state vector from the reduced optimization vector.""" dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced)