-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Linear solvers for sparse matrices #24
Comments
Right! So this is doable like so: import equinox as eqx
import lineax as lx
import jax.experimental.sparse as js
import jax.numpy as jnp
class SparseMatrixLinearOperator(lx.MatrixLinearOperator):
def mv(self, vector):
return js.sparsify(lambda m, v: m @ v)(self.matrix, vector)
x = jnp.array([[1.0, 0.0], [0.0, 1.0]])
x = js.BCOO.fromdense(x)
op = SparseMatrixLinearOperator(x)
vec = jnp.array([1., 2.])
@eqx.filter_jit
def f(op, vec):
sol = lx.linear_solve(op, vec, solver=lx.GMRES(rtol=1e-5, atol=1e-5))
return sol
print(f(op, vec).value) basically just (a) overriding the This isn't built into Lineax by default as JAX's sparse support is still pretty experimental, and we want to see how that's going to pan out before we design an API around it. |
Hello, @patrick-kidger. I am a new user of lineax, I would like to ask if lineax supports sparse matrix AD now? |
Only to whatever extent JAX itself does. |
Thank you very much for your reply! |
This is just a question: are there any plans to support JAX-compatible linear solvers for sparse matrices? I am thinking of sparse linear systems of the type$Ax=b$ where $A$ is sparse.
The text was updated successfully, but these errors were encountered: