Skip to content

Commit

Permalink
cast Ainv_full and Z to jnp to avoid host to device transfers when re… (
Browse files Browse the repository at this point in the history
#1051)

…cover/project is called
  • Loading branch information
f0uriest authored Jun 13, 2024
2 parents 0567d51 + c8a523f commit c0ade01
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from desc.backend import cond, jnp, logsumexp, put
from desc.backend import cond, jit, jnp, logsumexp, put
from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif


Expand Down Expand Up @@ -148,13 +148,19 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
else:
Ainv_full = A.T
Z = np.eye(A.shape[1])
Ainv_full = jnp.asarray(Ainv_full)
Z = jnp.asarray(Z)
b = jnp.asarray(b)
xp = put(xp, unfixed_idx, Ainv_full @ b)
xp = jnp.asarray(xp)

@jit
def project(x):
"""Project a full state vector into the reduced optimization vector."""
x_reduced = Z.T @ ((x - xp)[unfixed_idx])
return jnp.atleast_1d(jnp.squeeze(x_reduced))

@jit
def recover(x_reduced):
"""Recover the full state vector from the reduced optimization vector."""
dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced)
Expand Down

0 comments on commit c0ade01

Please sign in to comment.