Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward mode "adjoint" #537

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

johannahaffner
Copy link

@johannahaffner johannahaffner commented Dec 9, 2024

Here you go! This is the pragmatic solution, without support or test coverage for integer inputs and only a small comment explicating that forward mode is not really an adjoint, even though its diffrax interface is that of an adjoint.

Changes with respect to the last PR:

  • renamed to ForwardMode everywhere
  • add a sentence to the docstring that explains that this is not really an adjoint, but keep inheriting from AbstractAdjoint
  • remove stub for "forward gradient with int" from test_adjoint.py and explain that since JAX does not offer this option, we're not writing our own workaround to test it either

On the last point: if I understood this correctly, then supporting this would entail writing a gradient-computation directly from a JVP with custom "unit pytrees". This is somewhat annoying for mixed array and non-array types.
I'm happy to try again if computing gradients with respect to integer elements of a PyTree is an expected use case (maybe arising from composed/layered transformations of a solve) that requires test coverage.

Earlier comments here.

(This is now rebased on main.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant