Skip to content

Commit

Permalink
added optx in delays
Browse files Browse the repository at this point in the history
  • Loading branch information
thibmonsel committed Oct 23, 2023
1 parent 2b24da8 commit 3a81a32
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 144 deletions.
111 changes: 53 additions & 58 deletions diffrax/delays.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Callable, Optional, Sequence, Type, Union

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import unvmap_any, ω
from equinox.internal import unvmap_any
from optimistix import fixed_point, FixedPointIteration

from .custom_types import Array, Bool, Int, PyTree, Scalar
from .global_interpolation import DenseInterpolation
from .local_interpolation import AbstractLocalInterpolation
from .misc import rms_norm
from .nonlinear_solver import NewtonNonlinearSolver
from .term import VectorFieldWrapper

Expand Down Expand Up @@ -143,12 +144,18 @@ def history_extrapolation_implicit(
state,
args,
):
def _cond_fun(_val):
_, _, _, _, _, pred, step = _val
return (implicit_step & pred) | (jnp.invert(implicit_step) & (step == 0))

def _body_fun(_val):
y_prev, _, dense_info, _, _, _, step = _val
def fn(dense_info, args):
(
terms,
_,
dense_interp,
solver,
delays,
t0,
y0_history,
state,
vf_args,
) = args
terms_ = bind_history(
terms,
delays,
Expand All @@ -161,66 +168,54 @@ def _body_fun(_val):
state.tnext,
y0_history,
)

(y, y_error, dense_info, solver_state, solver_result) = solver.step(
(y, y_error, new_dense_info, solver_state, solver_result) = solver.step(
terms_,
state.tprev,
state.tnext,
state.y,
args,
vf_args,
state.solver_state,
state.made_jump,
)

_pred = (
rms_norm(
(
(ω(y).call(jnp.abs) - y_prev**ω)
/ (delays.atol + delays.rtol * ω(y).call(jnp.abs))
).ω
)
< 1
)
_pred = _pred & (step < 10)
return (
y,
y_error,
dense_info,
solver_state,
solver_result,
_pred,
step + 1,
)

unwrapped_buffer = jtu.tree_leaves(
eqx.filter(state.dense_infos, eqx.is_inexact_array),
is_leaf=eqx.is_inexact_array,
)
aux_dense_infos = dict(zip(state.dense_infos.keys(), unwrapped_buffer))

_init_val = (
state.y,
state.y,
jtu.tree_map(lambda x: x[state.dense_save_index - 1], aux_dense_infos),
state.solver_state,
0,
True,
0,
return new_dense_info, (y, y_error, solver_state, solver_result)

# unwrapped_buffer = jtu.tree_leaves(
# eqx.filter(state.dense_infos, eqx.is_inexact_array),
# is_leaf=eqx.is_inexact_array,
# )
# aux_dense_infos = dict(zip(state.dense_infos.keys(), unwrapped_buffer))

# def get_dense_info(dense_infos, idx):
# return jtu.tree_map(lambda x: x[idx], dense_infos)

# # dense_info_aux = dict(zip(["k","y0", "y1"],
# [jnp.zeros((4,)), jnp.zeros((1,)), jnp.zeros((1,))]))
# jax.debug.breakpoint()
# struct_dense_info = eqx.filter_eval_shape(get_dense_info, state.dense_infos, 0)
# infos = jtu.tree_map(lambda _, x: x[...], struct_dense_info, state.dense_infos)
# jax.debug.print("infos integrate {} ", infos)
# print("infos,", infos)
# print("struct_dense_info", struct_dense_info)

jax.debug.print("state.dense_infos {}", state.dense_infos)
init_guess = jtu.tree_map(
lambda x: x[state.dense_save_index - 1], state.dense_infos
)
(
y,
y_error,
dense_info,
solver_state,
solver_result,
_,
final_step,
) = lax.while_loop(_cond_fun, _body_fun, _init_val)

y_error = jtu.tree_map(
lambda _y_error: jnp.where(final_step < 10, _y_error, jnp.inf),
y_error,
alg = FixedPointIteration(rtol=delays.rtol, atol=delays.atol)
nonlinear_args = (
terms,
implicit_step,
dense_interp,
solver,
delays,
t0,
y0_history,
state,
args,
)
sol = fixed_point(fn, alg, init_guess, nonlinear_args, has_aux=True)
dense_info, (y, y_error, solver_state, solver_result) = sol.value, sol.aux
return y, y_error, dense_info, solver_state, solver_result


Expand Down
166 changes: 80 additions & 86 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@

from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint
from .custom_types import Array, Bool, Int, PyTree, Scalar
from .delays import (
bind_history,
Delays,
history_extrapolation_implicit,
maybe_find_discontinuity,
)
from .delays import bind_history, Delays, history_extrapolation_implicit
from .event import AbstractDiscreteTerminatingEvent
from .global_interpolation import DenseInterpolation
from .heuristics import is_sde, is_unsafe_sde
Expand Down Expand Up @@ -247,15 +242,11 @@ def body_fun(state):
min_delay = jnp.stack(min_delay).min()
implicit_step = min_delay < (state.tnext - state.tprev)

unwrapped_buffer = jtu.tree_leaves(
eqx.filter(state.dense_infos, eqx.is_inexact_array),
is_leaf=eqx.is_inexact_array,
)
dense_interp = DenseInterpolation(
ts=state.dense_ts[...],
ts_size=state.dense_save_index + 1,
interpolation_cls=solver.interpolation_cls,
infos=dict(zip(state.dense_infos.keys(), unwrapped_buffer)),
infos=state.dense_infos,
direction=1,
y0_if_trivial=y0_history(t0),
t0_if_trivial=t0,
Expand All @@ -278,6 +269,7 @@ def body_fun(state):
state,
args,
)

# e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
# we get a negative value for y, and then get a NaN vector field. (And then
# everything breaks.) See #143.
Expand All @@ -303,62 +295,64 @@ def body_fun(state):
)

# Finding all of the potential discontinuity roots
if delays is not None:
_part_maybe_find_discontinuity = ft.partial(
maybe_find_discontinuity,
tprev,
tnext,
dense_info,
state,
delays,
solver,
args,
)

tsearch = jnp.linspace(tprev, tnext, delays.sub_intervals)
batch_tprev, batch_tnext = tsearch[:-1], tsearch[1:]
vmap_maybe_find_discontinuity_wrapper = jax.vmap(
_part_maybe_find_discontinuity, (None, 0, 0)
)
if delays.recurrent_checking:
(
tnext_candidate,
batch_discont_update,
) = vmap_maybe_find_discontinuity_wrapper(
False, batch_tprev, batch_tnext
)
else:
(
tnext_candidate,
batch_discont_update,
) = vmap_maybe_find_discontinuity_wrapper(
keep_step, batch_tprev, batch_tnext
)

proxy_tnext = jnp.where(batch_discont_update, tnext_candidate, jnp.inf)
proxy_tnext = jnp.min(proxy_tnext)

tnext, discont_update = jax.lax.cond(
jnp.isinf(proxy_tnext),
lambda: (tnext, False),
lambda: (proxy_tnext, True),
)

# Count the number of steps in DDEs, just for statistical purposes
num_dde_implicit_step = state.num_dde_implicit_step + (
keep_step & implicit_step
)
num_dde_explicit_step = state.num_dde_explicit_step + (
keep_step & jnp.invert(implicit_step)
)

assert jnp.result_type(discont_update) is jnp.dtype(bool)
assert jnp.result_type(keep_step) is jnp.dtype(bool)
# if delays is not None:
# _part_maybe_find_discontinuity = ft.partial(
# maybe_find_discontinuity,
# tprev,
# tnext,
# dense_info,
# state,
# delays,
# solver,
# args,
# )

# tsearch = jnp.linspace(tprev, tnext, delays.sub_intervals)
# batch_tprev, batch_tnext = tsearch[:-1], tsearch[1:]
# vmap_maybe_find_discontinuity_wrapper = jax.vmap(
# _part_maybe_find_discontinuity, (None, 0, 0)
# )
# if delays.recurrent_checking:
# (
# tnext_candidate,
# batch_discont_update,
# ) = vmap_maybe_find_discontinuity_wrapper(
# False, batch_tprev, batch_tnext
# )
# else:
# (
# tnext_candidate,
# batch_discont_update,
# ) = vmap_maybe_find_discontinuity_wrapper(
# keep_step, batch_tprev, batch_tnext
# )

# proxy_tnext = jnp.where(batch_discont_update, tnext_candidate, jnp.inf)
# proxy_tnext = jnp.min(proxy_tnext)

# tnext, discont_update = jax.lax.cond(
# jnp.isinf(proxy_tnext),
# lambda: (tnext, False),
# lambda: (proxy_tnext, True),
# )

# # Count the number of steps in DDEs, just for statistical purposes
# num_dde_implicit_step = state.num_dde_implicit_step + (
# keep_step & implicit_step
# )
# num_dde_explicit_step = state.num_dde_explicit_step + (
# keep_step & jnp.invert(implicit_step)
# )

# assert jnp.result_type(discont_update) is jnp.dtype(bool)

# assert jnp.result_type(keep_step) is jnp.dtype(bool)

#
# Do some book-keeping.
#

# discont_update = False
num_dde_explicit_step = num_dde_implicit_step = 0
tprev = jnp.minimum(tprev, t1)
tnext = _clip_to_end(tprev, tnext, t1, keep_step)

Expand Down Expand Up @@ -455,16 +449,16 @@ def maybe_inplace(i, u, x):
else:
return x.at[i].set(u, pred=keep_step)

def maybe_inplace_delay(i, u, x):
# Annoying hack. We normally call this with `x` wrapped into a buffer
# (from Equinox's while loops). However we do also first trace through to
# see if we can resolve some values statically, in which case normal JAX
# arrays don't support the extra `pred` argument. We don't then use the
# result of this so we just skip it.
if _filtering:
return x
else:
return x.at[i].set(u, pred=discont_update)
# def maybe_inplace_delay(i, u, x):
# # Annoying hack. We normally call this with `x` wrapped into a buffer
# # (from Equinox's while loops). However we do also first trace through to
# # see if we can resolve some values statically, in which case normal JAX
# # arrays don't support the extra `pred` argument. We don't then use the
# # result of this so we just skip it.
# if _filtering:
# return x
# else:
# return x.at[i].set(u, pred=discont_update)

def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.steps:
Expand Down Expand Up @@ -496,19 +490,19 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
dense_save_index = dense_save_index + keep_step

# Updating discontinuity
if delays is not None:
if delays.recurrent_checking:
eqxi.error_if(
discontinuities_save_index,
discontinuities_save_index == delays.max_discontinuities,
"the number of discontinuities detected reached the number of"
" `max_discontinuities`, please raise its value.",
)

discontinuities = maybe_inplace_delay(
discontinuities_save_index + 1, tnext, discontinuities
)
discontinuities_save_index = discontinuities_save_index + discont_update
# if delays is not None:
# if delays.recurrent_checking:
# eqxi.error_if(
# discontinuities_save_index,
# discontinuities_save_index >= delays.max_discontinuities,
# "the number of discontinuities detected reached the number of"
# " `max_discontinuities`, please raise its value.",
# )

# discontinuities = maybe_inplace_delay(
# discontinuities_save_index + 1, tnext, discontinuities
# )
# discontinuities_save_index = discontinuities_save_index + discont_update

new_state = State(
y=y,
Expand Down

0 comments on commit 3a81a32

Please sign in to comment.