Skip to content

Commit

Permalink
Minor fixes, adding diacritic to Levy
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Jun 12, 2024
1 parent d1f2c0a commit fd946c5
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 129 deletions.
2 changes: 1 addition & 1 deletion benchmarks/brownian_tree_times.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is
additionally capable of computing Levy area.
additionally capable of computing Lévy area.
Here we check the speed of the new implementation against the old implementation, to be
sure that it is still fast.
Expand Down
33 changes: 11 additions & 22 deletions diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,28 +368,17 @@ def _space_time_levy_area(terms):
)
leaves = jtu.tree_leaves(terms, is_leaf=has_levy_area)
if any(has_levy_area(leaf) for leaf in leaves):
return r"""
% You are simulating Brownian motion using Levy area, the formulae for which
% are due to:
@misc{jelinčič2024singleseed,
title={Single-seed generation of Brownian paths and integrals
for adaptive and high order SDE solvers},
author={Andraž Jelinčič and James Foster and Patrick Kidger},
year={2024},
eprint={2405.06464},
archivePrefix={arXiv},
primaryClass={math.NA}
}
% and Theorem 6.1.6 of
@phdthesis{foster2020a,
publisher = {University of Oxford},
school = {University of Oxford},
title = {Numerical approximations for stochastic differential equations},
author = {Foster, James M.},
year = {2020}
}
"""
_, single_seed_ref, foster_ref = _parse_reference_multi(VirtualBrownianTree)
return (
r"""
% You are simulating Brownian motion using Lévy area,
% the formulae for which are due to:
"""
+ single_seed_ref
+ "\n\n"
+ r"""% and Theorem 6.1.6 of:"""
+ foster_ref
)


@citation_rules.append
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def evaluate(
left-continuous or right-continuous at any jump points, but Brownian
motion has no jump points.)
- `use_levy`: If True, the return type will be a `LevyVal`, which contains
PyTrees of Brownian increments and their Levy areas.
PyTrees of Brownian increments and their Lévy areas.
**Returns:**
Expand Down
11 changes: 7 additions & 4 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class UnsafeBrownianPath(AbstractBrownianPath):
motion. Hence the restrictions above. (They describe the general case for which the
correlation structure isn't needed.)
!!! info "Levy Area"
!!! info "Lévy Area"
Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or
`diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it
Expand Down Expand Up @@ -147,22 +147,25 @@ def _evaluate_leaf(
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
key_w, key_hh, key_kk = jr.split(key, 3)
w = jr.normal(key, shape.shape, shape.dtype) * w_std
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))

if levy_area is SpaceTimeTimeLevyArea:
key_w, key_hh, key_kk = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
kk_std = w_std / math.sqrt(720)
kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std
levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk)

elif levy_area is SpaceTimeLevyArea:
key_w, key_hh = jr.split(key, 2)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh)
elif levy_area is BrownianIncrement:
w = jr.normal(key, shape.shape, shape.dtype) * w_std
levy_val = BrownianIncrement(dt=dt, W=w)
else:
assert False
Expand All @@ -180,6 +183,6 @@ def _evaluate_leaf(
be a tuple of integers, describing the shape of a single JAX array. In that case
the dtype is chosen to be the default floating-point dtype.
- `key`: A random key.
- `levy_area`: Whether to additionally generate Levy area. This is required by some SDE
- `levy_area`: Whether to additionally generate Lévy area. This is required by some SDE
solvers.
"""
116 changes: 58 additions & 58 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
# author = {Foster, James M.},
# year = {2020}
# }
# For more about space-time Levy area see Definition 4.2.1.
# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6.
# For the general interpolation rule for space-time Levy area see Theorem 6.1.4.
# For more about space-time Lévy area see Definition 4.2.1.
# For the midpoint rule for generating space-time Lévy area see Theorem 6.1.6.
# For the general interpolation rule for space-time Lévy area see Theorem 6.1.4.

FloatDouble: TypeAlias = tuple[Inexact[Array, " *shape"], Inexact[Array, " *shape"]]
FloatTriple: TypeAlias = tuple[
Expand All @@ -64,9 +64,9 @@
_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement)


# An internal dataclass that holds the rescaled Levy areas
# An internal dataclass that holds the rescaled Lévy areas
# in addition to the non-rescaled ones, for the purposes of
# taking the difference between two Levy areas.
# taking the difference between two Lévy areas.
class _LevyVal(eqx.Module):
dt: RealScalarLike
W: Array
Expand All @@ -75,7 +75,7 @@ class _LevyVal(eqx.Module):
K: Optional[Array]
bar_K: Optional[Array]

def __post_init__(self):
def __check_init__(self):
if self.H is None:
assert self.bar_H is None
assert self.K is None
Expand All @@ -99,22 +99,22 @@ def _levy_diff(_, x0: _LevyVal, x1: _LevyVal) -> AbstractBrownianIncrement:
**Arguments:**
- `_`: unused, for the purposes of aligning the `jtu.tree_map`.
- `x0`: `_AbstractLevyVal` at time `s`.
- `x1`: `_AbstractLevyVal` at time `u`.
- `x0`: `_LevyVal` at time `s`.
- `x1`: `_LevyVal` at time `u`.
**Returns:**
`AbstractBrownianIncrement(W_su, H_su, K_su)`
"""
dtype = jnp.dtype(x0.W)
dtype = jnp.result_type(x0.W)
tdtype = complex_to_real_dtype(dtype)
su = jnp.asarray(x1.dt - x0.dt, dtype=tdtype)
if x0.H is None: # BM only case
assert x1.H is None
return BrownianIncrement(dt=su, W=x1.W - x0.W)

# the following computation is common to the space-time
# and the space-time-time Levy area case
# and the space-time-time Lévy area case
assert x0.H is not None
assert x1.H is not None
assert x0.bar_H is not None
Expand All @@ -126,20 +126,20 @@ def _levy_diff(_, x0: _LevyVal, x1: _LevyVal) -> AbstractBrownianIncrement:
bhh_su = x1.bar_H - x0.bar_H - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
hh_su = inverse_su * bhh_su

if x0.K is None: # space-time Levy area case
if x0.K is None: # space-time Lévy area case
return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su)

elif x0.K is not None: # space-time-time Levy area case
elif x0.K is not None: # space-time-time Lévy area case
assert x1.K is not None
assert x0.bar_K is not None
assert x1.bar_K is not None
with jax.numpy_dtype_promotion("standard"):
bkk_su = (
x1.bar_K
- x0.bar_K
- su / 2 * x0.bar_H
+ x0.dt / 2 * bhh_su
- (x1.dt - 2 * x0.dt) / 12 * u_bb_s
- (su / 2) * x0.bar_H
+ (x0.dt / 2) * bhh_su
- ((x1.dt - 2 * x0.dt) / 12) * u_bb_s
)
su2 = jnp.square(su)
inverse_su2 = 1 / jnp.where(jnp.abs(su2) < jnp.finfo(su2).eps, jnp.inf, su2)
Expand All @@ -161,8 +161,12 @@ def _make_levy_val(_, x: _LevyVal) -> AbstractBrownianIncrement:


def _split_interval(
pred: BoolScalarLike, x_stu: FloatTriple, x_st_tu: FloatDouble
) -> FloatTriple:
pred: BoolScalarLike, x_stu: Optional[FloatTriple], x_st_tu: Optional[FloatDouble]
) -> Optional[FloatTriple]:
if x_stu is None:
assert x_st_tu is None
return None
assert x_st_tu is not None
x_s, x_t, x_u = x_stu
x_st, x_tu = x_st_tu
x_s = jnp.where(pred, x_t, x_s)
Expand All @@ -174,12 +178,14 @@ def _split_interval(
class VirtualBrownianTree(AbstractBrownianPath):
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
!!! info "Levy Area"
!!! info "Lévy Area"
The parameter `levy_area` can be set to one of:
- [`diffrax.BrownianIncrement`][] (default, generates the increment of W)
- [`diffrax.SpaceTimeLevyArea`][] (generates W and the space-time Levy area H)
- [`diffrax.SpaceTimeLevyArea`][] (generates W and the space-time Lévy area H)
- [`diffrax.SpaceTimeTimeLevyArea`][] (generates W, H and the space-time-time
Levy area K)
Lévy area K)
The choice of `levy_area` will impact the Brownian path, so even with the same
key, the trajectory will be different depending on the value of `levy_area`.
Expand All @@ -197,8 +203,8 @@ class VirtualBrownianTree(AbstractBrownianPath):
```
The implementation here is an improvement on the above, in that it additionally
simulates space-time and space-time-time Levy areas, and exactly matches the
distribution of the Brownian motion and its Levy areas at all query times.
simulates space-time and space-time-time Lévy areas, and exactly matches the
distribution of the Brownian motion and its Lévy areas at all query times.
This is due to the paper
```bibtex
Expand Down Expand Up @@ -386,7 +392,7 @@ def _body_fun(_state: _State):
(
_t,
_w_stu,
_w_inc,
_w_st_tu,
_keys,
_bhh_stu,
_bhh_st_tu,
Expand All @@ -400,22 +406,10 @@ def _body_fun(_state: _State):
_key_st, _key_tu = _keys
_key = jnp.where(_cond, _key_st, _key_tu)

_w = _split_interval(_cond, _w_stu, _w_inc)

if self.levy_area is SpaceTimeTimeLevyArea:
assert _bkk_stu is not None and _bkk_st_tu is not None
_bkk = _split_interval(_cond, _bkk_stu, _bkk_st_tu)
else:
_bkk = None

if (
self.levy_area is SpaceTimeLevyArea
or self.levy_area is SpaceTimeTimeLevyArea
):
assert _bhh_stu is not None and _bhh_st_tu is not None
_bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu)
else:
_bhh = None
_w = _split_interval(_cond, _w_stu, _w_st_tu)
assert _w is not None
_bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu)
_bkk = _split_interval(_cond, _bkk_stu, _bkk_st_tu)

return _State(
level=_level,
Expand Down Expand Up @@ -474,8 +468,8 @@ def _body_fun(_state: _State):
2 * su
)
wk_cov = (sr_by_su**4) * ru_by_su * (sr - ru) / 12
hh_cov = (
sr / 12 * (1 - sr_by_su_3 * (sr2 + 2 * sr * ru + 16 * ru2) / su2)
hh_cov = (sr / 12) * (
1 - sr_by_su_3 * (sr2 + 2 * sr * ru + 16 * ru2) / su2
)
hk_cov = -(ru / 24) * sr_by_su_5
kk_cov = (sr / 720) * (1.0 - sr_by_su_5)
Expand All @@ -489,15 +483,20 @@ def _body_fun(_state: _State):
)

if self._spline == "sqrt":
# NOTE: not compatible with jnp.float16
# NOTE: jr.multivariate_normal is not compatible with jnp.float16,
# so we need to cast to jnp.float32 before calling it.
with jax.numpy_dtype_promotion("standard"):
dtype_atleast32 = jnp.result_type(dtype, jnp.float32)
cov = jnp.asarray(cov, dtype_atleast32)
hat_y = jr.multivariate_normal(
final_state.key,
jnp.zeros((3,), dtype),
jnp.zeros((3,), dtype_atleast32),
cov,
shape=shape,
dtype=dtype,
dtype=dtype_atleast32,
method="svd",
)
hat_y = jnp.asarray(hat_y, dtype)

elif self._spline == "zero":
hat_y = jnp.zeros(shape=shape + (3,), dtype=dtype)
Expand Down Expand Up @@ -525,9 +524,9 @@ def _body_fun(_state: _State):
bkk_r = (
bkk_s
+ bkk_sr
+ sr / 2 * bhh_s
- s / 2 * bhh_sr
+ (r - 2 * s) / 12 * r_bb_s
+ (sr / 2) * bhh_s
- (s / 2) * bhh_sr
+ ((r - 2 * s) / 12) * r_bb_s
)

inverse_r = 1 / jnp.where(jnp.square(r) < jnp.finfo(r).eps, jnp.inf, r)
Expand Down Expand Up @@ -618,7 +617,7 @@ def _brownian_arch(
and also returns `w_st` and `w_tu` in addition to just `w_t`. Same for `bhh`
if it is not None.
Note that the inputs and outputs already contain `bkk`. These values are
there for the sake of a future extension with "space-time-time" Levy area
there for the sake of a future extension with "space-time-time" Lévy area
and should be None for now.
**Arguments:**
Expand All @@ -637,6 +636,7 @@ def _brownian_arch(
- `bhh_st_tu`: (optional) $(\bar{H}_{s,t}, \bar{H}_{t,u})$
- `bkk_stu`: (optional) $(\bar{K}_s, \bar{K}_t, \bar{K}_u)$
- `bkk_st_tu`: (optional) $(\bar{K}_{s,t}, \bar{K}_{t,u})$
"""
key_st, midpoint_key, key_tu = jr.split(_state.key, 3)
keys = (key_st, key_tu)
Expand Down Expand Up @@ -669,15 +669,15 @@ def _brownian_arch(
su2 = su**2

w_term1 = w_su / 2
w_term2 = 3 / (2 * su) * bhh_su + z
w_term2 = (3 / (2 * su)) * bhh_su + z
w_st = w_term1 + w_term2
w_tu = w_term1 - w_term2
bhh_term1 = bhh_su / 8 - st / 2 * z
bhh_term2 = 15 / (8 * su) * bkk_su + st * x1
bhh_term1 = bhh_su / 8 - (st / 2) * z
bhh_term2 = (15 / (8 * su)) * bkk_su + st * x1
bhh_st = bhh_term1 + bhh_term2
bhh_tu = bhh_term1 - bhh_term2
bkk_term1 = bkk_su / 32 - (su2 / 8) * x1
bkk_term2 = su2 / 4 * x2
bkk_term2 = (su2 / 4) * x2
bkk_st = bkk_term1 + bkk_term2
bkk_tu = bkk_term1 - bkk_term2
w_st_tu = (w_st, w_tu)
Expand All @@ -690,9 +690,9 @@ def _brownian_arch(
bkk_t = (
bkk_s
+ bkk_st
+ st / 2 * bhh_s
- s / 2 * bhh_st
+ (t - 2 * s) / 12 * t_bb_s
+ (st / 2) * bhh_s
- (s / 2) * bhh_st
+ ((t - 2 * s) / 12) * t_bb_s
)

w_stu = (w_s, w_t, w_u)
Expand All @@ -711,13 +711,13 @@ def _brownian_arch(
n = z2 * jnp.sqrt(su / 12)

w_term1 = w_su / 2
w_term2 = 3 / (2 * su) * bhh_su + z
w_term2 = (3 / (2 * su)) * bhh_su + z
w_st = w_term1 + w_term2
w_tu = w_term1 - w_term2
w_st_tu = (w_st, w_tu)

bhh_term1 = bhh_su / 8 - su / 4 * z
bhh_term2 = su / 4 * n
bhh_term2 = (su / 4) * n
bhh_st = bhh_term1 + bhh_term2
bhh_tu = bhh_term1 - bhh_term2
bhh_st_tu = (bhh_st, bhh_tu)
Expand All @@ -734,7 +734,7 @@ def _brownian_arch(
assert _state.bhh_s_u_su is None
assert _state.bkk_s_u_su is None
mean = 0.5 * w_su
w_term2 = root_su / 2 * jr.normal(midpoint_key, shape, dtype)
w_term2 = (root_su / 2) * jr.normal(midpoint_key, shape, dtype)
w_st = mean + w_term2
w_tu = mean - w_term2
w_st_tu = (w_st, w_tu)
Expand Down
Loading

0 comments on commit fd946c5

Please sign in to comment.