Skip to content

Commit

Permalink
fix for JAX issue #22011, adding transpose rule for stop_gradient (#100)
Browse files Browse the repository at this point in the history
* fix for JAX issue #22011, adding transpose rule for stop_gradient

* fixing pre commit issues
  • Loading branch information
dkweiss31 authored Jul 31, 2024
1 parent 1909d19 commit 2272e63
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions lineax/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jax._src.ad_util import stop_gradient_p
from jaxtyping import Array, ArrayLike, PyTree

from ._custom_types import sentinel
Expand Down Expand Up @@ -812,3 +813,12 @@ def linear_solve(
# TODO: prevent forward-mode autodiff through stats
stats = eqxi.nondifferentiable_backward(stats)
return Solution(value=solution, result=result, state=state, stats=stats)


# Work around JAX issue #22011,
# as well as https://github.com/patrick-kidger/diffrax/pull/387#issuecomment-2174488365
def stop_gradient_transpose(ct, x):
return (ct,)


ad.primitive_transposes[stop_gradient_p] = stop_gradient_transpose

0 comments on commit 2272e63

Please sign in to comment.