Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reversible Solvers #528

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

sammccallum
Copy link

Hey Patrick,

Here's an implementation of Reversible Solvers! This includes:

  1. make any AbstractRungeKutta method in diffrax algebraically reversible - see diffrax.Reversible
  2. backpropagate through the solve in constant memory and get exact gradients (up to floating point errors) - see diffrax.ReversibleAdjoint

Main details I should highlight here:

  1. The current implementation relies on the _SolverState type of AbstractRungeKutta methods. Specifically, as the reversible method switches between evaluating the vector field at y and z, we ensure the fsal is correct by evaluating the vector field outside of the base Runge Kutta step. In principle this is unnecessary but required to fit with the behaviour of AbstractRungeKutta solvers; any ideas for how to avoid this?

  2. To backpropagate through the reversible solve we require knowledge of the ts that the solver visited. As this is not known a priori for adaptive step sizes, I've added a (teeny weeny) bit of infrastructure to the State in _integrate.py. This allows us to save the ts that the solver stepped to which we make available to ReversibleAdjoint as a residual. The added State follows exactly the implementation of saving dense_ts and is only triggered when adjoint=ReversibleAdjoint.

Best,
Sam

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really well done. I've left some comments but it's mostly around broader structural/testing/documentation stuff.

I've commented on your point 1 inline, and I think what you've done for point 2 looks good to me!

`adjoint=diffrax.ReversibleAdjoint()`.
"""

solver: AbstractRungeKutta
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are implicit RK methods handled here as well? According to this annotation they are but I don't think I see them in the tests.

What is it about RK methods that privileges them here btw? IIUC I think any single-step method should work?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what about Euler, which isn't implemented as an AbstractRungeKutta but which does have the correct properties?

(It's done separately to be able to use as example code for how to write a solver.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right - any single step method should work. The reversible solver now works with any AbstractSolver.

See the discussion on fsal for more info.

Comment on lines 63 to 64
def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]:
return self.solver.strong_order(terms)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect this technique to work for SDEs? If so then do call that out explicitly in the docstring, to reassure people! :)

(In particular I'm thinking of the asychronous leapfrog method, which to our surprise did not work for SDEs...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no theory here (does James' intuition count?), but numerically it works for SDEs! I've added SDEs to the docstring.

There's the detail that the second solver step (that steps backwards in time) should use the same Brownian increment as the first solver step. I believe this is handled by VirtualBrownianTree.

Copy link
Author

@sammccallum sammccallum Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added a check and test for UnsafeBrownianPath in light of the above.

And thinking about this further, we require the same conditions as BacksolveAdjoint; namely that the solver converges to the Stratonovich solution, so I've added a check and test for this.

Comment on lines 34 to 35
`adjoint=diffrax.ReversibleAdjoint()`.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do go ahead and a couple of references here! (See how we've done it in the other solvers.) At the very least including both your paper, and the various earlier pieces of work. Also make sure whatever you put here works with diffrax.citation, so that folks have an easy way to cite you :)

What happens if I use just ReversibleAdjoint with a different solver? What happens if I use Reversible with a different adjoint? Is this safe to use with adaptive time stepping? The docstring here needs to make clear what a user should expect to happen as this interacts with the other components of Diffrax!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good point - added.

Removing Reversible from public API helps with control here.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test checking how this interacts with events? It's not immediately obvious to me that this will actually do the right thing.

Also, it would be good to see some 'negative tests' checking that the appropriate error is raised if Reversible is used in conjunction with e.g. SemiImplicitEuler, or any other method that isn't supported.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Events seem to work on the forward reversible solve but raise the same error as BacksolveAdjoint on the backward solve. I've added a catch to raise an error if you try to use ReversibleAdjoint with events.

Negative tests for incompatible solvers, events and saveats have been added.

Comment on lines 1034 to 1038
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True:
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`saveat=SaveAt(t1=True)`."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably not take long until someone asks to use this alongside SaveAt(ts=...)!

I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at ts. Do you have any ideas for this?

(Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I imagine SaveAt(steps=True) is probably much easier.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added functionality for SaveAt(steps=True), but SaveAt(ts=...) is a tricky one.

Not a solution, but some thoughts:

The ReversibleAdjoint computes gradients accurate to the numerical operations taken, rather than an approximation to the 'idealised' continuous-time adjoint ODE. This is then tricky when the numerical operations include interpolation and not just ODE solving.

In principle, the interpolated ys are just a function of the stepped-to ys. We can therefore calculate gradients for the stepped-to ys and let AD handle the rest. This would require the interpolation routine to be separate to the solve routine, but I understand the memory drawbacks of this setup.

I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go!

Comment on lines 1040 to 1044
if not isinstance(solver, Reversible):
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`Reversible()` solver."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we perhaps remove Reversible from the public API altogether, and just have solver = Reversible(solver) here? Make the Reversible solver an implementation detail of the adjoint.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this idea :D

I've removed Reversible from the public API and any AbstractSolver passed to ReversibleAdjoint is auto-wrapped. There is now a _Reversible class within the _adjoint module that is exclusively used by ReversibleAdjoint. Do you think this is an appropriate home for the _Reversible class or should I keep it elsewhere?

solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
(first_step, f0), z0 = solver_state
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail for non-FSAL Runge-Kutta solvers.
(Can you add a test for one of those to be sure we get correct behaviour?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below.

# solver to ensure the vf is correct.
# Can we avoid this re-evaluation?

f0 = self.func(terms, t0, z0, args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is really okay -- AbstractSolver.func is something we try to use approximately never: it basically exists just to handle initial step size selection and steady state finding, which are both fairly heuristic and pretty far of the beaten path.

If I understand correctly, the issue is that your y1 isn't quite the value that is returned from a single step, so the FSAL property does not hold, and as such you need to reevaluate f0? If so then I think you should be able to avoid this issue by ensuring that the RK solvers are used in non-FSAL form. This is one of the most complicated corners of the codebase, but take a look at the comment starting here:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now disable the FSAL and SSAL properties in the _Reversible init method (if a RK solver is used).

With this we can now make any AbstractSolver reversible and we pass around the _Reversible solver state by (original_solver_state, z_n). We also never unpack the original_solver_state, so don't need to assume any structure.

Comment on lines 95 to 103
step_z0, z_error, dense_info, _, result1 = self.solver.step(
terms, t0, t1, z0, args, (first_step, f0), made_jump
)
y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω

f0 = self.func(terms, t1, y1, args)
step_y1, y_error, _, _, result2 = self.solver.step(
terms, t1, t0, y1, args, (first_step, f0), made_jump
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On these two evaluations of .step -- take a look at eqx.internal.scan_trick, which might allow you to collapse these two callsites into one. That can be used to half compilation time!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMIIW, I'm not sure we can use the scan trick here as the function return signature is different for each solver step?

That is, we only want to update the original_solver_state and dense_info when taking the forward-in-time step. So we don't return these on the backward-in-time step. IIUC, collapsing the two calls into one would require the returned carry to be the same on both calls?

commit ec1ebac
Author: Sam McCallum <[email protected]>
Date:   Wed Nov 27 08:46:55 2024 +0000

    tidy up function arguments

commit 7b66f46
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 18:13:11 2024 +0000

    beefy tests

commit e713b5d
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 13:29:26 2024 +0000

    update references

commit 9acf6e0
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 13:12:26 2024 +0000

    test incorrect solver

commit 861aa97
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 13:05:05 2024 +0000

    catch already reversible solvers

commit 4b8b4c0
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 12:37:03 2024 +0000

    error estimate may be pytree

commit 0b01210
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 12:36:24 2024 +0000

    tests

commit 5435ab2
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 11:17:09 2024 +0000

    Revert "leapfrog not compatible"

    This reverts commit d88e732.

commit d88e732
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 11:15:32 2024 +0000

    leapfrog not compatible

commit 6e3f2de
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 11:13:30 2024 +0000

    pytree state

commit 3fa6432
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 10:28:26 2024 +0000

    docs

commit 2bfe820
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 09:34:36 2024 +0000

    remove reversible.py solver file

commit e7856d3
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 09:33:52 2024 +0000

    fix tests for relative import

commit 24d1935
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 09:18:05 2024 +0000

    private reversible

commit 8a7448e
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 08:56:40 2024 +0000

    remove debug print

commit 0391bc1
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 08:28:41 2024 +0000

    tests

commit 81a9a57
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 08:23:41 2024 +0000

    more tests

commit 89f5731
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 20:52:51 2024 +0000

    test implicit solvers + SDEs

commit f30f47e
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 20:44:54 2024 +0000

    remove t0, t1, solver_state tangents

commit b903176
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 16:56:01 2024 +0000

    docs

commit acaa35f
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 12:56:50 2024 +0000

    better steps=True

commit 621e6f4
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 10:28:19 2024 +0000

    remove ifs in grad_step loop

commit 7dfb8e3
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 09:15:18 2024 +0000

    Disable fsal, ssal properties to allow any solver to be made reversible

commit f160295
Author: Sam McCallum <[email protected]>
Date:   Fri Nov 22 15:09:57 2024 +0000

    tests

commit f327f66
Author: Sam McCallum <[email protected]>
Date:   Fri Nov 22 13:53:56 2024 +0000

    ReversibleAdjoint compatible with SaveAt(steps=True)

Reversible Solvers (v2)

Changes:
- `Reversible` solver is hidden from public API and automatically used with `ReversibleAdjoint`
- compatible with any `AbstractSolver`, except methods that are already algebraically reversible
- can now use `SaveAt(steps=True)`
- works with ODEs/CDEs/SDEs
- improved docs
- improved tests
@sammccallum
Copy link
Author

Thanks very much for the review and suggestions!

I'll reply to individual comments inline but here is an overview:

  1. I really like the idea for removing Reversible from the public API and just making the adjoint auto-wrap the original solver. It feels very JAX-like, as we just augment the forward-trace so that we can backpropagate with improved properties. This has been added.

  2. Can now use any AbstractSolver, apart from solvers that are already algebraically reversible (raises a ValueError if passed).

  3. Add functionality for SaveAt(steps=True).

  4. Improved docs so that people know what they can use ReversibleAdjoint with (any solver, ODEs/CDEs/SDEs, adaptive time steps).

  5. Improved tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants