diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index bafb5f20..1d0cdfc5 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -926,9 +926,9 @@ def _loop_reversible_bwd( t1_only = saveat.subs.t1 if t1_only: y1 = (ω(ys)[-1]).ω - grad_ys = (ω(grad_final_state.save_state.ys)[-1]).ω - grad_ys = jtu.tree_map(_materialise_none, y1, grad_ys) - grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_ys) + grad_y1 = (ω(grad_final_state.save_state.ys)[-1]).ω + grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1) + grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_y1) # Otherwise we must be using SaveAt(..., steps=True) due to the guard in # ReversibleAdjoint. If y0 is not saved (t0=False) then we prepend grad_y0 (zeros). @@ -945,6 +945,7 @@ def _loop_reversible_bwd( ) grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys) + grad_y1 = (ω(grad_ys)[ts_final_index]).ω del grad_final_state, grad_final_state__aux_stats @@ -966,19 +967,16 @@ def solver_step(t0, t1, original_solver_state, y0, args, terms): ) return step, original_solver_state - ts_index, y1, solver_state, grad_ys, grad_z1, grad_args, grad_terms = state + ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state original_solver_state, z1 = solver_state t1 = ts[ts_index] t0 = ts[ts_index - 1] if t1_only: - grad_y1 = grad_ys grad_y0 = grad_y0_zeros # pyright: ignore - else: - grad_y1 = (ω(grad_ys)[ts_index]).ω - grad_y0 = (ω(grad_ys)[ts_index - 1]).ω + grad_y0 = (ω(grad_ys)[ts_index - 1]).ω # pyright: ignore solver_step_fn = ft.partial(solver_step, t1, t0, original_solver_state) step_y1, vjp_fun_y1, original_solver_state = eqx.filter_vjp( @@ -1003,19 +1001,13 @@ def solver_step(t0, t1, original_solver_state, y0, args, terms): grad_terms = (ω(grad_terms) - ω(grad_step_y1[2]) + ω(grad_step_z0[2])).ω grad_args = (ω(grad_args) - ω(grad_step_y1[1]) + ω(grad_step_z0[1])).ω - if t1_only: - grad_ys = grad_y0 - else: - grad_ys = (ω(grad_ys).at[ts_index].set(ω(grad_y1))).ω - grad_ys = (ω(grad_ys).at[ts_index - 1].set(ω(grad_y0))).ω - ts_index = ts_index - 1 return ( ts_index, y0, (original_solver_state, z0), - grad_ys, + grad_y0, grad_z0, grad_args, grad_terms, @@ -1029,18 +1021,14 @@ def cond_fun(state): ts_final_index, y1, (original_solver_state, z1), - grad_ys, + grad_y1, grad_z1, grad_args, grad_terms, ) state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") - _, _, _, grad_ys, grad_z0, grad_args, grad_terms = state - if t1_only: - grad_y0 = grad_ys - else: - grad_y0 = (ω(grad_ys)[0]).ω + _, _, _, grad_y0, grad_z0, grad_args, grad_terms = state return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms