-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reversible Solvers #528
base: main
Are you sure you want to change the base?
Reversible Solvers #528
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,12 @@ | |
|
||
from ._heuristics import is_sde, is_unsafe_sde | ||
from ._saveat import save_y, SaveAt, SubSaveAt | ||
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver | ||
from ._solver import ( | ||
AbstractItoSolver, | ||
AbstractRungeKutta, | ||
AbstractStratonovichSolver, | ||
Reversible, | ||
) | ||
from ._term import AbstractTerm, AdjointTerm | ||
|
||
|
||
|
@@ -852,3 +857,206 @@ def loop( | |
) | ||
final_state = _only_transpose_ys(final_state) | ||
return final_state, aux_stats | ||
|
||
|
||
# Reversible Adjoint custom vjp computes gradients w.r.t. | ||
# - y, corresponding to the initial state; | ||
# - args, corresponding to explicit parameters; | ||
# - terms, corresponding to implicit parameters as part of the vector field. | ||
|
||
|
||
@eqx.filter_custom_vjp | ||
def _loop_reversible(y__args__terms, *, self, throw, init_state, **kwargs): | ||
del throw | ||
y, args, terms = y__args__terms | ||
init_state = eqx.tree_at(lambda s: s.y, init_state, y) | ||
del y | ||
return self._loop( | ||
args=args, | ||
terms=terms, | ||
init_state=init_state, | ||
inner_while_loop=ft.partial(_inner_loop, kind="lax"), | ||
outer_while_loop=ft.partial(_outer_loop, kind="lax"), | ||
**kwargs, | ||
) | ||
|
||
|
||
@_loop_reversible.def_fwd | ||
def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): | ||
del perturbed | ||
final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) | ||
ts = final_state.reversible_ts | ||
ts_final_index = final_state.reversible_save_index | ||
y1 = final_state.save_state.ys[-1] | ||
solver_state1 = final_state.solver_state | ||
return (final_state, aux_stats), (ts, ts_final_index, y1, solver_state1) | ||
|
||
|
||
@_loop_reversible.def_bwd | ||
def _loop_reversible_bwd( | ||
residuals, | ||
grad_final_state__aux_stats, | ||
perturbed, | ||
y__args__terms, | ||
*, | ||
self, | ||
solver, | ||
event, | ||
t0, | ||
t1, | ||
dt0, | ||
init_state, | ||
progress_meter, | ||
**kwargs, | ||
): | ||
assert event is None | ||
|
||
del perturbed, init_state, t1, progress_meter, self, kwargs | ||
ts, ts_final_index, y1, solver_state1 = residuals | ||
original_solver_state, z1 = solver_state1 | ||
del residuals, solver_state1 | ||
|
||
grad_final_state, _ = grad_final_state__aux_stats | ||
# ReversibleAdjoint currently only allows SaveAt(t1=True) so grad_y1 should have | ||
# the same structure as y1. | ||
grad_y1 = grad_final_state.save_state.ys[-1] | ||
grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1) | ||
del grad_final_state, grad_final_state__aux_stats | ||
|
||
y, args, terms = y__args__terms | ||
del y__args__terms | ||
|
||
diff_args = eqx.filter(args, eqx.is_inexact_array) | ||
diff_terms = eqx.filter(terms, eqx.is_inexact_array) | ||
diff_z1 = eqx.filter(z1, eqx.is_inexact_array) | ||
grad_args = jtu.tree_map(jnp.zeros_like, diff_args) | ||
grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms) | ||
grad_z1 = jtu.tree_map(jnp.zeros_like, diff_z1) | ||
del diff_args, diff_terms, diff_z1 | ||
|
||
def grad_step(state): | ||
def solver_step(terms, t0, t1, y1, args): | ||
step, _, _, _, _ = solver.solver.step( | ||
terms, t0, t1, y1, args, (first_step, f0), False | ||
) | ||
return step | ||
|
||
ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state | ||
(first_step, f0), z1 = solver_state | ||
|
||
t1 = ts[ts_index] | ||
t0 = ts[ts_index - 1] | ||
ts_index = ts_index - 1 | ||
|
||
# TODO The solver steps switch between evaluating from z0 | ||
# and y1. Therefore, we re-evaluate f0 outside of the base | ||
# solver to ensure the vf is correct. | ||
# Can we avoid this re-evaluation? | ||
|
||
f0 = solver.func(terms, t1, y1, args) | ||
step_y1, vjp_fun_y1 = eqx.filter_vjp(solver_step, terms, t1, t0, y1, args) | ||
z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω | ||
|
||
f0 = solver.func(terms, t0, z0, args) | ||
step_z0, vjp_fun_z0 = eqx.filter_vjp(solver_step, terms, t0, t1, z0, args) | ||
|
||
y0 = ((1 / solver.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω | ||
|
||
grad_step_y1 = vjp_fun_y1(grad_z1) | ||
grad_y1 = (ω(grad_y1) + ω(grad_z1) - ω(grad_step_y1[3])).ω | ||
|
||
grad_step_z0 = vjp_fun_z0(grad_y1) | ||
grad_y0 = (solver.l * ω(grad_y1)).ω | ||
grad_z0 = (ω(grad_z1) - solver.l * ω(grad_y1) + ω(grad_step_z0[3])).ω | ||
|
||
grad_terms = (ω(grad_terms) - ω(grad_step_y1[0]) + ω(grad_step_z0[0])).ω | ||
grad_args = (ω(grad_args) - ω(grad_step_y1[4]) + ω(grad_step_z0[4])).ω | ||
|
||
return ( | ||
ts_index, | ||
y0, | ||
((first_step, f0), z0), | ||
grad_y0, | ||
grad_z0, | ||
grad_args, | ||
grad_terms, | ||
) | ||
|
||
def cond_fun(state): | ||
ts_index = state[0] | ||
return ts_index > 0 | ||
|
||
state = ( | ||
ts_final_index, | ||
y1, | ||
(original_solver_state, z1), | ||
grad_y1, | ||
grad_z1, | ||
grad_args, | ||
grad_terms, | ||
) | ||
|
||
state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") | ||
_, _, _, grad_y0, grad_z0, grad_args, grad_terms = state | ||
return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms | ||
|
||
|
||
class ReversibleAdjoint(AbstractAdjoint): | ||
""" | ||
Backpropagate through [`diffrax.diffeqsolve`][] when using the | ||
[`diffrax.Reversible`][] solver. | ||
|
||
This method implies very low memory usage and exact gradient calculation (up to | ||
floating point errors). | ||
|
||
This will compute gradients with respect to the `terms`, `y0` and `args` arguments | ||
passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with | ||
respect to anything else (for example `t0`, or arguments passed via closure), then | ||
a `CustomVJPException` will be raised. See also | ||
[this FAQ](../../further_details/faq/#im-getting-a-customvjpexception) | ||
entry. | ||
""" | ||
|
||
def loop( | ||
self, | ||
*, | ||
args, | ||
terms, | ||
solver, | ||
saveat, | ||
init_state, | ||
passed_solver_state, | ||
passed_controller_state, | ||
event, | ||
**kwargs, | ||
): | ||
# `is` check because this may return a Tracer from SaveAt(ts=<array>) | ||
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`saveat=SaveAt(t1=True)`." | ||
) | ||
|
||
if not isinstance(solver, Reversible): | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`Reversible()` solver." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we perhaps remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like this idea :D I've removed |
||
|
||
y = init_state.y | ||
init_state = eqx.tree_at(lambda s: s.y, init_state, object()) | ||
init_state = _nondiff_solver_controller_state( | ||
self, init_state, passed_solver_state, passed_controller_state | ||
) | ||
|
||
final_state, aux_stats = _loop_reversible( | ||
(y, args, terms), | ||
self=self, | ||
saveat=saveat, | ||
init_state=init_state, | ||
solver=solver, | ||
event=event, | ||
**kwargs, | ||
) | ||
final_state = _only_transpose_ys(final_state) | ||
return final_state, aux_stats |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will probably not take long until someone asks to use this alongside
SaveAt(ts=...)
!I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at
ts
. Do you have any ideas for this?(Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I imagine
SaveAt(steps=True)
is probably much easier.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added functionality for
SaveAt(steps=True)
, butSaveAt(ts=...)
is a tricky one.Not a solution, but some thoughts:
The
ReversibleAdjoint
computes gradients accurate to the numerical operations taken, rather than an approximation to the 'idealised' continuous-time adjoint ODE. This is then tricky when the numerical operations include interpolation and not just ODE solving.In principle, the interpolated
ys
are just a function of the stepped-toys
. We can therefore calculate gradients for the stepped-toys
and let AD handle the rest. This would require the interpolation routine to be separate to the solve routine, but I understand the memory drawbacks of this setup.I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go!