-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reversible Solvers #528
base: main
Are you sure you want to change the base?
Reversible Solvers #528
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really well done. I've left some comments but it's mostly around broader structural/testing/documentation stuff.
I've commented on your point 1 inline, and I think what you've done for point 2 looks good to me!
diffrax/_solver/reversible.py
Outdated
`adjoint=diffrax.ReversibleAdjoint()`. | ||
""" | ||
|
||
solver: AbstractRungeKutta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are implicit RK methods handled here as well? According to this annotation they are but I don't think I see them in the tests.
What is it about RK methods that privileges them here btw? IIUC I think any single-step method should work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what about Euler
, which isn't implemented as an AbstractRungeKutta
but which does have the correct properties?
(It's done separately to be able to use as example code for how to write a solver.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you're right - any single step method should work. The reversible solver now works with any AbstractSolver
.
See the discussion on fsal
for more info.
diffrax/_solver/reversible.py
Outdated
def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]: | ||
return self.solver.strong_order(terms) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you expect this technique to work for SDEs? If so then do call that out explicitly in the docstring, to reassure people! :)
(In particular I'm thinking of the asychronous leapfrog method, which to our surprise did not work for SDEs...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have no theory here (does James' intuition count?), but numerically it works for SDEs! I've added SDEs to the docstring.
There's the detail that the second solver step (that steps backwards in time) should use the same Brownian increment as the first solver step. I believe this is handled by VirtualBrownianTree
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added a check and test for UnsafeBrownianPath
in light of the above.
And thinking about this further, we require the same conditions as BacksolveAdjoint
; namely that the solver converges to the Stratonovich solution, so I've added a check and test for this.
diffrax/_solver/reversible.py
Outdated
`adjoint=diffrax.ReversibleAdjoint()`. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do go ahead and a couple of references here! (See how we've done it in the other solvers.) At the very least including both your paper, and the various earlier pieces of work. Also make sure whatever you put here works with diffrax.citation
, so that folks have an easy way to cite you :)
What happens if I use just ReversibleAdjoint
with a different solver? What happens if I use Reversible
with a different adjoint? Is this safe to use with adaptive time stepping? The docstring here needs to make clear what a user should expect to happen as this interacts with the other components of Diffrax!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good point - added.
Removing Reversible
from public API helps with control here.
test/test_reversible.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test checking how this interacts with events? It's not immediately obvious to me that this will actually do the right thing.
Also, it would be good to see some 'negative tests' checking that the appropriate error is raised if Reversible
is used in conjunction with e.g. SemiImplicitEuler
, or any other method that isn't supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Events seem to work on the forward reversible solve but raise the same error as BacksolveAdjoint
on the backward solve. I've added a catch to raise an error if you try to use ReversibleAdjoint
with events.
Negative tests for incompatible solvers, events and saveats have been added.
diffrax/_adjoint.py
Outdated
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`saveat=SaveAt(t1=True)`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will probably not take long until someone asks to use this alongside SaveAt(ts=...)
!
I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at ts
. Do you have any ideas for this?
(Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I imagine SaveAt(steps=True)
is probably much easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added functionality for SaveAt(steps=True)
, but SaveAt(ts=...)
is a tricky one.
Not a solution, but some thoughts:
The ReversibleAdjoint
computes gradients accurate to the numerical operations taken, rather than an approximation to the 'idealised' continuous-time adjoint ODE. This is then tricky when the numerical operations include interpolation and not just ODE solving.
In principle, the interpolated ys
are just a function of the stepped-to ys
. We can therefore calculate gradients for the stepped-to ys
and let AD handle the rest. This would require the interpolation routine to be separate to the solve routine, but I understand the memory drawbacks of this setup.
I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go!
diffrax/_adjoint.py
Outdated
if not isinstance(solver, Reversible): | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`Reversible()` solver." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we perhaps remove Reversible
from the public API altogether, and just have solver = Reversible(solver)
here? Make the Reversible
solver an implementation detail of the adjoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this idea :D
I've removed Reversible
from the public API and any AbstractSolver
passed to ReversibleAdjoint
is auto-wrapped. There is now a _Reversible
class within the _adjoint
module that is exclusively used by ReversibleAdjoint
. Do you think this is an appropriate home for the _Reversible
class or should I keep it elsewhere?
diffrax/_solver/reversible.py
Outdated
solver_state: _SolverState, | ||
made_jump: BoolScalarLike, | ||
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: | ||
(first_step, f0), z0 = solver_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will fail for non-FSAL Runge-Kutta solvers.
(Can you add a test for one of those to be sure we get correct behaviour?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below.
diffrax/_solver/reversible.py
Outdated
# solver to ensure the vf is correct. | ||
# Can we avoid this re-evaluation? | ||
|
||
f0 = self.func(terms, t0, z0, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is really okay -- AbstractSolver.func
is something we try to use approximately never: it basically exists just to handle initial step size selection and steady state finding, which are both fairly heuristic and pretty far of the beaten path.
If I understand correctly, the issue is that your y1
isn't quite the value that is returned from a single step, so the FSAL property does not hold, and as such you need to reevaluate f0
? If so then I think you should be able to avoid this issue by ensuring that the RK solvers are used in non-FSAL form. This is one of the most complicated corners of the codebase, but take a look at the comment starting here:
diffrax/diffrax/_solver/runge_kutta.py
Line 532 in 0cb19e9
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now disable the FSAL and SSAL properties in the _Reversible
init method (if a RK solver is used).
With this we can now make any AbstractSolver
reversible and we pass around the _Reversible
solver state by (original_solver_state, z_n)
. We also never unpack the original_solver_state
, so don't need to assume any structure.
diffrax/_solver/reversible.py
Outdated
step_z0, z_error, dense_info, _, result1 = self.solver.step( | ||
terms, t0, t1, z0, args, (first_step, f0), made_jump | ||
) | ||
y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω | ||
|
||
f0 = self.func(terms, t1, y1, args) | ||
step_y1, y_error, _, _, result2 = self.solver.step( | ||
terms, t1, t0, y1, args, (first_step, f0), made_jump | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On these two evaluations of .step
-- take a look at eqx.internal.scan_trick
, which might allow you to collapse these two callsites into one. That can be used to half compilation time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CMIIW, I'm not sure we can use the scan trick here as the function return signature is different for each solver step?
That is, we only want to update the original_solver_state
and dense_info
when taking the forward-in-time step. So we don't return these on the backward-in-time step. IIUC, collapsing the two calls into one would require the returned carry to be the same on both calls?
commit ec1ebac Author: Sam McCallum <[email protected]> Date: Wed Nov 27 08:46:55 2024 +0000 tidy up function arguments commit 7b66f46 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 18:13:11 2024 +0000 beefy tests commit e713b5d Author: Sam McCallum <[email protected]> Date: Tue Nov 26 13:29:26 2024 +0000 update references commit 9acf6e0 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 13:12:26 2024 +0000 test incorrect solver commit 861aa97 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 13:05:05 2024 +0000 catch already reversible solvers commit 4b8b4c0 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 12:37:03 2024 +0000 error estimate may be pytree commit 0b01210 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 12:36:24 2024 +0000 tests commit 5435ab2 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 11:17:09 2024 +0000 Revert "leapfrog not compatible" This reverts commit d88e732. commit d88e732 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 11:15:32 2024 +0000 leapfrog not compatible commit 6e3f2de Author: Sam McCallum <[email protected]> Date: Tue Nov 26 11:13:30 2024 +0000 pytree state commit 3fa6432 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 10:28:26 2024 +0000 docs commit 2bfe820 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 09:34:36 2024 +0000 remove reversible.py solver file commit e7856d3 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 09:33:52 2024 +0000 fix tests for relative import commit 24d1935 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 09:18:05 2024 +0000 private reversible commit 8a7448e Author: Sam McCallum <[email protected]> Date: Tue Nov 26 08:56:40 2024 +0000 remove debug print commit 0391bc1 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 08:28:41 2024 +0000 tests commit 81a9a57 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 08:23:41 2024 +0000 more tests commit 89f5731 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 20:52:51 2024 +0000 test implicit solvers + SDEs commit f30f47e Author: Sam McCallum <[email protected]> Date: Mon Nov 25 20:44:54 2024 +0000 remove t0, t1, solver_state tangents commit b903176 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 16:56:01 2024 +0000 docs commit acaa35f Author: Sam McCallum <[email protected]> Date: Mon Nov 25 12:56:50 2024 +0000 better steps=True commit 621e6f4 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 10:28:19 2024 +0000 remove ifs in grad_step loop commit 7dfb8e3 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 09:15:18 2024 +0000 Disable fsal, ssal properties to allow any solver to be made reversible commit f160295 Author: Sam McCallum <[email protected]> Date: Fri Nov 22 15:09:57 2024 +0000 tests commit f327f66 Author: Sam McCallum <[email protected]> Date: Fri Nov 22 13:53:56 2024 +0000 ReversibleAdjoint compatible with SaveAt(steps=True) Reversible Solvers (v2) Changes: - `Reversible` solver is hidden from public API and automatically used with `ReversibleAdjoint` - compatible with any `AbstractSolver`, except methods that are already algebraically reversible - can now use `SaveAt(steps=True)` - works with ODEs/CDEs/SDEs - improved docs - improved tests
Thanks very much for the review and suggestions! I'll reply to individual comments inline but here is an overview:
|
Hey Patrick,
Here's an implementation of Reversible Solvers! This includes:
AbstractRungeKutta
method in diffrax algebraically reversible - seediffrax.Reversible
diffrax.ReversibleAdjoint
Main details I should highlight here:
The current implementation relies on the _SolverState type of
AbstractRungeKutta
methods. Specifically, as the reversible method switches between evaluating the vector field at y and z, we ensure thefsal
is correct by evaluating the vector field outside of the base Runge Kutta step. In principle this is unnecessary but required to fit with the behaviour ofAbstractRungeKutta
solvers; any ideas for how to avoid this?To backpropagate through the reversible solve we require knowledge of the
ts
that the solver visited. As this is not known a priori for adaptive step sizes, I've added a (teeny weeny) bit of infrastructure to the State in_integrate.py
. This allows us to save the ts that the solver stepped to which we make available toReversibleAdjoint
as a residual. The added State follows exactly the implementation of saving dense_ts and is only triggered whenadjoint=ReversibleAdjoint
.Best,
Sam