From ef529a19b4ec313d728d76a898b5a048649ccb88 Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Tue, 19 Nov 2024 09:19:48 +0000 Subject: [PATCH 1/4] Reversible Solvers --- diffrax/__init__.py | 2 + diffrax/_adjoint.py | 210 +++++++++++++++++++++++++++++++++- diffrax/_integrate.py | 37 +++++- diffrax/_solver/__init__.py | 1 + diffrax/_solver/reversible.py | 114 ++++++++++++++++++ test/test_reversible.py | 132 +++++++++++++++++++++ 6 files changed, 494 insertions(+), 2 deletions(-) create mode 100644 diffrax/_solver/reversible.py create mode 100644 test/test_reversible.py diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 67b4ca50..313ec131 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -6,6 +6,7 @@ DirectAdjoint as DirectAdjoint, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, + ReversibleAdjoint as ReversibleAdjoint, ) from ._autocitation import citation as citation, citation_rules as citation_rules from ._brownian import ( @@ -101,6 +102,7 @@ Midpoint as Midpoint, MultiButcherTableau as MultiButcherTableau, Ralston as Ralston, + Reversible as Reversible, ReversibleHeun as ReversibleHeun, SEA as SEA, SemiImplicitEuler as SemiImplicitEuler, diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 4ff2dd2c..836a73a4 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -16,7 +16,12 @@ from ._heuristics import is_sde, is_unsafe_sde from ._saveat import save_y, SaveAt, SubSaveAt -from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver +from ._solver import ( + AbstractItoSolver, + AbstractRungeKutta, + AbstractStratonovichSolver, + Reversible, +) from ._term import AbstractTerm, AdjointTerm @@ -852,3 +857,206 @@ def loop( ) final_state = _only_transpose_ys(final_state) return final_state, aux_stats + + +# Reversible Adjoint custom vjp computes gradients w.r.t. +# - y, corresponding to the initial state; +# - args, corresponding to explicit parameters; +# - terms, corresponding to implicit parameters as part of the vector field. + + +@eqx.filter_custom_vjp +def _loop_reversible(y__args__terms, *, self, throw, init_state, **kwargs): + del throw + y, args, terms = y__args__terms + init_state = eqx.tree_at(lambda s: s.y, init_state, y) + del y + return self._loop( + args=args, + terms=terms, + init_state=init_state, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), + **kwargs, + ) + + +@_loop_reversible.def_fwd +def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): + del perturbed + final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) + ts = final_state.reversible_ts + ts_final_index = final_state.reversible_save_index + y1 = final_state.save_state.ys[-1] + solver_state1 = final_state.solver_state + return (final_state, aux_stats), (ts, ts_final_index, y1, solver_state1) + + +@_loop_reversible.def_bwd +def _loop_reversible_bwd( + residuals, + grad_final_state__aux_stats, + perturbed, + y__args__terms, + *, + self, + solver, + event, + t0, + t1, + dt0, + init_state, + progress_meter, + **kwargs, +): + assert event is None + + del perturbed, init_state, t1, progress_meter, self, kwargs + ts, ts_final_index, y1, solver_state1 = residuals + original_solver_state, z1 = solver_state1 + del residuals, solver_state1 + + grad_final_state, _ = grad_final_state__aux_stats + # ReversibleAdjoint currently only allows SaveAt(t1=True) so grad_y1 should have + # the same structure as y1. + grad_y1 = grad_final_state.save_state.ys[-1] + grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1) + del grad_final_state, grad_final_state__aux_stats + + y, args, terms = y__args__terms + del y__args__terms + + diff_args = eqx.filter(args, eqx.is_inexact_array) + diff_terms = eqx.filter(terms, eqx.is_inexact_array) + diff_z1 = eqx.filter(z1, eqx.is_inexact_array) + grad_args = jtu.tree_map(jnp.zeros_like, diff_args) + grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms) + grad_z1 = jtu.tree_map(jnp.zeros_like, diff_z1) + del diff_args, diff_terms, diff_z1 + + def grad_step(state): + def solver_step(terms, t0, t1, y1, args): + step, _, _, _, _ = solver.solver.step( + terms, t0, t1, y1, args, (first_step, f0), False + ) + return step + + ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state + (first_step, f0), z1 = solver_state + + t1 = ts[ts_index] + t0 = ts[ts_index - 1] + ts_index = ts_index - 1 + + # TODO The solver steps switch between evaluating from z0 + # and y1. Therefore, we re-evaluate f0 outside of the base + # solver to ensure the vf is correct. + # Can we avoid this re-evaluation? + + f0 = solver.func(terms, t1, y1, args) + step_y1, vjp_fun_y1 = eqx.filter_vjp(solver_step, terms, t1, t0, y1, args) + z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω + + f0 = solver.func(terms, t0, z0, args) + step_z0, vjp_fun_z0 = eqx.filter_vjp(solver_step, terms, t0, t1, z0, args) + + y0 = ((1 / solver.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω + + grad_step_y1 = vjp_fun_y1(grad_z1) + grad_y1 = (ω(grad_y1) + ω(grad_z1) - ω(grad_step_y1[3])).ω + + grad_step_z0 = vjp_fun_z0(grad_y1) + grad_y0 = (solver.l * ω(grad_y1)).ω + grad_z0 = (ω(grad_z1) - solver.l * ω(grad_y1) + ω(grad_step_z0[3])).ω + + grad_terms = (ω(grad_terms) - ω(grad_step_y1[0]) + ω(grad_step_z0[0])).ω + grad_args = (ω(grad_args) - ω(grad_step_y1[4]) + ω(grad_step_z0[4])).ω + + return ( + ts_index, + y0, + ((first_step, f0), z0), + grad_y0, + grad_z0, + grad_args, + grad_terms, + ) + + def cond_fun(state): + ts_index = state[0] + return ts_index > 0 + + state = ( + ts_final_index, + y1, + (original_solver_state, z1), + grad_y1, + grad_z1, + grad_args, + grad_terms, + ) + + state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") + _, _, _, grad_y0, grad_z0, grad_args, grad_terms = state + return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms + + +class ReversibleAdjoint(AbstractAdjoint): + """ + Backpropagate through [`diffrax.diffeqsolve`][] when using the + [`diffrax.Reversible`][] solver. + + This method implies very low memory usage and exact gradient calculation (up to + floating point errors). + + This will compute gradients with respect to the `terms`, `y0` and `args` arguments + passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with + respect to anything else (for example `t0`, or arguments passed via closure), then + a `CustomVJPException` will be raised. See also + [this FAQ](../../further_details/faq/#im-getting-a-customvjpexception) + entry. + """ + + def loop( + self, + *, + args, + terms, + solver, + saveat, + init_state, + passed_solver_state, + passed_controller_state, + event, + **kwargs, + ): + # `is` check because this may return a Tracer from SaveAt(ts=) + if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: + raise ValueError( + "Can only use `adjoint=ReversibleAdjoint()` with " + "`saveat=SaveAt(t1=True)`." + ) + + if not isinstance(solver, Reversible): + raise ValueError( + "Can only use `adjoint=ReversibleAdjoint()` with " + "`Reversible()` solver." + ) + + y = init_state.y + init_state = eqx.tree_at(lambda s: s.y, init_state, object()) + init_state = _nondiff_solver_controller_state( + self, init_state, passed_solver_state, passed_controller_state + ) + + final_state, aux_stats = _loop_reversible( + (y, args, terms), + self=self, + saveat=saveat, + init_state=init_state, + solver=solver, + event=event, + **kwargs, + ) + final_state = _only_transpose_ys(final_state) + return final_state, aux_stats diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 938eee37..80e6c2fe 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -23,7 +23,7 @@ import optimistix as optx from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real -from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint +from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint, ReversibleAdjoint from ._custom_types import ( BoolScalarLike, BufferDenseInfos, @@ -110,6 +110,11 @@ class State(eqx.Module): event_dense_info: Optional[DenseInfo] event_values: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]] event_mask: Optional[PyTree[BoolScalarLike]] + # + # Information for reversible adjoint (save ts) + # + reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]] + reversible_save_index: Optional[IntScalarLike] def _is_none(x: Any) -> bool: @@ -293,6 +298,11 @@ def loop( dense_ts = dense_ts.at[0].set(t0) init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) + if init_state.reversible_ts is not None: + reversible_ts = init_state.reversible_ts + reversible_ts = reversible_ts.at[0].set(t0) + init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts) + def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: if subsaveat.t0: save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state) @@ -574,6 +584,15 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): result, ) + reversible_ts = state.reversible_ts + reversible_save_index = state.reversible_save_index + + if state.reversible_ts is not None: + reversible_ts = maybe_inplace( + reversible_save_index + 1, tprev, reversible_ts + ) + reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0) + new_state = State( y=y, tprev=tprev, @@ -595,6 +614,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType] + reversible_save_index=reversible_save_index, ) return ( @@ -1320,6 +1341,18 @@ def _outer_cond_fn(cond_fn_i): ) del had_event, event_structure, event_mask_leaves, event_values__mask + # Reversible info + if isinstance(adjoint, ReversibleAdjoint): + if max_steps is None: + raise ValueError( + "`max_steps=None` is incompatible with `ReversibleAdjoint`" + ) + reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) + reversible_save_index = 0 + else: + reversible_ts = None + reversible_save_index = None + # Initialise state init_state = State( y=y0, @@ -1342,6 +1375,8 @@ def _outer_cond_fn(cond_fn_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_ts=reversible_ts, + reversible_save_index=reversible_save_index, ) # diff --git a/diffrax/_solver/__init__.py b/diffrax/_solver/__init__.py index da6fe6c9..1a30cd07 100644 --- a/diffrax/_solver/__init__.py +++ b/diffrax/_solver/__init__.py @@ -27,6 +27,7 @@ StratonovichMilstein as StratonovichMilstein, ) from .ralston import Ralston as Ralston +from .reversible import Reversible as Reversible from .reversible_heun import ReversibleHeun as ReversibleHeun from .runge_kutta import ( AbstractDIRK as AbstractDIRK, diff --git a/diffrax/_solver/reversible.py b/diffrax/_solver/reversible.py new file mode 100644 index 00000000..8ab53cc5 --- /dev/null +++ b/diffrax/_solver/reversible.py @@ -0,0 +1,114 @@ +from collections.abc import Callable +from typing import cast, Optional, TypeAlias, TypeVar + +from equinox.internal import ω +from jaxtyping import PyTree + +from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y +from .._solution import RESULTS, update_result +from .._term import AbstractTerm +from .base import AbstractAdaptiveSolver, AbstractWrappedSolver +from .runge_kutta import AbstractRungeKutta + + +ω = cast(Callable, ω) +_BaseSolverState = TypeVar("_BaseSolverState") +_SolverState: TypeAlias = tuple[_BaseSolverState, Y] + + +class Reversible( + AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState] +): + """ + Reversible solver method. + + Allows any Runge-Kutta method ([`diffrax.AbstractRungeKutta`][]) to be made + algebraically reversible. + + The convergence order of the reversible solver is inherited from the wrapped + Runge-Kutta method. + + Backpropagation through the reversible solver implies very low memory usage and + exact gradient calculation (up to floating point errors). This is implemented in + [`diffrax.ReversibleAdjoint`][] and passed to [`diffrax.diffeqsolve`][] as + `adjoint=diffrax.ReversibleAdjoint()`. + """ + + solver: AbstractRungeKutta + l: RealScalarLike = 0.999 + + @property + def term_structure(self): + return self.solver.term_structure + + @property + def interpolation_cls(self): # pyright: ignore + return self.solver.interpolation_cls + + @property + def term_compatible_contr_kwargs(self): + return self.solver.term_compatible_contr_kwargs + + @property + def root_finder(self): + return self.solver.root_finder # pyright: ignore + + @property + def root_find_max_steps(self): + return self.solver.root_find_max_steps # pyright: ignore + + def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]: + return self.solver.order(terms) + + def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]: + return self.solver.strong_order(terms) + + def init( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + ) -> _SolverState: + original_solver_init = self.solver.init(terms, t0, t1, y0, args) + return (original_solver_init, y0) + + def step( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: + (first_step, f0), z0 = solver_state + + # TODO The solver steps switch between evaluating from z0 + # and y1. Therefore, we re-evaluate f0 outside of the base + # solver to ensure the vf is correct. + # Can we avoid this re-evaluation? + + f0 = self.func(terms, t0, z0, args) + step_z0, z_error, dense_info, _, result1 = self.solver.step( + terms, t0, t1, z0, args, (first_step, f0), made_jump + ) + y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω + + f0 = self.func(terms, t1, y1, args) + step_y1, y_error, _, _, result2 = self.solver.step( + terms, t1, t0, y1, args, (first_step, f0), made_jump + ) + z1 = (ω(y1) + ω(z0) - ω(step_y1)).ω + + solver_state = ((first_step, f0), z1) + result = update_result(result1, result2) + + return y1, z_error + y_error, dense_info, solver_state, result + + def func( + self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args + ) -> VF: + return self.solver.func(terms, t0, y0, args) diff --git a/test/test_reversible.py b/test/test_reversible.py new file mode 100644 index 00000000..2983dc6d --- /dev/null +++ b/test/test_reversible.py @@ -0,0 +1,132 @@ +from typing import cast + +import diffrax +import equinox as eqx +import jax.numpy as jnp +from jaxtyping import Array + +from .helpers import tree_allclose + + +class _VectorField(eqx.Module): + nondiff_arg: int + diff_arg: float + + def __call__(self, t, y, args): + assert y.shape == (2,) + diff_arg, nondiff_arg = args + dya = diff_arg * y[0] + nondiff_arg * y[1] + dyb = self.nondiff_arg * y[0] + self.diff_arg * y[1] + return jnp.stack([dya, dyb]) + + +@eqx.filter_value_and_grad +def _loss(y0__args__term, solver, adjoint): + y0, args, term = y0__args__term + + sol = diffrax.diffeqsolve( + term, + solver, + t0=0, + t1=5, + dt0=0.01, + y0=y0, + args=args, + adjoint=adjoint, + stepsize_controller=diffrax.PIDController(rtol=1e-8, atol=1e-8), + ) + y1 = sol.ys + return jnp.sum(cast(Array, y1)) + + +def test_constant_stepsizes(): + y0 = jnp.array([0.9, 5.4]) + args = (0.1, -1) + term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + + base_solver = diffrax.Tsit5() + reversible_solver = diffrax.Reversible(base_solver, l=0.999) + + # Base + sol = diffrax.diffeqsolve( + term, + base_solver, + t0=0, + t1=5, + dt0=0.01, + y0=y0, + args=args, + ) + y1_base = sol.ys + + # Reversible + sol = diffrax.diffeqsolve( + term, + reversible_solver, + t0=0, + t1=5, + dt0=0.01, + y0=y0, + args=args, + ) + y1_rev = sol.ys + + assert tree_allclose(y1_base, y1_rev, atol=1e-5) + + +def test_adaptive_stepsizes(): + y0 = jnp.array([0.9, 5.4]) + args = (0.1, -1) + term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + + base_solver = diffrax.Tsit5() + reversible_solver = diffrax.Reversible(base_solver, l=0.999) + stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8) + + # Base + sol = diffrax.diffeqsolve( + term, + base_solver, + t0=0, + t1=5, + dt0=0.01, + y0=y0, + args=args, + stepsize_controller=stepsize_controller, + ) + y1_base = sol.ys + + # Reversible + sol = diffrax.diffeqsolve( + term, + reversible_solver, + t0=0, + t1=5, + dt0=0.01, + y0=y0, + args=args, + stepsize_controller=stepsize_controller, + ) + y1_rev = sol.ys + + assert tree_allclose(y1_base, y1_rev, atol=1e-5) + + +def test_reversible_adjoint(): + y0 = jnp.array([0.9, 5.4]) + args = (0.1, -1) + term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + y0__args__term = (y0, args, term) + del y0, args, term + + base_solver = diffrax.Tsit5() + reversible_solver = diffrax.Reversible(base_solver, l=0.999) + + loss, grads_base = _loss( + y0__args__term, base_solver, adjoint=diffrax.RecursiveCheckpointAdjoint() + ) + loss, grads_reversible = _loss( + y0__args__term, reversible_solver, adjoint=diffrax.ReversibleAdjoint() + ) + + assert tree_allclose(grads_base, grads_reversible, atol=1e-5) From a81f8839f6305186f0dc96b8a8046ca18f27c041 Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Wed, 27 Nov 2024 08:56:44 +0000 Subject: [PATCH 2/4] Squashed commit of the following: commit ec1ebacb2c9ddc6241b9617079e06f88869aae1a Author: Sam McCallum Date: Wed Nov 27 08:46:55 2024 +0000 tidy up function arguments commit 7b66f46509e3ee0651b3d652a672d54e0cd92e60 Author: Sam McCallum Date: Tue Nov 26 18:13:11 2024 +0000 beefy tests commit e713b5dc642ad776a67731a9e03b4384407fefe9 Author: Sam McCallum Date: Tue Nov 26 13:29:26 2024 +0000 update references commit 9acf6e0b513229f449aa1196e2bc471a45a46e58 Author: Sam McCallum Date: Tue Nov 26 13:12:26 2024 +0000 test incorrect solver commit 861aa9700a29fdcb452d5e5c029e078210cb2035 Author: Sam McCallum Date: Tue Nov 26 13:05:05 2024 +0000 catch already reversible solvers commit 4b8b4c04af8335284050321959cdebf95f5e9213 Author: Sam McCallum Date: Tue Nov 26 12:37:03 2024 +0000 error estimate may be pytree commit 0b01210e7a3aebf9c7c20ff34506484c7a7d4f2a Author: Sam McCallum Date: Tue Nov 26 12:36:24 2024 +0000 tests commit 5435ab2d68050aeda4f72fccb77148cc0aba6fed Author: Sam McCallum Date: Tue Nov 26 11:17:09 2024 +0000 Revert "leapfrog not compatible" This reverts commit d88e732f1ab278e8be3039888bef28f68dd6d8b8. commit d88e732f1ab278e8be3039888bef28f68dd6d8b8 Author: Sam McCallum Date: Tue Nov 26 11:15:32 2024 +0000 leapfrog not compatible commit 6e3f2dee9f4ce9ecf3fd0370899058d631102371 Author: Sam McCallum Date: Tue Nov 26 11:13:30 2024 +0000 pytree state commit 3fa64325be0dccae4198f1a565935a07d01867bd Author: Sam McCallum Date: Tue Nov 26 10:28:26 2024 +0000 docs commit 2bfe8200fee8824e397a059b4a3870613ac424e6 Author: Sam McCallum Date: Tue Nov 26 09:34:36 2024 +0000 remove reversible.py solver file commit e7856d37e5780672b08c1582021bb51f8b88e9cb Author: Sam McCallum Date: Tue Nov 26 09:33:52 2024 +0000 fix tests for relative import commit 24d19359d99cc69bd930c25545547e94e7534da5 Author: Sam McCallum Date: Tue Nov 26 09:18:05 2024 +0000 private reversible commit 8a7448e53e48318587587063403151bfc1d05660 Author: Sam McCallum Date: Tue Nov 26 08:56:40 2024 +0000 remove debug print commit 0391bc12e3947960f7235fc061ec7d76e25144d6 Author: Sam McCallum Date: Tue Nov 26 08:28:41 2024 +0000 tests commit 81a9a57333bba7171427de6d0d5d7768f5fb3f36 Author: Sam McCallum Date: Tue Nov 26 08:23:41 2024 +0000 more tests commit 89f57313de55e508d5d9de3907189fa4b9d4bcb9 Author: Sam McCallum Date: Mon Nov 25 20:52:51 2024 +0000 test implicit solvers + SDEs commit f30f47eab24076d4de5057801992947d9afe4b15 Author: Sam McCallum Date: Mon Nov 25 20:44:54 2024 +0000 remove t0, t1, solver_state tangents commit b90317637a70ffe6262a48c2b2850cdd2dea9d0d Author: Sam McCallum Date: Mon Nov 25 16:56:01 2024 +0000 docs commit acaa35f1dcc47b319558a3513c94211b845359cb Author: Sam McCallum Date: Mon Nov 25 12:56:50 2024 +0000 better steps=True commit 621e6f4a39b0ed92ec0d2f8e37431db7124a9fe9 Author: Sam McCallum Date: Mon Nov 25 10:28:19 2024 +0000 remove ifs in grad_step loop commit 7dfb8e3257df0f2058d3a737fab4144d57144304 Author: Sam McCallum Date: Mon Nov 25 09:15:18 2024 +0000 Disable fsal, ssal properties to allow any solver to be made reversible commit f1602956d24f0131fbdcc57140745513388bb4f5 Author: Sam McCallum Date: Fri Nov 22 15:09:57 2024 +0000 tests commit f327f668cda7d93d983e9a78a7dfc6b9e5230e45 Author: Sam McCallum Date: Fri Nov 22 13:53:56 2024 +0000 ReversibleAdjoint compatible with SaveAt(steps=True) Reversible Solvers (v2) Changes: - `Reversible` solver is hidden from public API and automatically used with `ReversibleAdjoint` - compatible with any `AbstractSolver`, except methods that are already algebraically reversible - can now use `SaveAt(steps=True)` - works with ODEs/CDEs/SDEs - improved docs - improved tests --- diffrax/__init__.py | 1 - diffrax/_adjoint.py | 331 ++++++++++++++++++++++++++++------ diffrax/_solver/__init__.py | 1 - diffrax/_solver/reversible.py | 114 ------------ docs/api/adjoints.md | 11 +- test/test_reversible.py | 307 +++++++++++++++++++++++++------ 6 files changed, 536 insertions(+), 229 deletions(-) delete mode 100644 diffrax/_solver/reversible.py diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 313ec131..f55c82f6 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -102,7 +102,6 @@ Midpoint as Midpoint, MultiButcherTableau as MultiButcherTableau, Ralston as Ralston, - Reversible as Reversible, ReversibleHeun as ReversibleHeun, SEA as SEA, SemiImplicitEuler as SemiImplicitEuler, diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 836a73a4..7bcd0174 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -2,7 +2,7 @@ import functools as ft import warnings from collections.abc import Callable, Iterable -from typing import Any, cast, Optional, Union +from typing import Any, cast, Optional, TypeAlias, TypeVar, Union import equinox as eqx import equinox.internal as eqxi @@ -13,14 +13,22 @@ import lineax as lx import optimistix.internal as optxi from equinox.internal import ω +from jaxtyping import PyTree +from ._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y from ._heuristics import is_sde, is_unsafe_sde from ._saveat import save_y, SaveAt, SubSaveAt +from ._solution import RESULTS, update_result from ._solver import ( + AbstractAdaptiveSolver, AbstractItoSolver, AbstractRungeKutta, + AbstractSolver, AbstractStratonovichSolver, - Reversible, + AbstractWrappedSolver, + LeapfrogMidpoint, + ReversibleHeun, + SemiImplicitEuler, ) from ._term import AbstractTerm, AdjointTerm @@ -887,9 +895,9 @@ def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) ts = final_state.reversible_ts ts_final_index = final_state.reversible_save_index - y1 = final_state.save_state.ys[-1] + ys = final_state.save_state.ys solver_state1 = final_state.solver_state - return (final_state, aux_stats), (ts, ts_final_index, y1, solver_state1) + return (final_state, aux_stats), (ts, ts_final_index, ys, solver_state1) @_loop_reversible.def_bwd @@ -900,27 +908,44 @@ def _loop_reversible_bwd( y__args__terms, *, self, + saveat, + init_state, solver, event, - t0, - t1, - dt0, - init_state, - progress_meter, **kwargs, ): assert event is None - del perturbed, init_state, t1, progress_meter, self, kwargs - ts, ts_final_index, y1, solver_state1 = residuals + del perturbed, self, init_state, kwargs + ts, ts_final_index, ys, solver_state1 = residuals original_solver_state, z1 = solver_state1 del residuals, solver_state1 grad_final_state, _ = grad_final_state__aux_stats - # ReversibleAdjoint currently only allows SaveAt(t1=True) so grad_y1 should have - # the same structure as y1. - grad_y1 = grad_final_state.save_state.ys[-1] - grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1) + # If true we must be using SaveAt(t1=True). + t1_only = saveat.subs.t1 + if t1_only: + y1 = (ω(ys)[-1]).ω + grad_ys = (ω(grad_final_state.save_state.ys)[-1]).ω + grad_ys = jtu.tree_map(_materialise_none, y1, grad_ys) + grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_ys) + + # Otherwise we must be using SaveAt(..., steps=True) due to the guard in + # ReversibleAdjoint. If y0 is not saved (t0=False) then we prepend grad_y0 (zeros). + else: + if saveat.subs.t0: + y1 = (ω(ys)[ts_final_index]).ω + grad_ys = grad_final_state.save_state.ys + else: + y1 = (ω(ys)[ts_final_index - 1]).ω + grad_ys = grad_final_state.save_state.ys + grad_y0 = jtu.tree_map(lambda x: jnp.zeros_like(x[0]), grad_ys) + grad_ys = jtu.tree_map( + lambda x, y: jnp.concatenate([x[None], y]), grad_y0, grad_ys + ) + + grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys) + del grad_final_state, grad_final_state__aux_stats y, args, terms = y__args__terms @@ -935,48 +960,62 @@ def _loop_reversible_bwd( del diff_args, diff_terms, diff_z1 def grad_step(state): - def solver_step(terms, t0, t1, y1, args): - step, _, _, _, _ = solver.solver.step( - terms, t0, t1, y1, args, (first_step, f0), False + def solver_step(t0, t1, original_solver_state, y0, args, terms): + step, _, _, original_solver_state, _ = solver.solver.step( + terms, t0, t1, y0, args, original_solver_state, False ) - return step + return step, original_solver_state - ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state - (first_step, f0), z1 = solver_state + ts_index, y1, solver_state, grad_ys, grad_z1, grad_args, grad_terms = state + original_solver_state, z1 = solver_state t1 = ts[ts_index] t0 = ts[ts_index - 1] - ts_index = ts_index - 1 - # TODO The solver steps switch between evaluating from z0 - # and y1. Therefore, we re-evaluate f0 outside of the base - # solver to ensure the vf is correct. - # Can we avoid this re-evaluation? + if t1_only: + grad_y1 = grad_ys + grad_y0 = grad_y0_zeros # pyright: ignore + + else: + grad_y1 = (ω(grad_ys)[ts_index]).ω + grad_y0 = (ω(grad_ys)[ts_index - 1]).ω - f0 = solver.func(terms, t1, y1, args) - step_y1, vjp_fun_y1 = eqx.filter_vjp(solver_step, terms, t1, t0, y1, args) + solver_step_fn = ft.partial(solver_step, t1, t0, original_solver_state) + step_y1, vjp_fun_y1, original_solver_state = eqx.filter_vjp( + solver_step_fn, y1, args, terms, has_aux=True + ) z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω - f0 = solver.func(terms, t0, z0, args) - step_z0, vjp_fun_z0 = eqx.filter_vjp(solver_step, terms, t0, t1, z0, args) + solver_step_fn = ft.partial(solver_step, t0, t1, original_solver_state) + step_z0, vjp_fun_z0, _ = eqx.filter_vjp( + solver_step_fn, z0, args, terms, has_aux=True + ) y0 = ((1 / solver.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω grad_step_y1 = vjp_fun_y1(grad_z1) - grad_y1 = (ω(grad_y1) + ω(grad_z1) - ω(grad_step_y1[3])).ω + grad_y1 = (ω(grad_y1) + ω(grad_z1) - ω(grad_step_y1[0])).ω grad_step_z0 = vjp_fun_z0(grad_y1) - grad_y0 = (solver.l * ω(grad_y1)).ω - grad_z0 = (ω(grad_z1) - solver.l * ω(grad_y1) + ω(grad_step_z0[3])).ω + grad_y0 = (solver.l * ω(grad_y1) + ω(grad_y0)).ω + grad_z0 = (ω(grad_z1) - solver.l * ω(grad_y1) + ω(grad_step_z0[0])).ω + + grad_terms = (ω(grad_terms) - ω(grad_step_y1[2]) + ω(grad_step_z0[2])).ω + grad_args = (ω(grad_args) - ω(grad_step_y1[1]) + ω(grad_step_z0[1])).ω - grad_terms = (ω(grad_terms) - ω(grad_step_y1[0]) + ω(grad_step_z0[0])).ω - grad_args = (ω(grad_args) - ω(grad_step_y1[4]) + ω(grad_step_z0[4])).ω + if t1_only: + grad_ys = grad_y0 + else: + grad_ys = (ω(grad_ys).at[ts_index].set(ω(grad_y1))).ω + grad_ys = (ω(grad_ys).at[ts_index - 1].set(ω(grad_y0))).ω + + ts_index = ts_index - 1 return ( ts_index, y0, - ((first_step, f0), z0), - grad_y0, + (original_solver_state, z0), + grad_ys, grad_z0, grad_args, grad_terms, @@ -990,33 +1029,84 @@ def cond_fun(state): ts_final_index, y1, (original_solver_state, z1), - grad_y1, + grad_ys, grad_z1, grad_args, grad_terms, ) state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") - _, _, _, grad_y0, grad_z0, grad_args, grad_terms = state + _, _, _, grad_ys, grad_z0, grad_args, grad_terms = state + if t1_only: + grad_y0 = grad_ys + else: + grad_y0 = (ω(grad_ys)[0]).ω + return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms class ReversibleAdjoint(AbstractAdjoint): """ - Backpropagate through [`diffrax.diffeqsolve`][] when using the - [`diffrax.Reversible`][] solver. + Backpropagate through [`diffrax.diffeqsolve`][] using the reversible solver + method. - This method implies very low memory usage and exact gradient calculation (up to - floating point errors). + This method automatically wraps the passed solver to create an algebraically + reversible version of that solver. In doing so, gradient calculation is exact + (up to floating point errors) and backpropagation becomes a linear in time $O(n)$ + and constant in memory $O(1)$ algorithm in the number of steps $n$. - This will compute gradients with respect to the `terms`, `y0` and `args` arguments - passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with - respect to anything else (for example `t0`, or arguments passed via closure), then - a `CustomVJPException` will be raised. See also - [this FAQ](../../further_details/faq/#im-getting-a-customvjpexception) - entry. + The reversible adjoint can be used when solving ODEs/CDEs/SDEs and is + compatible with any [`diffrax.AbstractSolver`][]. Adaptive step sizes are also + supported. + + !!! note + + This adjoint can be less numerically stable than + [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.DirectAdjoint`][]. + Stability can be largely improved by using [double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + and [smaller/adaptive step sizes](https://docs.kidger.site/diffrax/api/stepsize_controller/). + + ??? cite "References" + + This algorithm was developed in: + + ```bibtex + @article{mccallum2024efficient, + title={Efficient, Accurate and Stable Gradients for Neural ODEs}, + author={McCallum, Sam and Foster, James}, + journal={arXiv preprint arXiv:2410.11648}, + year={2024} + } + ``` + + And built on previous work by: + + ```bibtex + @article{kidger2021efficient, + title={Efficient and accurate gradients for neural sdes}, + author={Kidger, Patrick and Foster, James and Li, Xuechen Chen and Lyons, + Terry}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={18747--18761}, + year={2021} + } + ``` + + ```bibtex + @article{zhuang2021mali, + title={Mali: A memory efficient and reverse accurate integrator for neural + odes}, + author={Zhuang, Juntang and Dvornek, Nicha C and Tatikonda, Sekhar and + Duncan, James S}, + journal={arXiv preprint arXiv:2102.04668}, + year={2021} + } + ``` """ + l: float = 0.999 + def loop( self, *, @@ -1031,19 +1121,39 @@ def loop( **kwargs, ): # `is` check because this may return a Tracer from SaveAt(ts=) - if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: + if ( + eqx.tree_equal(saveat, SaveAt(t1=True)) is not True + and eqx.tree_equal(saveat, SaveAt(steps=True)) is not True + and eqx.tree_equal(saveat, SaveAt(t0=True, steps=True)) is not True + ): raise ValueError( - "Can only use `adjoint=ReversibleAdjoint()` with " - "`saveat=SaveAt(t1=True)`." + "Can only use `diffrax.ReversibleAdjoint` with " + "`saveat=SaveAt(t1=True)` or `saveat=SaveAt(steps=True)`." + ) + + if event is not None: + raise NotImplementedError( + "`diffrax.ReversibleAdjoint` is not compatible with events." ) - if not isinstance(solver, Reversible): + if isinstance(solver, (SemiImplicitEuler, ReversibleHeun, LeapfrogMidpoint)): raise ValueError( - "Can only use `adjoint=ReversibleAdjoint()` with " - "`Reversible()` solver." + "`diffrax.ReversibleAdjoint` is not compatible with solvers that are " + f"intrinsically algebraically reversible, such as {solver}." ) + solver = _Reversible(solver, self.l) + tprev = init_state.tprev + tnext = init_state.tnext y = init_state.y + + init_state = eqx.tree_at( + lambda s: s.solver_state, + init_state, + solver.init(terms, tprev, tnext, y, args), + is_leaf=_is_none, + ) + init_state = eqx.tree_at(lambda s: s.y, init_state, object()) init_state = _nondiff_solver_controller_state( self, init_state, passed_solver_state, passed_controller_state @@ -1060,3 +1170,114 @@ def loop( ) final_state = _only_transpose_ys(final_state) return final_state, aux_stats + + +ReversibleAdjoint.__init__.__doc__ = r""" +**Arguments:** + +- `l` - coupling parameter, defaults to `l=0.999`. + +The reversible solver introduces the coupled state $\{y_n, z_n\}_{n\geq 0}$ and the +coupling parameter $l\in (0, 1)$ mixes the states via $ly_n + (1-l)z_n$. This parameter +effects the stability of the reversible solver; decreasing it's value leads to greater +forward stability and increasing it's value leads to greater backward stability. + +In most cases the default value is sufficient. However, if you find yourself needing +greater control over stability it can be passed as an argument. +""" + +_BaseSolverState = TypeVar("_BaseSolverState") +_SolverState: TypeAlias = tuple[_BaseSolverState, Y] + + +def _add_maybe_none(x, y): + if x is None: + return None + else: + return (ω(x) + ω(y)).ω + + +class _Reversible( + AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState] +): + """ + Reversible solver method. + + Allows any solver ([`diffrax.AbstractSolver`][]) to be made algebraically + reversible. This is a private API, exclusively for [`diffrax.ReversibleAdjoint`][]. + """ + + solver: AbstractSolver + l: float = 0.999 + + @property + def term_structure(self): + return self.solver.term_structure + + @property + def interpolation_cls(self): # pyright: ignore + return self.solver.interpolation_cls + + @property + def term_compatible_contr_kwargs(self): + return self.solver.term_compatible_contr_kwargs + + @property + def root_finder(self): + return self.solver.root_finder # pyright: ignore + + @property + def root_find_max_steps(self): + return self.solver.root_find_max_steps # pyright: ignore + + def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]: + return self.solver.order(terms) + + def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]: + return self.solver.strong_order(terms) + + def init( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + ) -> _SolverState: + if isinstance(self.solver, AbstractRungeKutta): + object.__setattr__(self.solver.tableau, "fsal", False) + object.__setattr__(self.solver.tableau, "ssal", False) + original_solver_init = self.solver.init(terms, t0, t1, y0, args) + return (original_solver_init, y0) + + def step( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: + original_solver_state, z0 = solver_state + + step_z0, z_error, dense_info, original_solver_state, result1 = self.solver.step( + terms, t0, t1, z0, args, original_solver_state, made_jump + ) + y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω + + step_y1, y_error, _, _, result2 = self.solver.step( + terms, t1, t0, y1, args, original_solver_state, made_jump + ) + z1 = (ω(y1) + ω(z0) - ω(step_y1)).ω + + solver_state = (original_solver_state, z1) + result = update_result(result1, result2) + + return y1, _add_maybe_none(z_error, y_error), dense_info, solver_state, result + + def func( + self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args + ) -> VF: + return self.solver.func(terms, t0, y0, args) diff --git a/diffrax/_solver/__init__.py b/diffrax/_solver/__init__.py index 1a30cd07..da6fe6c9 100644 --- a/diffrax/_solver/__init__.py +++ b/diffrax/_solver/__init__.py @@ -27,7 +27,6 @@ StratonovichMilstein as StratonovichMilstein, ) from .ralston import Ralston as Ralston -from .reversible import Reversible as Reversible from .reversible_heun import ReversibleHeun as ReversibleHeun from .runge_kutta import ( AbstractDIRK as AbstractDIRK, diff --git a/diffrax/_solver/reversible.py b/diffrax/_solver/reversible.py deleted file mode 100644 index 8ab53cc5..00000000 --- a/diffrax/_solver/reversible.py +++ /dev/null @@ -1,114 +0,0 @@ -from collections.abc import Callable -from typing import cast, Optional, TypeAlias, TypeVar - -from equinox.internal import ω -from jaxtyping import PyTree - -from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y -from .._solution import RESULTS, update_result -from .._term import AbstractTerm -from .base import AbstractAdaptiveSolver, AbstractWrappedSolver -from .runge_kutta import AbstractRungeKutta - - -ω = cast(Callable, ω) -_BaseSolverState = TypeVar("_BaseSolverState") -_SolverState: TypeAlias = tuple[_BaseSolverState, Y] - - -class Reversible( - AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState] -): - """ - Reversible solver method. - - Allows any Runge-Kutta method ([`diffrax.AbstractRungeKutta`][]) to be made - algebraically reversible. - - The convergence order of the reversible solver is inherited from the wrapped - Runge-Kutta method. - - Backpropagation through the reversible solver implies very low memory usage and - exact gradient calculation (up to floating point errors). This is implemented in - [`diffrax.ReversibleAdjoint`][] and passed to [`diffrax.diffeqsolve`][] as - `adjoint=diffrax.ReversibleAdjoint()`. - """ - - solver: AbstractRungeKutta - l: RealScalarLike = 0.999 - - @property - def term_structure(self): - return self.solver.term_structure - - @property - def interpolation_cls(self): # pyright: ignore - return self.solver.interpolation_cls - - @property - def term_compatible_contr_kwargs(self): - return self.solver.term_compatible_contr_kwargs - - @property - def root_finder(self): - return self.solver.root_finder # pyright: ignore - - @property - def root_find_max_steps(self): - return self.solver.root_find_max_steps # pyright: ignore - - def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]: - return self.solver.order(terms) - - def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]: - return self.solver.strong_order(terms) - - def init( - self, - terms: PyTree[AbstractTerm], - t0: RealScalarLike, - t1: RealScalarLike, - y0: Y, - args: Args, - ) -> _SolverState: - original_solver_init = self.solver.init(terms, t0, t1, y0, args) - return (original_solver_init, y0) - - def step( - self, - terms: PyTree[AbstractTerm], - t0: RealScalarLike, - t1: RealScalarLike, - y0: Y, - args: Args, - solver_state: _SolverState, - made_jump: BoolScalarLike, - ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: - (first_step, f0), z0 = solver_state - - # TODO The solver steps switch between evaluating from z0 - # and y1. Therefore, we re-evaluate f0 outside of the base - # solver to ensure the vf is correct. - # Can we avoid this re-evaluation? - - f0 = self.func(terms, t0, z0, args) - step_z0, z_error, dense_info, _, result1 = self.solver.step( - terms, t0, t1, z0, args, (first_step, f0), made_jump - ) - y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω - - f0 = self.func(terms, t1, y1, args) - step_y1, y_error, _, _, result2 = self.solver.step( - terms, t1, t0, y1, args, (first_step, f0), made_jump - ) - z1 = (ω(y1) + ω(z0) - ω(step_y1)).ω - - solver_state = ((first_step, f0), z1) - result = update_result(result1, result2) - - return y1, z_error + y_error, dense_info, solver_state, result - - def func( - self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args - ) -> VF: - return self.solver.func(terms, t0, y0, args) diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index 992d57ab..a432d25e 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -14,6 +14,10 @@ There are multiple ways to backpropagate through a differential equation (to com Alternatively we may compute $\frac{\mathrm{d}y(t_1)}{\mathrm{d}y_0}$ analytically. In doing so we obtain a backwards-in-time ODE that we must numerically solve to obtain the desired gradients. This is known as "optimise then discretise", and corresponds to [`diffrax.BacksolveAdjoint`][] below. +!!! note + + It is possible to augment the structure of the forward ODE solve to ensure algebraic reversibility of the numerical operations. As a result, backpropagation can be performed with very low memory usage and retain exact gradient calculation. This corresponds to [`diffrax.ReversibleAdjoint`][] below. + ??? abstract "`diffrax.AbstractAdjoint`" ::: diffrax.AbstractAdjoint @@ -21,7 +25,7 @@ There are multiple ways to backpropagate through a differential equation (to com members: - loop -Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.BacksolveAdjoint`][] can only be reverse-mode autodifferentiated. [`diffrax.ImplicitAdjoint`][] and [`diffrax.DirectAdjoint`][] support both forward and reverse-mode autodifferentiation. +Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][], [`diffrax.BacksolveAdjoint`][] and [`diffrax.ReversibleAdjoint`][] can only be reverse-mode autodifferentiated. [`diffrax.ImplicitAdjoint`][] and [`diffrax.DirectAdjoint`][] support both forward and reverse-mode autodifferentiation. --- @@ -44,6 +48,11 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax selection: members: false +::: diffrax.ReversibleAdjoint + selection: + members: + - __init__ + --- ::: diffrax.adjoint_rms_seminorm diff --git a/test/test_reversible.py b/test/test_reversible.py index 2983dc6d..94acd7ce 100644 --- a/test/test_reversible.py +++ b/test/test_reversible.py @@ -3,6 +3,10 @@ import diffrax import equinox as eqx import jax.numpy as jnp +import jax.random as jr +import optimistix as optx +import pytest +from diffrax._adjoint import _Reversible from jaxtyping import Array from .helpers import tree_allclose @@ -20,113 +24,302 @@ def __call__(self, t, y, args): return jnp.stack([dya, dyb]) -@eqx.filter_value_and_grad -def _loss(y0__args__term, solver, adjoint): - y0, args, term = y0__args__term +class _PyTreeVectorField(eqx.Module): + nondiff_arg: int + diff_arg: float - sol = diffrax.diffeqsolve( - term, - solver, - t0=0, - t1=5, - dt0=0.01, - y0=y0, - args=args, - adjoint=adjoint, - stepsize_controller=diffrax.PIDController(rtol=1e-8, atol=1e-8), - ) - y1 = sol.ys - return jnp.sum(cast(Array, y1)) + def __call__(self, t, y, args): + diff_arg, nondiff_arg = args + dya = diff_arg * y[0] + nondiff_arg * y[1][0] + dyb = self.nondiff_arg * y[0] + self.diff_arg * y[1][0] + dyc = diff_arg * y[1][1] + nondiff_arg * y[1][0] + return (dya, (dyb, dyc)) -def test_constant_stepsizes(): - y0 = jnp.array([0.9, 5.4]) - args = (0.1, -1) - term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) +class QuadraticPath(diffrax.AbstractPath): + @property + def t0(self): + return 0 + + @property + def t1(self): + return 3 + + def evaluate(self, t0, t1=None, left=True): + del left + if t1 is not None: + return self.evaluate(t1) - self.evaluate(t0) + return t0**2 - base_solver = diffrax.Tsit5() - reversible_solver = diffrax.Reversible(base_solver, l=0.999) - # Base +def _compare_solve(y0__args__term, solver, stepsize_controller): + y0, args, term = y0__args__term sol = diffrax.diffeqsolve( term, - base_solver, + solver, t0=0, t1=5, dt0=0.01, y0=y0, args=args, + stepsize_controller=stepsize_controller, ) y1_base = sol.ys # Reversible sol = diffrax.diffeqsolve( term, - reversible_solver, + solver, t0=0, t1=5, dt0=0.01, y0=y0, args=args, + adjoint=diffrax.ReversibleAdjoint(), + stepsize_controller=stepsize_controller, ) y1_rev = sol.ys assert tree_allclose(y1_base, y1_rev, atol=1e-5) -def test_adaptive_stepsizes(): - y0 = jnp.array([0.9, 5.4]) - args = (0.1, -1) - term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) - - base_solver = diffrax.Tsit5() - reversible_solver = diffrax.Reversible(base_solver, l=0.999) - stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8) +@eqx.filter_value_and_grad +def _loss(y0__args__term, solver, saveat, adjoint, stepsize_controller, pytree_state): + y0, args, term = y0__args__term - # Base sol = diffrax.diffeqsolve( term, - base_solver, + solver, t0=0, t1=5, dt0=0.01, y0=y0, args=args, + saveat=saveat, + adjoint=adjoint, stepsize_controller=stepsize_controller, ) - y1_base = sol.ys + if pytree_state: + y1, (y2, y3) = sol.ys # type: ignore + y1 = y1 + y2 + y3 + else: + y1 = sol.ys + return jnp.sum(cast(Array, y1)) - # Reversible - sol = diffrax.diffeqsolve( - term, + +# The adjoint comparison looks wrong at first glance so here's an explanation: +# We want to check that the gradients calculated by ReversibleAdjoint +# are the same as those calculated by RecursiveCheckpointAdjoint, for a fixed +# solver. +# +# ReversibleAdjoint auto-wraps the solver to create a reversible version. So when +# calculating gradients we use base_solver + ReversibleAdjoint and reversible_solver + +# RecursiveCheckpointAdjoint, to ensure that the reversible solver is fixed across both +# adjoints. +def _compare_grads( + y0__args__term, base_solver, saveat, stepsize_controller, pytree_state +): + reversible_solver = _Reversible(base_solver) + + loss, grads_base = _loss( + y0__args__term, reversible_solver, - t0=0, - t1=5, - dt0=0.01, - y0=y0, - args=args, + saveat, + adjoint=diffrax.RecursiveCheckpointAdjoint(), stepsize_controller=stepsize_controller, + pytree_state=pytree_state, + ) + loss, grads_reversible = _loss( + y0__args__term, + base_solver, + saveat, + adjoint=diffrax.ReversibleAdjoint(), + stepsize_controller=stepsize_controller, + pytree_state=pytree_state, ) - y1_rev = sol.ys - assert tree_allclose(y1_base, y1_rev, atol=1e-5) + assert tree_allclose(grads_base, grads_reversible, atol=1e-5) -def test_reversible_adjoint(): - y0 = jnp.array([0.9, 5.4]) +@pytest.mark.parametrize( + "solver", + [diffrax.Tsit5(), diffrax.Kvaerno5(), diffrax.KenCarp5()], +) +@pytest.mark.parametrize( + "stepsize_controller", + [diffrax.ConstantStepSize(), diffrax.PIDController(rtol=1e-8, atol=1e-8)], +) +@pytest.mark.parametrize("pytree_state", [True, False]) +def test_forward_solve(solver, stepsize_controller, pytree_state): + if pytree_state: + y0 = (jnp.array(0.9), (jnp.array(5.4), jnp.array(1.2))) + term = diffrax.ODETerm(_PyTreeVectorField(nondiff_arg=1, diff_arg=-0.1)) + else: + y0 = jnp.array([0.9, 5.4]) + term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + + if isinstance(stepsize_controller, diffrax.ConstantStepSize) and isinstance( + solver, diffrax.AbstractImplicitSolver + ): + return + + if isinstance(solver, diffrax.KenCarp5): + term = diffrax.MultiTerm(term, term) + + args = (0.1, -1) + y0__args__term = (y0, args, term) + _compare_solve(y0__args__term, solver, stepsize_controller) + + +@pytest.mark.parametrize( + "solver", + [diffrax.Tsit5(), diffrax.Kvaerno5(), diffrax.KenCarp5()], +) +@pytest.mark.parametrize( + "stepsize_controller", + [diffrax.ConstantStepSize(), diffrax.PIDController(rtol=1e-8, atol=1e-8)], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t1=True), + diffrax.SaveAt(steps=True), + diffrax.SaveAt(t0=True, steps=True), + ], +) +@pytest.mark.parametrize("pytree_state", [True, False]) +def test_reversible_adjoint(solver, stepsize_controller, saveat, pytree_state): + if pytree_state: + y0 = (jnp.array(0.9), (jnp.array(5.4), jnp.array(1.2))) + term = diffrax.ODETerm(_PyTreeVectorField(nondiff_arg=1, diff_arg=-0.1)) + else: + y0 = jnp.array([0.9, 5.4]) + term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + + if isinstance(stepsize_controller, diffrax.ConstantStepSize) and isinstance( + solver, diffrax.AbstractImplicitSolver + ): + return + + if isinstance(solver, diffrax.KenCarp5): + term = diffrax.MultiTerm(term, term) + args = (0.1, -1) - term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) y0__args__term = (y0, args, term) del y0, args, term - base_solver = diffrax.Tsit5() - reversible_solver = diffrax.Reversible(base_solver, l=0.999) + _compare_grads(y0__args__term, solver, saveat, stepsize_controller, pytree_state) - loss, grads_base = _loss( - y0__args__term, base_solver, adjoint=diffrax.RecursiveCheckpointAdjoint() - ) - loss, grads_reversible = _loss( - y0__args__term, reversible_solver, adjoint=diffrax.ReversibleAdjoint() + +@pytest.mark.parametrize( + "solver, diffusion", + [ + (diffrax.ShARK(), lambda t, y, args: 1.0), + (diffrax.SlowRK(), lambda t, y, args: 0.1 * y), + ], +) +@pytest.mark.parametrize("adjoint", [False, True]) +def test_sde(solver, diffusion, adjoint): + y0 = jnp.array([0.9, 5.4]) + args = (0.1, -1) + drift = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + brownian_motion = diffrax.VirtualBrownianTree( + 0, 5, tol=1e-3, shape=(), levy_area=diffrax.SpaceTimeLevyArea, key=jr.PRNGKey(0) ) + terms = diffrax.MultiTerm(drift, diffrax.ControlTerm(diffusion, brownian_motion)) + y0__args__term = (y0, args, terms) + stepsize_controller = diffrax.ConstantStepSize() - assert tree_allclose(grads_base, grads_reversible, atol=1e-5) + if adjoint: + saveat = diffrax.SaveAt(t1=True) + _compare_grads(y0__args__term, solver, saveat, stepsize_controller, False) + + else: + _compare_solve(y0__args__term, solver, stepsize_controller) + + +@pytest.mark.parametrize( + "solver", + [diffrax.Tsit5(), diffrax.Kvaerno5(), diffrax.KenCarp5()], +) +@pytest.mark.parametrize( + "stepsize_controller", + [diffrax.ConstantStepSize(), diffrax.PIDController(rtol=1e-8, atol=1e-8)], +) +def test_cde(solver, stepsize_controller): + if isinstance(stepsize_controller, diffrax.ConstantStepSize) and isinstance( + solver, diffrax.AbstractImplicitSolver + ): + return + + vf = _VectorField(nondiff_arg=1, diff_arg=-0.1) + control = diffrax.ControlTerm(vf, QuadraticPath()) + terms = diffrax.MultiTerm(control, diffrax.ODETerm(vf)) + y0 = jnp.array([0.9, 5.4]) + args = (0.1, -1) + y0__args__term = (y0, args, terms) + _compare_solve(y0__args__term, solver, stepsize_controller) + + +def test_events(): + def vector_field(t, y, args): + _, v = y + return jnp.array([v, -8.0]) + + def cond_fn(t, y, args, **kwargs): + x, _ = y + return x + + @eqx.filter_value_and_grad + def _event_loss(y0, adjoint): + sol = diffrax.diffeqsolve( + term, solver, t0, t1, dt0, y0, adjoint=adjoint, event=event + ) + return cast(Array, sol.ys)[0, 1] + + y0 = jnp.array([10.0, 0.0]) + t0 = 0 + t1 = jnp.inf + dt0 = 0.1 + term = diffrax.ODETerm(vector_field) + root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) + event = diffrax.Event(cond_fn, root_finder) + solver = diffrax.Tsit5() + + msg = "`diffrax.ReversibleAdjoint` is not compatible with events." + with pytest.raises(NotImplementedError, match=msg): + _event_loss(y0, adjoint=diffrax.ReversibleAdjoint()) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(ts=jnp.linspace(0, 5)), + diffrax.SaveAt(dense=True), + diffrax.SaveAt(t0=True), + diffrax.SaveAt(ts=jnp.linspace(0, 5), fn=lambda t, y, args: t), + ], +) +@pytest.mark.parametrize( + "solver", + [diffrax.SemiImplicitEuler(), diffrax.ReversibleHeun(), diffrax.LeapfrogMidpoint()], +) +def test_incompatible_arguments(solver, saveat): + y0 = jnp.array([0.9, 5.4]) + args = (0.1, -1) + term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + y0__args__term = (y0, args, term) + + if isinstance(solver, diffrax.SemiImplicitEuler): + y0 = (y0, y0) + term = (term, term) + + with pytest.raises(ValueError): + loss, grads_reversible = _loss( + y0__args__term, + solver, + saveat, + adjoint=diffrax.ReversibleAdjoint(), + stepsize_controller=diffrax.ConstantStepSize(), + pytree_state=False, + ) From eb79c3334588acaa8e2735c6bceb2e1c701a615b Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Wed, 27 Nov 2024 16:57:48 +0000 Subject: [PATCH 3/4] catch and test unsafe sde arguments --- diffrax/_adjoint.py | 20 ++++++++++++++++++++ test/test_reversible.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 7bcd0174..bafb5f20 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -1142,6 +1142,26 @@ def loop( f"intrinsically algebraically reversible, such as {solver}." ) + if is_unsafe_sde(terms): + raise ValueError( + "`adjoint=ReversibleAdjoint()` does not support `UnsafeBrownianPath`. " + "Consider using `VirtualBrownianTree` instead." + ) + if is_sde(terms): + if isinstance(solver, AbstractItoSolver): + raise NotImplementedError( + f"`{solver.__class__.__name__}` converges to the Itô solution. " + "However `ReversibleAdjoint` currently only supports Stratonovich " + "SDEs." + ) + elif not isinstance(solver, AbstractStratonovichSolver): + warnings.warn( + f"{solver.__class__.__name__} is not marked as converging to " + "either the Itô or the Stratonovich solution. Note that " + "`ReversibleAdjoint` will only produce the correct solution for " + "Stratonovich SDEs." + ) + solver = _Reversible(solver, self.l) tprev = init_state.tprev tnext = init_state.tnext diff --git a/test/test_reversible.py b/test/test_reversible.py index 94acd7ce..c5b0cf38 100644 --- a/test/test_reversible.py +++ b/test/test_reversible.py @@ -323,3 +323,37 @@ def test_incompatible_arguments(solver, saveat): stepsize_controller=diffrax.ConstantStepSize(), pytree_state=False, ) + + +@pytest.mark.parametrize( + "solver, unsafe_brownian", [(diffrax.EulerHeun(), True), (diffrax.Euler(), False)] +) +def test_unsafe_sde(solver, unsafe_brownian): + diffusion = lambda t, y, args: 1.0 + y0 = jnp.array([0.9, 5.4]) + args = (0.1, -1) + drift = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) + + if unsafe_brownian: + brownian_path = diffrax.UnsafeBrownianPath(shape=(), key=jr.PRNGKey(0)) + else: + brownian_path = diffrax.VirtualBrownianTree( + 0, + 5, + tol=1e-3, + shape=(), + key=jr.PRNGKey(1), + ) + + terms = diffrax.MultiTerm(drift, diffrax.ControlTerm(diffusion, brownian_path)) + y0__args__term = (y0, args, terms) + + with pytest.raises((ValueError, NotImplementedError)): + loss, grads_reversible = _loss( + y0__args__term, + solver, + saveat=diffrax.SaveAt(t1=True), + adjoint=diffrax.ReversibleAdjoint(), + stepsize_controller=diffrax.ConstantStepSize(), + pytree_state=False, + ) From 8f65c2c46a77b080018953ace5b0e9ca0aa144f5 Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Thu, 28 Nov 2024 18:20:15 +0000 Subject: [PATCH 4/4] remove inplace updates --- diffrax/_adjoint.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index bafb5f20..1d0cdfc5 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -926,9 +926,9 @@ def _loop_reversible_bwd( t1_only = saveat.subs.t1 if t1_only: y1 = (ω(ys)[-1]).ω - grad_ys = (ω(grad_final_state.save_state.ys)[-1]).ω - grad_ys = jtu.tree_map(_materialise_none, y1, grad_ys) - grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_ys) + grad_y1 = (ω(grad_final_state.save_state.ys)[-1]).ω + grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1) + grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_y1) # Otherwise we must be using SaveAt(..., steps=True) due to the guard in # ReversibleAdjoint. If y0 is not saved (t0=False) then we prepend grad_y0 (zeros). @@ -945,6 +945,7 @@ def _loop_reversible_bwd( ) grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys) + grad_y1 = (ω(grad_ys)[ts_final_index]).ω del grad_final_state, grad_final_state__aux_stats @@ -966,19 +967,16 @@ def solver_step(t0, t1, original_solver_state, y0, args, terms): ) return step, original_solver_state - ts_index, y1, solver_state, grad_ys, grad_z1, grad_args, grad_terms = state + ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state original_solver_state, z1 = solver_state t1 = ts[ts_index] t0 = ts[ts_index - 1] if t1_only: - grad_y1 = grad_ys grad_y0 = grad_y0_zeros # pyright: ignore - else: - grad_y1 = (ω(grad_ys)[ts_index]).ω - grad_y0 = (ω(grad_ys)[ts_index - 1]).ω + grad_y0 = (ω(grad_ys)[ts_index - 1]).ω # pyright: ignore solver_step_fn = ft.partial(solver_step, t1, t0, original_solver_state) step_y1, vjp_fun_y1, original_solver_state = eqx.filter_vjp( @@ -1003,19 +1001,13 @@ def solver_step(t0, t1, original_solver_state, y0, args, terms): grad_terms = (ω(grad_terms) - ω(grad_step_y1[2]) + ω(grad_step_z0[2])).ω grad_args = (ω(grad_args) - ω(grad_step_y1[1]) + ω(grad_step_z0[1])).ω - if t1_only: - grad_ys = grad_y0 - else: - grad_ys = (ω(grad_ys).at[ts_index].set(ω(grad_y1))).ω - grad_ys = (ω(grad_ys).at[ts_index - 1].set(ω(grad_y0))).ω - ts_index = ts_index - 1 return ( ts_index, y0, (original_solver_state, z0), - grad_ys, + grad_y0, grad_z0, grad_args, grad_terms, @@ -1029,18 +1021,14 @@ def cond_fun(state): ts_final_index, y1, (original_solver_state, z1), - grad_ys, + grad_y1, grad_z1, grad_args, grad_terms, ) state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") - _, _, _, grad_ys, grad_z0, grad_args, grad_terms = state - if t1_only: - grad_y0 = grad_ys - else: - grad_y0 = (ω(grad_ys)[0]).ω + _, _, _, grad_y0, grad_z0, grad_args, grad_terms = state return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms