Skip to content

Commit

Permalink
Added single-seed paper to docs and autocite, some minor fixes, diacr…
Browse files Browse the repository at this point in the history
…itic on Levy
  • Loading branch information
andyElking committed Jun 18, 2024
1 parent a84cbfa commit fa83ab4
Show file tree
Hide file tree
Showing 10 changed files with 382 additions and 326 deletions.
2 changes: 1 addition & 1 deletion benchmarks/brownian_tree_times.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is
additionally capable of computing Levy area.
additionally capable of computing Lévy area.
Here we check the speed of the new implementation against the old implementation, to be
sure that it is still fast.
Expand Down
23 changes: 16 additions & 7 deletions diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ._adjoint import BacksolveAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint
from ._brownian import AbstractBrownianPath, VirtualBrownianTree
from ._custom_types import BrownianIncrement
from ._heuristics import is_cde, is_sde
from ._integrate import diffeqsolve
from ._misc import adjoint_rms_seminorm
Expand Down Expand Up @@ -347,28 +348,36 @@ def _virtual_brownian_tree(terms):
is_vbt = lambda x: isinstance(x, VirtualBrownianTree)
leaves = jtu.tree_leaves(terms, is_leaf=is_vbt)
if any(is_vbt(leaf) for leaf in leaves):
vbt_ref, _ = _parse_reference_multi(VirtualBrownianTree)
vbt_ref, single_seed_ref, _ = _parse_reference_multi(VirtualBrownianTree)
return (
r"""
% You are simulating Brownian motion using a virtual Brownian tree, which was introduced
% in:
"""
+ vbt_ref
+ "\n\n"
+ single_seed_ref
)


@citation_rules.append
def _space_time_levy_area(terms):
has_levy_area = lambda x: isinstance(x, AbstractBrownianPath) and x.levy_area != ""
has_levy_area = (
lambda x: isinstance(x, AbstractBrownianPath)
and x.levy_area != BrownianIncrement
)
leaves = jtu.tree_leaves(terms, is_leaf=has_levy_area)
if any(has_levy_area(leaf) for leaf in leaves):
_, levy_area_ref = _parse_reference_multi(VirtualBrownianTree)
_, single_seed_ref, foster_ref = _parse_reference_multi(VirtualBrownianTree)
return (
r"""
% You are simulating Brownian motion using space-time Levy area, the formulae for which
% were developed in:
"""
+ levy_area_ref
% You are simulating Brownian motion using Lévy area,
% the formulae for which are due to:
"""
+ single_seed_ref
+ "\n\n"
+ r"""% and Theorem 6.1.6 of:"""
+ foster_ref
)


Expand Down
2 changes: 1 addition & 1 deletion diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def evaluate(
left-continuous or right-continuous at any jump points, but Brownian
motion has no jump points.)
- `use_levy`: If True, the return type will be a `LevyVal`, which contains
PyTrees of Brownian increments and their Levy areas.
PyTrees of Brownian increments and their Lévy areas.
**Returns:**
Expand Down
11 changes: 7 additions & 4 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class UnsafeBrownianPath(AbstractBrownianPath):
motion. Hence the restrictions above. (They describe the general case for which the
correlation structure isn't needed.)
!!! info "Levy Area"
!!! info "Lévy Area"
Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or
`diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it
Expand Down Expand Up @@ -147,22 +147,25 @@ def _evaluate_leaf(
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
key_w, key_hh, key_kk = jr.split(key, 3)
w = jr.normal(key, shape.shape, shape.dtype) * w_std
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))

if levy_area is SpaceTimeTimeLevyArea:
key_w, key_hh, key_kk = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
kk_std = w_std / math.sqrt(720)
kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std
levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk)

elif levy_area is SpaceTimeLevyArea:
key_w, key_hh = jr.split(key, 2)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh)
elif levy_area is BrownianIncrement:
w = jr.normal(key, shape.shape, shape.dtype) * w_std
levy_val = BrownianIncrement(dt=dt, W=w)
else:
assert False
Expand All @@ -180,6 +183,6 @@ def _evaluate_leaf(
be a tuple of integers, describing the shape of a single JAX array. In that case
the dtype is chosen to be the default floating-point dtype.
- `key`: A random key.
- `levy_area`: Whether to additionally generate Levy area. This is required by some SDE
- `levy_area`: Whether to additionally generate Lévy area. This is required by some SDE
solvers.
"""
Loading

0 comments on commit fa83ab4

Please sign in to comment.