Skip to content

Commit

Permalink
remove inplace updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sammccallum committed Nov 28, 2024
1 parent eb79c33 commit 8f65c2c
Showing 1 changed file with 9 additions and 21 deletions.
30 changes: 9 additions & 21 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,9 @@ def _loop_reversible_bwd(
t1_only = saveat.subs.t1
if t1_only:
y1 = (ω(ys)[-1]).ω
grad_ys = (ω(grad_final_state.save_state.ys)[-1]).ω
grad_ys = jtu.tree_map(_materialise_none, y1, grad_ys)
grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_ys)
grad_y1 = (ω(grad_final_state.save_state.ys)[-1]).ω
grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1)
grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_y1)

# Otherwise we must be using SaveAt(..., steps=True) due to the guard in
# ReversibleAdjoint. If y0 is not saved (t0=False) then we prepend grad_y0 (zeros).
Expand All @@ -945,6 +945,7 @@ def _loop_reversible_bwd(
)

grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)
grad_y1 = (ω(grad_ys)[ts_final_index]).ω

del grad_final_state, grad_final_state__aux_stats

Expand All @@ -966,19 +967,16 @@ def solver_step(t0, t1, original_solver_state, y0, args, terms):
)
return step, original_solver_state

ts_index, y1, solver_state, grad_ys, grad_z1, grad_args, grad_terms = state
ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state
original_solver_state, z1 = solver_state

t1 = ts[ts_index]
t0 = ts[ts_index - 1]

if t1_only:
grad_y1 = grad_ys
grad_y0 = grad_y0_zeros # pyright: ignore

else:
grad_y1 = (ω(grad_ys)[ts_index]).ω
grad_y0 = (ω(grad_ys)[ts_index - 1]).ω
grad_y0 = (ω(grad_ys)[ts_index - 1]).ω # pyright: ignore

solver_step_fn = ft.partial(solver_step, t1, t0, original_solver_state)
step_y1, vjp_fun_y1, original_solver_state = eqx.filter_vjp(
Expand All @@ -1003,19 +1001,13 @@ def solver_step(t0, t1, original_solver_state, y0, args, terms):
grad_terms = (ω(grad_terms) - ω(grad_step_y1[2]) + ω(grad_step_z0[2])).ω
grad_args = (ω(grad_args) - ω(grad_step_y1[1]) + ω(grad_step_z0[1])).ω

if t1_only:
grad_ys = grad_y0
else:
grad_ys = (ω(grad_ys).at[ts_index].set(ω(grad_y1))).ω
grad_ys = (ω(grad_ys).at[ts_index - 1].set(ω(grad_y0))).ω

ts_index = ts_index - 1

return (
ts_index,
y0,
(original_solver_state, z0),
grad_ys,
grad_y0,
grad_z0,
grad_args,
grad_terms,
Expand All @@ -1029,18 +1021,14 @@ def cond_fun(state):
ts_final_index,
y1,
(original_solver_state, z1),
grad_ys,
grad_y1,
grad_z1,
grad_args,
grad_terms,
)

state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax")
_, _, _, grad_ys, grad_z0, grad_args, grad_terms = state
if t1_only:
grad_y0 = grad_ys
else:
grad_y0 = (ω(grad_ys)[0]).ω
_, _, _, grad_y0, grad_z0, grad_args, grad_terms = state

return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms

Expand Down

0 comments on commit 8f65c2c

Please sign in to comment.