You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am using a JAX-based framework to do finite element analysis.
Currently, I use solvers from two sources; either jax.scipy.bicgstab or PETSc.
I like the fact that lineax has many solvers atm and is better than JAX so I will start substituting the jax solvers with lineax solvers. However, I want to maintain the petsc solver option as well since petsc has a lot more solvers and preconditioners.
Will it be possible to wrap PETSc with lineax so that the user can access them and still be integrated into a JAX workflow (no tracing errors or issues with higher order autodiff)?
Before coming across lineax, I was thinking of use pure_callbacks + custom_jvp to do this. But this is painful, especially when you have sparse matrices involved.
The text was updated successfully, but these errors were encountered:
Yup, that should be possible. Subclass lineax.AbstractSolver. Take a look at the implementation of any existing solver for an example, along with the docstring of each method of AbstractSolver for what needs implementing.
In particular you should be able to call out to PETSc in the .compute method using a jax.pure_callback. Lineax will handle autodiff etc automatically.
Hi guys,
I am using a JAX-based framework to do finite element analysis.
Currently, I use solvers from two sources; either
jax.scipy.bicgstab
or PETSc.I like the fact that
lineax
has many solvers atm and is better than JAX so I will start substituting the jax solvers withlineax
solvers. However, I want to maintain the petsc solver option as well since petsc has a lot more solvers and preconditioners.Will it be possible to wrap PETSc with lineax so that the user can access them and still be integrated into a JAX workflow (no tracing errors or issues with higher order autodiff)?
Before coming across lineax, I was thinking of use
pure_callbacks
+custom_jvp
to do this. But this is painful, especially when you have sparse matrices involved.The text was updated successfully, but these errors were encountered: