Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reversible Solvers #528

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DirectAdjoint as DirectAdjoint,
ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
ReversibleAdjoint as ReversibleAdjoint,
)
from ._autocitation import citation as citation, citation_rules as citation_rules
from ._brownian import (
Expand Down Expand Up @@ -101,6 +102,7 @@
Midpoint as Midpoint,
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
Reversible as Reversible,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
Expand Down
210 changes: 209 additions & 1 deletion diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from ._heuristics import is_sde, is_unsafe_sde
from ._saveat import save_y, SaveAt, SubSaveAt
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
from ._solver import (
AbstractItoSolver,
AbstractRungeKutta,
AbstractStratonovichSolver,
Reversible,
)
from ._term import AbstractTerm, AdjointTerm


Expand Down Expand Up @@ -852,3 +857,206 @@ def loop(
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats


# Reversible Adjoint custom vjp computes gradients w.r.t.
# - y, corresponding to the initial state;
# - args, corresponding to explicit parameters;
# - terms, corresponding to implicit parameters as part of the vector field.


@eqx.filter_custom_vjp
def _loop_reversible(y__args__terms, *, self, throw, init_state, **kwargs):
del throw
y, args, terms = y__args__terms
init_state = eqx.tree_at(lambda s: s.y, init_state, y)
del y
return self._loop(
args=args,
terms=terms,
init_state=init_state,
inner_while_loop=ft.partial(_inner_loop, kind="lax"),
outer_while_loop=ft.partial(_outer_loop, kind="lax"),
**kwargs,
)


@_loop_reversible.def_fwd
def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs):
del perturbed
final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs)
ts = final_state.reversible_ts
ts_final_index = final_state.reversible_save_index
y1 = final_state.save_state.ys[-1]
solver_state1 = final_state.solver_state
return (final_state, aux_stats), (ts, ts_final_index, y1, solver_state1)


@_loop_reversible.def_bwd
def _loop_reversible_bwd(
residuals,
grad_final_state__aux_stats,
perturbed,
y__args__terms,
*,
self,
solver,
event,
t0,
t1,
dt0,
init_state,
progress_meter,
**kwargs,
):
assert event is None

del perturbed, init_state, t1, progress_meter, self, kwargs
ts, ts_final_index, y1, solver_state1 = residuals
original_solver_state, z1 = solver_state1
del residuals, solver_state1

grad_final_state, _ = grad_final_state__aux_stats
# ReversibleAdjoint currently only allows SaveAt(t1=True) so grad_y1 should have
# the same structure as y1.
grad_y1 = grad_final_state.save_state.ys[-1]
grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1)
del grad_final_state, grad_final_state__aux_stats

y, args, terms = y__args__terms
del y__args__terms

diff_args = eqx.filter(args, eqx.is_inexact_array)
diff_terms = eqx.filter(terms, eqx.is_inexact_array)
diff_z1 = eqx.filter(z1, eqx.is_inexact_array)
grad_args = jtu.tree_map(jnp.zeros_like, diff_args)
grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
grad_z1 = jtu.tree_map(jnp.zeros_like, diff_z1)
del diff_args, diff_terms, diff_z1

def grad_step(state):
def solver_step(terms, t0, t1, y1, args):
step, _, _, _, _ = solver.solver.step(
terms, t0, t1, y1, args, (first_step, f0), False
)
return step

ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state
(first_step, f0), z1 = solver_state

t1 = ts[ts_index]
t0 = ts[ts_index - 1]
ts_index = ts_index - 1

# TODO The solver steps switch between evaluating from z0
# and y1. Therefore, we re-evaluate f0 outside of the base
# solver to ensure the vf is correct.
# Can we avoid this re-evaluation?

f0 = solver.func(terms, t1, y1, args)
step_y1, vjp_fun_y1 = eqx.filter_vjp(solver_step, terms, t1, t0, y1, args)
z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω

f0 = solver.func(terms, t0, z0, args)
step_z0, vjp_fun_z0 = eqx.filter_vjp(solver_step, terms, t0, t1, z0, args)

y0 = ((1 / solver.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω

grad_step_y1 = vjp_fun_y1(grad_z1)
grad_y1 = (ω(grad_y1) + ω(grad_z1) - ω(grad_step_y1[3])).ω

grad_step_z0 = vjp_fun_z0(grad_y1)
grad_y0 = (solver.l * ω(grad_y1)).ω
grad_z0 = (ω(grad_z1) - solver.l * ω(grad_y1) + ω(grad_step_z0[3])).ω

grad_terms = (ω(grad_terms) - ω(grad_step_y1[0]) + ω(grad_step_z0[0])).ω
grad_args = (ω(grad_args) - ω(grad_step_y1[4]) + ω(grad_step_z0[4])).ω

return (
ts_index,
y0,
((first_step, f0), z0),
grad_y0,
grad_z0,
grad_args,
grad_terms,
)

def cond_fun(state):
ts_index = state[0]
return ts_index > 0

state = (
ts_final_index,
y1,
(original_solver_state, z1),
grad_y1,
grad_z1,
grad_args,
grad_terms,
)

state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax")
_, _, _, grad_y0, grad_z0, grad_args, grad_terms = state
return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms


class ReversibleAdjoint(AbstractAdjoint):
"""
Backpropagate through [`diffrax.diffeqsolve`][] when using the
[`diffrax.Reversible`][] solver.

This method implies very low memory usage and exact gradient calculation (up to
floating point errors).

This will compute gradients with respect to the `terms`, `y0` and `args` arguments
passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with
respect to anything else (for example `t0`, or arguments passed via closure), then
a `CustomVJPException` will be raised. See also
[this FAQ](../../further_details/faq/#im-getting-a-customvjpexception)
entry.
"""

def loop(
self,
*,
args,
terms,
solver,
saveat,
init_state,
passed_solver_state,
passed_controller_state,
event,
**kwargs,
):
# `is` check because this may return a Tracer from SaveAt(ts=<array>)
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True:
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`saveat=SaveAt(t1=True)`."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably not take long until someone asks to use this alongside SaveAt(ts=...)!

I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at ts. Do you have any ideas for this?

(Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I imagine SaveAt(steps=True) is probably much easier.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added functionality for SaveAt(steps=True), but SaveAt(ts=...) is a tricky one.

Not a solution, but some thoughts:

The ReversibleAdjoint computes gradients accurate to the numerical operations taken, rather than an approximation to the 'idealised' continuous-time adjoint ODE. This is then tricky when the numerical operations include interpolation and not just ODE solving.

In principle, the interpolated ys are just a function of the stepped-to ys. We can therefore calculate gradients for the stepped-to ys and let AD handle the rest. This would require the interpolation routine to be separate to the solve routine, but I understand the memory drawbacks of this setup.

I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go!


if not isinstance(solver, Reversible):
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`Reversible()` solver."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we perhaps remove Reversible from the public API altogether, and just have solver = Reversible(solver) here? Make the Reversible solver an implementation detail of the adjoint.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this idea :D

I've removed Reversible from the public API and any AbstractSolver passed to ReversibleAdjoint is auto-wrapped. There is now a _Reversible class within the _adjoint module that is exclusively used by ReversibleAdjoint. Do you think this is an appropriate home for the _Reversible class or should I keep it elsewhere?


y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
init_state = _nondiff_solver_controller_state(
self, init_state, passed_solver_state, passed_controller_state
)

final_state, aux_stats = _loop_reversible(
(y, args, terms),
self=self,
saveat=saveat,
init_state=init_state,
solver=solver,
event=event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats
37 changes: 36 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import optimistix as optx
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint, ReversibleAdjoint
from ._custom_types import (
BoolScalarLike,
BufferDenseInfos,
Expand Down Expand Up @@ -110,6 +110,11 @@ class State(eqx.Module):
event_dense_info: Optional[DenseInfo]
event_values: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]]
event_mask: Optional[PyTree[BoolScalarLike]]
#
# Information for reversible adjoint (save ts)
#
reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]]
reversible_save_index: Optional[IntScalarLike]


def _is_none(x: Any) -> bool:
Expand Down Expand Up @@ -293,6 +298,11 @@ def loop(
dense_ts = dense_ts.at[0].set(t0)
init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts)

if init_state.reversible_ts is not None:
reversible_ts = init_state.reversible_ts
reversible_ts = reversible_ts.at[0].set(t0)
init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts)

def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.t0:
save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state)
Expand Down Expand Up @@ -574,6 +584,15 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
result,
)

reversible_ts = state.reversible_ts
reversible_save_index = state.reversible_save_index

if state.reversible_ts is not None:
reversible_ts = maybe_inplace(
reversible_save_index + 1, tprev, reversible_ts
)
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0)

new_state = State(
y=y,
tprev=tprev,
Expand All @@ -595,6 +614,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
event_dense_info=event_dense_info,
event_values=event_values,
event_mask=event_mask,
reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType]
reversible_save_index=reversible_save_index,
)

return (
Expand Down Expand Up @@ -1320,6 +1341,18 @@ def _outer_cond_fn(cond_fn_i):
)
del had_event, event_structure, event_mask_leaves, event_values__mask

# Reversible info
if isinstance(adjoint, ReversibleAdjoint):
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with `ReversibleAdjoint`"
)
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0
else:
reversible_ts = None
reversible_save_index = None

# Initialise state
init_state = State(
y=y0,
Expand All @@ -1342,6 +1375,8 @@ def _outer_cond_fn(cond_fn_i):
event_dense_info=event_dense_info,
event_values=event_values,
event_mask=event_mask,
reversible_ts=reversible_ts,
reversible_save_index=reversible_save_index,
)

#
Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
StratonovichMilstein as StratonovichMilstein,
)
from .ralston import Ralston as Ralston
from .reversible import Reversible as Reversible
from .reversible_heun import ReversibleHeun as ReversibleHeun
from .runge_kutta import (
AbstractDIRK as AbstractDIRK,
Expand Down
Loading