diff --git a/benchmarks/brownian_tree_times.py b/benchmarks/brownian_tree_times.py index 66190f45..832e8502 100644 --- a/benchmarks/brownian_tree_times.py +++ b/benchmarks/brownian_tree_times.py @@ -1,6 +1,6 @@ """ v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is -additionally capable of computing Levy area. +additionally capable of computing Lévy area. Here we check the speed of the new implementation against the old implementation, to be sure that it is still fast. diff --git a/diffrax/_autocitation.py b/diffrax/_autocitation.py index 484d1ebb..547177ce 100644 --- a/diffrax/_autocitation.py +++ b/diffrax/_autocitation.py @@ -10,6 +10,7 @@ from ._adjoint import BacksolveAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint from ._brownian import AbstractBrownianPath, VirtualBrownianTree +from ._custom_types import BrownianIncrement from ._heuristics import is_cde, is_sde from ._integrate import diffeqsolve from ._misc import adjoint_rms_seminorm @@ -347,28 +348,36 @@ def _virtual_brownian_tree(terms): is_vbt = lambda x: isinstance(x, VirtualBrownianTree) leaves = jtu.tree_leaves(terms, is_leaf=is_vbt) if any(is_vbt(leaf) for leaf in leaves): - vbt_ref, _ = _parse_reference_multi(VirtualBrownianTree) + vbt_ref, single_seed_ref, _ = _parse_reference_multi(VirtualBrownianTree) return ( r""" % You are simulating Brownian motion using a virtual Brownian tree, which was introduced % in: """ + vbt_ref + + "\n\n" + + single_seed_ref ) @citation_rules.append def _space_time_levy_area(terms): - has_levy_area = lambda x: isinstance(x, AbstractBrownianPath) and x.levy_area != "" + has_levy_area = ( + lambda x: isinstance(x, AbstractBrownianPath) + and x.levy_area != BrownianIncrement + ) leaves = jtu.tree_leaves(terms, is_leaf=has_levy_area) if any(has_levy_area(leaf) for leaf in leaves): - _, levy_area_ref = _parse_reference_multi(VirtualBrownianTree) + _, single_seed_ref, foster_ref = _parse_reference_multi(VirtualBrownianTree) return ( r""" -% You are simulating Brownian motion using space-time Levy area, the formulae for which -% were developed in: -""" - + levy_area_ref + % You are simulating Brownian motion using Lévy area, + % the formulae for which are due to: + """ + + single_seed_ref + + "\n\n" + + r"""% and Theorem 6.1.6 of:""" + + foster_ref ) diff --git a/diffrax/_brownian/base.py b/diffrax/_brownian/base.py index 1642d315..21618b76 100644 --- a/diffrax/_brownian/base.py +++ b/diffrax/_brownian/base.py @@ -41,7 +41,7 @@ def evaluate( left-continuous or right-continuous at any jump points, but Brownian motion has no jump points.) - `use_levy`: If True, the return type will be a `LevyVal`, which contains - PyTrees of Brownian increments and their Levy areas. + PyTrees of Brownian increments and their Lévy areas. **Returns:** diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 09c9ba37..0333caa5 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -46,7 +46,7 @@ class UnsafeBrownianPath(AbstractBrownianPath): motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.) - !!! info "Levy Area" + !!! info "Lévy Area" Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or `diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it @@ -147,11 +147,11 @@ def _evaluate_leaf( use_levy: bool, ): w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) - key_w, key_hh, key_kk = jr.split(key, 3) - w = jr.normal(key, shape.shape, shape.dtype) * w_std dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype)) if levy_area is SpaceTimeTimeLevyArea: + key_w, key_hh, key_kk = jr.split(key, 3) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std kk_std = w_std / math.sqrt(720) @@ -159,10 +159,13 @@ def _evaluate_leaf( levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) elif levy_area is SpaceTimeLevyArea: + key_w, key_hh = jr.split(key, 2) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh) elif levy_area is BrownianIncrement: + w = jr.normal(key, shape.shape, shape.dtype) * w_std levy_val = BrownianIncrement(dt=dt, W=w) else: assert False @@ -180,6 +183,6 @@ def _evaluate_leaf( be a tuple of integers, describing the shape of a single JAX array. In that case the dtype is chosen to be the default floating-point dtype. - `key`: A random key. -- `levy_area`: Whether to additionally generate Levy area. This is required by some SDE +- `levy_area`: Whether to additionally generate Lévy area. This is required by some SDE solvers. """ diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index 2eb7f535..83259567 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -52,9 +52,9 @@ # author = {Foster, James M.}, # year = {2020} # } -# For more about space-time Levy area see Definition 4.2.1. -# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6. -# For the general interpolation rule for space-time Levy area see Theorem 6.1.4. +# For more about space-time Lévy area see Definition 4.2.1. +# For the midpoint rule for generating space-time Lévy area see Theorem 6.1.6. +# For the general interpolation rule for space-time Lévy area see Theorem 6.1.4. FloatDouble: TypeAlias = tuple[Inexact[Array, " *shape"], Inexact[Array, " *shape"]] FloatTriple: TypeAlias = tuple[ @@ -64,120 +64,109 @@ _BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement) -# An internal dataclass that holds the rescaled Levy areas +# An internal dataclass that holds the rescaled Lévy areas # in addition to the non-rescaled ones, for the purposes of -# taking the difference between two Levy areas. -class _AbstractLevyVal(eqx.Module): - dt: eqx.AbstractVar[Inexact[Array, ""]] - W: eqx.AbstractVar[Array] - - -class _BMLevyVal(_AbstractLevyVal): - dt: Inexact[Array, ""] +# taking the difference between two Lévy areas. +class _LevyVal(eqx.Module): + dt: RealScalarLike W: Array + H: Optional[Array] + bar_H: Optional[Array] + K: Optional[Array] + bar_K: Optional[Array] - -class _AbstractSpaceTimeLevyVal(_AbstractLevyVal): - H: eqx.AbstractVar[Array] - bar_H: eqx.AbstractVar[Array] - - -class _SpaceTimeLevyVal(_AbstractSpaceTimeLevyVal): - dt: Inexact[Array, ""] - W: Array - H: Array - bar_H: Array - - -class _SpaceTimeTimeLevyVal(_AbstractSpaceTimeLevyVal): - dt: Inexact[Array, ""] - W: Array - H: Array - bar_H: Array - K: Array - bar_K: Array + def __check_init__(self): + if self.H is None: + assert self.bar_H is None + assert self.K is None + if self.K is None: + assert self.bar_K is None class _State(eqx.Module): level: IntScalarLike # level of the tree - s: Inexact[Array, ""] # starting time of the interval + s: RealScalarLike # starting time of the interval w_s_u_su: FloatTriple # W_s, W_u, W_{s,u} key: PRNGKeyArray bhh_s_u_su: Optional[FloatTriple] # \bar{H}_s, _u, _{s,u} bkk_s_u_su: Optional[FloatTriple] # \bar{K}_s, _u, _{s,u} -def _levy_diff( - _, x0: _AbstractLevyVal, x1: _AbstractLevyVal -) -> AbstractBrownianIncrement: +def _levy_diff(_, x0: _LevyVal, x1: _LevyVal) -> AbstractBrownianIncrement: r"""Computes $(W_{s,u}, H_{s,u})$ from $(W_s, \bar{H}_{s,u})$ and $(W_u, \bar{H}_u)$, where $\bar{H}_u = u * H_u$. **Arguments:** - `_`: unused, for the purposes of aligning the `jtu.tree_map`. - - `x0`: `_AbstractLevyVal` at time `s`. - - `x1`: `_AbstractLevyVal` at time `u`. + - `x0`: `_LevyVal` at time `s`. + - `x1`: `_LevyVal` at time `u`. **Returns:** `AbstractBrownianIncrement(W_su, H_su, K_su)` """ - dtype = jnp.dtype(x0.W) + dtype = jnp.result_type(x0.W) tdtype = complex_to_real_dtype(dtype) - if isinstance(x0, _BMLevyVal): # BM only case - assert isinstance(x1, _BMLevyVal) - su_real = jnp.asarray(x1.dt - x0.dt, dtype=tdtype) - return BrownianIncrement(dt=su_real, W=x1.W - x0.W) + su = jnp.asarray(x1.dt - x0.dt, dtype=tdtype) + if x0.H is None: # BM only case + assert x1.H is None + return BrownianIncrement(dt=su, W=x1.W - x0.W) # the following computation is common to the space-time - # and the space-time-time Levy area case - assert isinstance(x0, _AbstractSpaceTimeLevyVal) - assert isinstance(x1, _AbstractSpaceTimeLevyVal) + # and the space-time-time Lévy area case + assert x0.H is not None + assert x1.H is not None + assert x0.bar_H is not None + assert x1.bar_H is not None w_su = x1.W - x0.W - su = jnp.asarray(x1.dt - x0.dt, dtype=dtype) - su_real = jnp.asarray(su, dtype=tdtype) inverse_su = 1 / jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su) - u_bb_s = x1.dt * x0.W - x0.dt * x1.W - bhh_su = x1.bar_H - x0.bar_H - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s) - hh_su = inverse_su * bhh_su - - if isinstance(x0, _SpaceTimeLevyVal): # space-time Levy area case - return SpaceTimeLevyArea(dt=su_real, W=w_su, H=hh_su) - - elif isinstance(x0, _SpaceTimeTimeLevyVal): # space-time-time Levy area case - assert isinstance(x1, _SpaceTimeTimeLevyVal) - bkk_su = ( - x1.bar_K - - x0.bar_K - - su / 2 * x0.bar_H - + x0.dt / 2 * bhh_su - - (x1.dt - 2 * x0.dt) / 12 * u_bb_s - ) - su2 = jnp.square(su) - inverse_su2 = 1 / jnp.where(jnp.abs(su2) < jnp.finfo(su2).eps, jnp.inf, su2) - kk_su = inverse_su2 * bkk_su - return SpaceTimeTimeLevyArea(dt=su_real, W=w_su, H=hh_su, K=kk_su) + with jax.numpy_dtype_promotion("standard"): + u_bb_s = x1.dt * x0.W - x0.dt * x1.W + bhh_su = x1.bar_H - x0.bar_H - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s) + hh_su = inverse_su * bhh_su + + if x0.K is None: # space-time Lévy area case + return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su) + + elif x0.K is not None: # space-time-time Lévy area case + assert x1.K is not None + assert x0.bar_K is not None + assert x1.bar_K is not None + with jax.numpy_dtype_promotion("standard"): + bkk_su = ( + x1.bar_K + - x0.bar_K + - (su / 2) * x0.bar_H + + (x0.dt / 2) * bhh_su + - ((x1.dt - 2 * x0.dt) / 12) * u_bb_s + ) + su2 = jnp.square(su) + inverse_su2 = 1 / jnp.where(jnp.abs(su2) < jnp.finfo(su2).eps, jnp.inf, su2) + kk_su = inverse_su2 * bkk_su + return SpaceTimeTimeLevyArea(dt=su, W=w_su, H=hh_su, K=kk_su) else: assert False -def _make_levy_val(_, x: _AbstractLevyVal) -> AbstractBrownianIncrement: +def _make_levy_val(_, x: _LevyVal) -> AbstractBrownianIncrement: tdtype = complex_to_real_dtype(x.W) dt = jnp.asarray(x.dt, dtype=tdtype) - if isinstance(x, _BMLevyVal): + if x.H is None: return BrownianIncrement(dt=dt, W=x.W) - elif isinstance(x, _SpaceTimeLevyVal): + elif x.K is None: return SpaceTimeLevyArea(dt=dt, W=x.W, H=x.H) - elif isinstance(x, _SpaceTimeTimeLevyVal): - return SpaceTimeTimeLevyArea(dt=dt, W=x.W, H=x.H, K=x.K) else: - assert False + return SpaceTimeTimeLevyArea(dt=dt, W=x.W, H=x.H, K=x.K) def _split_interval( - pred: BoolScalarLike, x_stu: FloatTriple, x_st_tu: FloatDouble -) -> FloatTriple: + pred: BoolScalarLike, x_stu: Optional[FloatTriple], x_st_tu: Optional[FloatDouble] +) -> Optional[FloatTriple]: + if x_stu is None: + assert x_st_tu is None + return None + assert x_st_tu is not None x_s, x_t, x_u = x_stu x_st, x_tu = x_st_tu x_s = jnp.where(pred, x_t, x_s) @@ -189,15 +178,16 @@ def _split_interval( class VirtualBrownianTree(AbstractBrownianPath): """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`. - !!! info "Levy Area" + !!! info "Lévy Area" + The parameter `levy_area` can be set to one of: + + - [`diffrax.BrownianIncrement`][] (default, generates the increment of W) + - [`diffrax.SpaceTimeLevyArea`][] (generates W and the space-time Lévy area H) + - [`diffrax.SpaceTimeTimeLevyArea`][] (generates W, H and the space-time-time + Lévy area K) - The parameter `levy_area` can be set to one of: - - [`diffrax.BrownianIncrement`][] (default, generates the increment of W) - - [`diffrax.SpaceTimeLevyArea`][] (generates W and the space-time Levy area H) - - [`diffrax.SpaceTimeTimeLevyArea`][] (generates W, H and the space-time-time - Levy area K) - The choice of `levy_area` will impact the Brownian path, so even with the same - key, the trajectory will be different depending on the value of `levy_area`. + The choice of `levy_area` will impact the Brownian path, so even with the same + key, the trajectory will be different depending on the value of `levy_area`. ??? cite "Reference" @@ -213,11 +203,21 @@ class VirtualBrownianTree(AbstractBrownianPath): ``` The implementation here is an improvement on the above, in that it additionally - simulates space-time and space-time-time Levy areas. This is due to the paper + simulates space-time and space-time-time Lévy areas, and exactly matches the + distribution of the Brownian motion and its Lévy areas at all query times. + This is due to the paper - "Single-seed generation of Brownian paths and integrals - for adaptive and high order SDE solvers" - TODO: add the paper bitex + ```bibtex + @misc{jelinčič2024singleseed, + title={Single-seed generation of Brownian paths and integrals + for adaptive and high order SDE solvers}, + author={Andraž Jelinčič and James Foster and Patrick Kidger}, + year={2024}, + eprint={2405.06464}, + archivePrefix={arXiv}, + primaryClass={math.NA} + } + ``` and Theorem 6.1.6 of @@ -230,9 +230,6 @@ class VirtualBrownianTree(AbstractBrownianPath): year = {2020} } ``` - - In addition, the implementation here is a further improvement on these by using - an interpolation method which ensures the conditional 2nd moments are correct. """ t0: RealScalarLike @@ -346,11 +343,11 @@ def _evaluate_leaf( key, r: RealScalarLike, struct: jax.ShapeDtypeStruct, - ) -> _AbstractLevyVal: + ) -> _LevyVal: shape, dtype = struct.shape, struct.dtype - - t0 = jnp.zeros((), dtype) - r = jnp.asarray(r, dtype) + tdtype = complex_to_real_dtype(dtype) + t0 = jnp.zeros((), tdtype) + r = jnp.asarray(r, tdtype) if self.levy_area is SpaceTimeTimeLevyArea: state_key, init_key_w, init_key_hh, init_key_kk = jr.split(key, 4) @@ -395,7 +392,7 @@ def _body_fun(_state: _State): ( _t, _w_stu, - _w_inc, + _w_st_tu, _keys, _bhh_stu, _bhh_st_tu, @@ -409,22 +406,10 @@ def _body_fun(_state: _State): _key_st, _key_tu = _keys _key = jnp.where(_cond, _key_st, _key_tu) - _w = _split_interval(_cond, _w_stu, _w_inc) - - if self.levy_area is SpaceTimeTimeLevyArea: - assert _bkk_stu is not None and _bkk_st_tu is not None - _bkk = _split_interval(_cond, _bkk_stu, _bkk_st_tu) - else: - _bkk = None - - if ( - self.levy_area is SpaceTimeLevyArea - or self.levy_area is SpaceTimeTimeLevyArea - ): - assert _bhh_stu is not None and _bhh_st_tu is not None - _bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu) - else: - _bhh = None + _w = _split_interval(_cond, _w_stu, _w_st_tu) + assert _w is not None + _bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu) + _bkk = _split_interval(_cond, _bkk_stu, _bkk_st_tu) return _State( level=_level, @@ -438,7 +423,7 @@ def _body_fun(_state: _State): final_state = lax.while_loop(_cond_fun, _body_fun, init_state) s = final_state.s - su = jnp.asarray(2.0, dtype=dtype) ** -final_state.level + su = jnp.asarray(2.0**-final_state.level, dtype=tdtype) sr = jax.nn.relu(r - s) # make sure su = sr + ru regardless of cancellation error @@ -466,21 +451,28 @@ def _body_fun(_state: _State): # compute the mean of (W_sr, H_sr, K_sr) conditioned on # (W_s, H_s, K_s, W_u, H_u, K_u) - bb_mean = (6 * sr_ru_by_su2 / su) * bhh_su + ( - 120 * sr_ru_by_su2 * (0.5 - sr_by_su) / su2 - ) * bkk_su - w_mean = sr_by_su * w_su + bb_mean - h_mean = (sr_by_su**2 / su) * (bhh_su + (30 * ru_by_su / su) * bkk_su) - k_mean = (sr_by_su_3 / su2) * bkk_su - - # compute the covariance matrix of (W_sr, H_sr, K_sr) conditioned on - # (W_s, H_s, K_s, W_u, H_u, K_u) - ww_cov = (sr_by_su * ru_by_su * ((sr - ru) ** 4 + 4 * (sr2 * ru2))) / su3 - wh_cov = -(sr_by_su_3 * ru_by_su * (sr2 - 3 * sr * ru + 6 * ru2)) / (2 * su) - wk_cov = (sr_by_su**4) * ru_by_su * (sr - ru) / 12 - hh_cov = sr / 12 * (1 - sr_by_su_3 * (sr2 + 2 * sr * ru + 16 * ru2) / su2) - hk_cov = -(ru / 24) * sr_by_su_5 - kk_cov = (sr / 720) * (1.0 - sr_by_su_5) + with jax.numpy_dtype_promotion("standard"): + bb_mean = (6 * sr_ru_by_su2 / su) * bhh_su + ( + 120 * sr_ru_by_su2 * (0.5 - sr_by_su) / su2 + ) * bkk_su + w_mean = sr_by_su * w_su + bb_mean + h_mean = (sr_by_su**2 / su) * (bhh_su + (30 * ru_by_su / su) * bkk_su) + k_mean = (sr_by_su_3 / su2) * bkk_su + + # compute the covariance matrix of (W_sr, H_sr, K_sr) conditioned on + # (W_s, H_s, K_s, W_u, H_u, K_u) + ww_cov = ( + sr_by_su * ru_by_su * ((sr - ru) ** 4 + 4 * (sr2 * ru2)) + ) / su3 + wh_cov = -(sr_by_su_3 * ru_by_su * (sr2 - 3 * sr * ru + 6 * ru2)) / ( + 2 * su + ) + wk_cov = (sr_by_su**4) * ru_by_su * (sr - ru) / 12 + hh_cov = (sr / 12) * ( + 1 - sr_by_su_3 * (sr2 + 2 * sr * ru + 16 * ru2) / su2 + ) + hk_cov = -(ru / 24) * sr_by_su_5 + kk_cov = (sr / 720) * (1.0 - sr_by_su_5) cov = jnp.array( [ @@ -491,15 +483,20 @@ def _body_fun(_state: _State): ) if self._spline == "sqrt": - # NOTE: not compatible with jnp.float16 + # NOTE: jr.multivariate_normal is not compatible with jnp.float16, + # so we need to cast to jnp.float32 before calling it. + with jax.numpy_dtype_promotion("standard"): + dtype_atleast32 = jnp.result_type(dtype, jnp.float32) + cov = jnp.asarray(cov, dtype_atleast32) hat_y = jr.multivariate_normal( final_state.key, - jnp.zeros((3,), dtype), + jnp.zeros((3,), dtype_atleast32), cov, shape=shape, - dtype=dtype, + dtype=dtype_atleast32, method="svd", ) + hat_y = jnp.asarray(hat_y, dtype) elif self._spline == "zero": hat_y = jnp.zeros(shape=shape + (3,), dtype=dtype) @@ -517,27 +514,26 @@ def _body_fun(_state: _State): w_sr = w_mean + hat_w_sr w_r = w_s + w_sr - r_bb_s = r * w_s - s * w_r + with jax.numpy_dtype_promotion("standard"): + r_bb_s = r * w_s - s * w_r - bhh_sr = sr * (h_mean + hat_hh_sr) - bhh_r = bhh_s + bhh_sr + 0.5 * r_bb_s + bhh_sr = sr * (h_mean + hat_hh_sr) + bhh_r = bhh_s + bhh_sr + 0.5 * r_bb_s - bkk_sr = sr2 * (k_mean + hat_kk_sr) - bkk_r = ( - bkk_s - + bkk_sr - + sr / 2 * bhh_s - - s / 2 * bhh_sr - + (r - 2 * s) / 12 * r_bb_s - ) + bkk_sr = sr2 * (k_mean + hat_kk_sr) + bkk_r = ( + bkk_s + + bkk_sr + + (sr / 2) * bhh_s + - (s / 2) * bhh_sr + + ((r - 2 * s) / 12) * r_bb_s + ) - inverse_r = 1 / jnp.where(jnp.square(r) < jnp.finfo(r).eps, jnp.inf, r) - hh_r = inverse_r * bhh_r - kk_r = inverse_r**2 * bkk_r + inverse_r = 1 / jnp.where(jnp.square(r) < jnp.finfo(r).eps, jnp.inf, r) + hh_r = inverse_r * bhh_r + kk_r = inverse_r**2 * bkk_r - return _SpaceTimeTimeLevyVal( - dt=r, W=w_r, H=hh_r, bar_H=bhh_r, K=kk_r, bar_K=bkk_r - ) + return _LevyVal(dt=r, W=w_r, H=hh_r, bar_H=bhh_r, K=kk_r, bar_K=bkk_r) elif self.levy_area is SpaceTimeLevyArea: # This is based on Theorem 6.1.4 of Foster's thesis (see above). @@ -569,31 +565,35 @@ def _body_fun(_state: _State): a = d_prime * sr3 * sr_ru_half b = d_prime * ru3 * sr_ru_half - w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1 - w_r = w_s + w_sr - c = jnp.sqrt(3 * sr3 * ru3) / (6 * d) - bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2 - bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r) + with jax.numpy_dtype_promotion("standard"): + w_sr = ( + sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1 + ) + w_r = w_s + w_sr + c = jnp.sqrt(3 * sr3 * ru3) / (6 * d) + bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2 + bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r) - inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r) - hh_r = inverse_r * bhh_r + inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r) + hh_r = inverse_r * bhh_r - return _SpaceTimeLevyVal(dt=r, W=w_r, H=hh_r, bar_H=bhh_r) + return _LevyVal(dt=r, W=w_r, H=hh_r, bar_H=bhh_r, K=None, bar_K=None) elif self.levy_area is BrownianIncrement: - w_mean = w_s + sr / su * w_su - if self._spline == "sqrt": - z = jr.normal(final_state.key, shape, dtype) - bb = jnp.sqrt(sr * ru / su) * z - elif self._spline == "quad": - z = jr.normal(final_state.key, shape, dtype) - bb = (sr * ru / su) * z - elif self._spline == "zero": - bb = jnp.zeros(shape, dtype) - else: - assert False + with jax.numpy_dtype_promotion("standard"): + w_mean = w_s + sr / su * w_su + if self._spline == "sqrt": + z = jr.normal(final_state.key, shape, dtype) + bb = jnp.sqrt(sr * ru / su) * z + elif self._spline == "quad": + z = jr.normal(final_state.key, shape, dtype) + bb = (sr * ru / su) * z + elif self._spline == "zero": + bb = jnp.zeros(shape, dtype) + else: + assert False w_r = w_mean + bb - return _BMLevyVal(dt=r, W=w_r) + return _LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None) else: assert False @@ -601,7 +601,7 @@ def _body_fun(_state: _State): def _brownian_arch( self, _state: _State, shape, dtype ) -> tuple[ - Inexact[Array, ""], + RealScalarLike, FloatTriple, FloatDouble, tuple[PRNGKeyArray, PRNGKeyArray], @@ -617,7 +617,7 @@ def _brownian_arch( and also returns `w_st` and `w_tu` in addition to just `w_t`. Same for `bhh` if it is not None. Note that the inputs and outputs already contain `bkk`. These values are - there for the sake of a future extension with "space-time-time" Levy area + there for the sake of a future extension with "space-time-time" Lévy area and should be None for now. **Arguments:** @@ -636,110 +636,113 @@ def _brownian_arch( - `bhh_st_tu`: (optional) $(\bar{H}_{s,t}, \bar{H}_{t,u})$ - `bkk_stu`: (optional) $(\bar{K}_s, \bar{K}_t, \bar{K}_u)$ - `bkk_st_tu`: (optional) $(\bar{K}_{s,t}, \bar{K}_{t,u})$ + """ key_st, midpoint_key, key_tu = jr.split(_state.key, 3) keys = (key_st, key_tu) - su = 2.0**-_state.level + tdtype = complex_to_real_dtype(dtype) + su = jnp.asarray(2.0**-_state.level, dtype=tdtype) st = su / 2 - s = _state.s + s = jnp.asarray(_state.s, dtype=tdtype) t = s + st root_su = jnp.sqrt(su) w_s, w_u, w_su = _state.w_s_u_su - if self.levy_area is SpaceTimeTimeLevyArea: - assert _state.bhh_s_u_su is not None - assert _state.bkk_s_u_su is not None - - bhh_s, bhh_u, bhh_su = _state.bhh_s_u_su - bkk_s, bkk_u, bkk_su = _state.bkk_s_u_su - - z1_key, z2_key, z3_key = jr.split(midpoint_key, 3) - z1 = jr.normal(z1_key, shape, dtype) - z2 = jr.normal(z2_key, shape, dtype) - z3 = jr.normal(z3_key, shape, dtype) - - z = z1 * jnp.sqrt(su / 16) - x1 = z2 * jnp.sqrt(su / 768) - x2 = z3 * jnp.sqrt(su / 2880) - - su2 = su**2 - - w_term1 = w_su / 2 - w_term2 = 3 / (2 * su) * bhh_su + z - w_st = w_term1 + w_term2 - w_tu = w_term1 - w_term2 - bhh_term1 = bhh_su / 8 - st / 2 * z - bhh_term2 = 15 / (8 * su) * bkk_su + st * x1 - bhh_st = bhh_term1 + bhh_term2 - bhh_tu = bhh_term1 - bhh_term2 - bkk_term1 = bkk_su / 32 - (su2 / 8) * x1 - bkk_term2 = su2 / 4 * x2 - bkk_st = bkk_term1 + bkk_term2 - bkk_tu = bkk_term1 - bkk_term2 - w_st_tu = (w_st, w_tu) - bhh_st_tu = (bhh_st, bhh_tu) - bkk_st_tu = (bkk_st, bkk_tu) - - w_t = w_s + w_st - t_bb_s = t * w_s - s * w_t - bhh_t = bhh_s + bhh_st + t_bb_s / 2 - bkk_t = ( - bkk_s - + bkk_st - + st / 2 * bhh_s - - s / 2 * bhh_st - + (t - 2 * s) / 12 * t_bb_s - ) - - w_stu = (w_s, w_t, w_u) - bhh_stu = (bhh_s, bhh_t, bhh_u) - bkk_stu = (bkk_s, bkk_t, bkk_u) - - elif self.levy_area is SpaceTimeLevyArea: - assert _state.bhh_s_u_su is not None - assert _state.bkk_s_u_su is None - bhh_s, bhh_u, bhh_su = _state.bhh_s_u_su - - z1_key, z2_key = jr.split(midpoint_key, 2) - z1 = jr.normal(z1_key, shape, dtype) - z2 = jr.normal(z2_key, shape, dtype) - z = z1 * (root_su / 4) - n = z2 * jnp.sqrt(su / 12) - - w_term1 = w_su / 2 - w_term2 = 3 / (2 * su) * bhh_su + z - w_st = w_term1 + w_term2 - w_tu = w_term1 - w_term2 - w_st_tu = (w_st, w_tu) - - bhh_term1 = bhh_su / 8 - su / 4 * z - bhh_term2 = su / 4 * n - bhh_st = bhh_term1 + bhh_term2 - bhh_tu = bhh_term1 - bhh_term2 - bhh_st_tu = (bhh_st, bhh_tu) - - w_t = w_s + w_st - w_stu = (w_s, w_t, w_u) - - bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t) - bhh_stu = (bhh_s, bhh_t, bhh_u) - bkk_stu = None - bkk_st_tu = None + with jax.numpy_dtype_promotion("standard"): + if self.levy_area is SpaceTimeTimeLevyArea: + assert _state.bhh_s_u_su is not None + assert _state.bkk_s_u_su is not None + + bhh_s, bhh_u, bhh_su = _state.bhh_s_u_su + bkk_s, bkk_u, bkk_su = _state.bkk_s_u_su + + z1_key, z2_key, z3_key = jr.split(midpoint_key, 3) + z1 = jr.normal(z1_key, shape, dtype) + z2 = jr.normal(z2_key, shape, dtype) + z3 = jr.normal(z3_key, shape, dtype) + + z = z1 * jnp.sqrt(su / 16) + x1 = z2 * jnp.sqrt(su / 768) + x2 = z3 * jnp.sqrt(su / 2880) + + su2 = su**2 + + w_term1 = w_su / 2 + w_term2 = (3 / (2 * su)) * bhh_su + z + w_st = w_term1 + w_term2 + w_tu = w_term1 - w_term2 + bhh_term1 = bhh_su / 8 - (st / 2) * z + bhh_term2 = (15 / (8 * su)) * bkk_su + st * x1 + bhh_st = bhh_term1 + bhh_term2 + bhh_tu = bhh_term1 - bhh_term2 + bkk_term1 = bkk_su / 32 - (su2 / 8) * x1 + bkk_term2 = (su2 / 4) * x2 + bkk_st = bkk_term1 + bkk_term2 + bkk_tu = bkk_term1 - bkk_term2 + w_st_tu = (w_st, w_tu) + bhh_st_tu = (bhh_st, bhh_tu) + bkk_st_tu = (bkk_st, bkk_tu) + + w_t = w_s + w_st + t_bb_s = t * w_s - s * w_t + bhh_t = bhh_s + bhh_st + t_bb_s / 2 + bkk_t = ( + bkk_s + + bkk_st + + (st / 2) * bhh_s + - (s / 2) * bhh_st + + ((t - 2 * s) / 12) * t_bb_s + ) - elif self.levy_area is BrownianIncrement: - assert _state.bhh_s_u_su is None - assert _state.bkk_s_u_su is None - mean = 0.5 * w_su - w_term2 = root_su / 2 * jr.normal(midpoint_key, shape, dtype) - w_st = mean + w_term2 - w_tu = mean - w_term2 - w_st_tu = (w_st, w_tu) - w_t = w_s + w_st - w_stu = (w_s, w_t, w_u) - bhh_stu, bhh_st_tu, bkk_stu, bkk_st_tu = None, None, None, None + w_stu = (w_s, w_t, w_u) + bhh_stu = (bhh_s, bhh_t, bhh_u) + bkk_stu = (bkk_s, bkk_t, bkk_u) + + elif self.levy_area is SpaceTimeLevyArea: + assert _state.bhh_s_u_su is not None + assert _state.bkk_s_u_su is None + bhh_s, bhh_u, bhh_su = _state.bhh_s_u_su + + z1_key, z2_key = jr.split(midpoint_key, 2) + z1 = jr.normal(z1_key, shape, dtype) + z2 = jr.normal(z2_key, shape, dtype) + z = z1 * (root_su / 4) + n = z2 * jnp.sqrt(su / 12) + + w_term1 = w_su / 2 + w_term2 = (3 / (2 * su)) * bhh_su + z + w_st = w_term1 + w_term2 + w_tu = w_term1 - w_term2 + w_st_tu = (w_st, w_tu) + + bhh_term1 = bhh_su / 8 - su / 4 * z + bhh_term2 = (su / 4) * n + bhh_st = bhh_term1 + bhh_term2 + bhh_tu = bhh_term1 - bhh_term2 + bhh_st_tu = (bhh_st, bhh_tu) + + w_t = w_s + w_st + w_stu = (w_s, w_t, w_u) + + bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t) + bhh_stu = (bhh_s, bhh_t, bhh_u) + bkk_stu = None + bkk_st_tu = None + + elif self.levy_area is BrownianIncrement: + assert _state.bhh_s_u_su is None + assert _state.bkk_s_u_su is None + mean = 0.5 * w_su + w_term2 = (root_su / 2) * jr.normal(midpoint_key, shape, dtype) + w_st = mean + w_term2 + w_tu = mean - w_term2 + w_st_tu = (w_st, w_tu) + w_t = w_s + w_st + w_stu = (w_s, w_t, w_u) + bhh_stu, bhh_st_tu, bkk_stu, bkk_st_tu = None, None, None, None - else: - assert False + else: + assert False return t, w_stu, w_st_tu, keys, bhh_stu, bhh_st_tu, bkk_stu, bkk_st_tu diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index 494ceacd..a74f278c 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -48,8 +48,7 @@ class AbstractStochasticCoeffs(eqx.Module): b_error: eqx.AbstractVar[Optional[Float[np.ndarray, " s"]]] @abc.abstractmethod - def check(self) -> int: - ... + def check(self) -> int: ... class AdditiveCoeffs(AbstractStochasticCoeffs): @@ -143,8 +142,8 @@ def __post_init__(self): assert self.coeffs_hh.check() == num_stages if self.coeffs_kk is not None: assert self.coeffs_hh is not None, ( - "If space-time-time Levy area (K) is used," - " space-time Levy area (H) must also be used." + "If space-time-time Lévy area (K) is used," + " space-time Lévy area (H) must also be used." ) assert type(self.coeffs_kk) is type(self.coeffs_w) assert self.coeffs_kk.check() == num_stages @@ -189,9 +188,9 @@ def __post_init__(self): - `coeffs_w`: An instance of `AdditiveCoeffs` or `GeneralCoeffs`, providing the coefficients of the Brownian motion increments. - `coeffs_hh`: An instance of `AdditiveCoeffs` or `GeneralCoeffs`, providing the - coefficients of the space-time Levy area. + coefficients of the space-time Lévy area. - `coeffs_kk`: An instance of `AdditiveCoeffs` or `GeneralCoeffs`, providing the - coefficients of the space-time-time Levy area. + coefficients of the space-time-time Lévy area. - `ignore_stage_f`: Optional. A NumPy array of length `s` of booleans. If `True` at stage `j`, the vector field of the drift term will not be evaluated at stage `j`. - `ignore_stage_g`: Optional. A NumPy array of length `s` of booleans. If `True` at @@ -206,9 +205,9 @@ class AbstractSRK(AbstractSolver[_SolverState]): `MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))` or `MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))`. Depending on the solver, the Brownian motion might need to generate - different types of Levy areas, specified by the `minimal_levy_area` attribute. + different types of Lévy areas, specified by the `minimal_levy_area` attribute. - For example, the [`diffrax.ShARK`][] solver requires space-time Levy area, so + For example, the [`diffrax.ShARK`][] solver requires space-time Lévy area, so it will have `minimal_levy_area = AbstractSpaceTimeLevyArea` and the Brownian motion must be initialised with `levy_area=SpaceTimeLevyArea`. @@ -260,10 +259,10 @@ class AbstractSRK(AbstractSolver[_SolverState]): term_compatible_contr_kwargs = (dict(), dict(use_levy=True)) tableau: AbstractClassVar[StochasticButcherTableau] - # Indicates the type of Levy area used by the solver. - # The BM must generate at least this type of Levy area, but can generate - # more. E.g. if the solver uses space-time Levy area, then the BM generates - # space-time-time Levy area as well that is fine. The other way around would + # Indicates the type of Lévy area used by the solver. + # The BM must generate at least this type of Lévy area, but can generate + # more. E.g. if the solver uses space-time Lévy area, then the BM generates + # space-time-time Lévy area as well that is fine. The other way around would # not work. This is mostly an easily readable indicator so that methods know # what kind of BM to use. @property @@ -288,7 +287,7 @@ def init( args: PyTree, ) -> _SolverState: del t1 - # Check that the diffusion has the correct Levy area + # Check that the diffusion has the correct Lévy area _, diffusion = terms.terms if self.tableau.is_additive_noise(): @@ -372,12 +371,12 @@ def make_zeros_aux(leaf): b_levy_list = [] levy_areas = [] - if self.tableau.coeffs_hh is not None: # space-time Levy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) levy_areas.append(bm_inc.H) b_levy_list.append(jnp.asarray(self.tableau.coeffs_hh.b_sol, dtype=dtype)) - if self.tableau.coeffs_kk is not None: # space-time-time Levy area + if self.tableau.coeffs_kk is not None: # space-time-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeTimeLevyArea) levy_areas.append(bm_inc.K) b_levy_list.append( @@ -397,7 +396,7 @@ def aux_add_levy(w_leaf, *levy_leaves): levylist_kgs = [] # will contain levy * g(t0 + c_j * h, z_j) for each stage j # where levy is either H or K (if those entries exist) - # this is similar to h_kfs or w_kgs, but for the Levy area(s) + # this is similar to h_kfs or w_kgs, but for the Lévy area(s) if self.tableau.is_additive_noise(): # additive noise # compute g once since it is constant @@ -413,12 +412,12 @@ def _comp_g(_t): w_kgs = diffusion.prod(g0, w) a_w = jnp.asarray(self.tableau.coeffs_w.a, dtype=dtype) - if self.tableau.coeffs_hh is not None: # space-time Levy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) levylist_kgs.append(diffusion.prod(g0, bm_inc.H)) a_levy.append(jnp.asarray(self.tableau.coeffs_hh.a, dtype=dtype)) - if self.tableau.coeffs_kk is not None: # space-time-time Levy area + if self.tableau.coeffs_kk is not None: # space-time-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeTimeLevyArea) levylist_kgs.append(diffusion.prod(g0, bm_inc.K)) a_levy.append(jnp.asarray(self.tableau.coeffs_kk.a, dtype=dtype)) @@ -436,11 +435,11 @@ def _comp_g(_t): w_kgs = make_zeros() a_w = self._embed_a_lower(self.tableau.coeffs_w.a, dtype) - # do the same for each type of Levy area - if self.tableau.coeffs_hh is not None: # space-time Levy area + # do the same for each type of Lévy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area levylist_kgs.append(make_zeros()) a_levy.append(self._embed_a_lower(self.tableau.coeffs_hh.a, dtype)) - if self.tableau.coeffs_kk is not None: # space-time-time Levy area + if self.tableau.coeffs_kk is not None: # space-time-time Lévy area levylist_kgs.append(make_zeros()) a_levy.append(self._embed_a_lower(self.tableau.coeffs_kk.a, dtype)) @@ -581,7 +580,7 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in): # In the additive noise case (i.e. when g is independent of y), # we still need a correction term in case the diffusion vector field # g depends on t. This term is of the form $(g1 - g0) * (0.5*W_n - H_n)$. - if self.tableau.coeffs_hh is not None: # space-time Levy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) time_var_contr = (bm_inc.W**ω - 2.0 * bm_inc.H**ω).ω time_var_term = diffusion.prod(g_delta, time_var_contr) diff --git a/docs/api/brownian.md b/docs/api/brownian.md index c843f005..6d9719bf 100644 --- a/docs/api/brownian.md +++ b/docs/api/brownian.md @@ -22,3 +22,38 @@ SDEs are simulated using a Brownian motion as a control. (See the neural SDE exa members: - __init__ - evaluate + +--- + +## Lévy areas + +Brownian controls can return certain types of Lévy areas. These are iterated integrals +of the Brownian motion, and are used by some SDE solvers. When a solver requires a +Lévy area, it will have a `minimal_levy_area` attribute, which will always return an +abstract Lévy area type, and it can accept any subclass of that type. +The inheritance hierarchy is as follows: +``` +AbstractBrownianIncrement +│ └── BrownianIncrement +└── AbstractSpaceTimeLevyArea + │ └── SpaceTimeLevyArea + └── AbstractSpaceTimeTimeLevyArea + └── SpaceTimeTimeLevyArea +``` +For example if `solver.minimal_levy_area` returns an `AbstractSpaceTimeLevyArea`, then +the Brownian motion (which is either an `UnsafeBrownianPath` or +a `VirtualBrownianTree`) should be initialized with `levy_area=SpaceTimeLevyArea` or +`levy_area=SpaceTimeTimeLevyArea`. Note that for the Brownian motion, +a concrete class must be used, not its abstract parent. + +::: diffrax.AbstractBrownianIncrement + +::: diffrax.BrownianIncrement + +::: diffrax.AbstractSpaceTimeLevyArea + +::: diffrax.SpaceTimeLevyArea + +::: diffrax.AbstractSpaceTimeTimeLevyArea + +::: diffrax.SpaceTimeTimeLevyArea \ No newline at end of file diff --git a/docs/devdocs/srk_example.ipynb b/docs/devdocs/srk_example.ipynb index f1905822..aa2a38cf 100644 --- a/docs/devdocs/srk_example.ipynb +++ b/docs/devdocs/srk_example.ipynb @@ -19,11 +19,11 @@ "\n", "To account for time-dependent noise, the SRK adds a term to the output of each step, which allows it to still maintain its usual strong order of convergence.\n", "\n", - "The SRK is capable of utilising various types of time Levy area, depending on the tableau provided. It can use:\n", + "The SRK is capable of utilising various types of time Lévy area, depending on the tableau provided. It can use:\n", "\n", - "- just the Brownian motion $W$, without any Levy area\n", - "- $W$ and the space-time Levy area $H$\n", - "- $W$, $H$ and the space-time-time Levy area $K$.\n", + "- just the Brownian motion $W$, without any Lévy area\n", + "- $W$ and the space-time Lévy area $H$\n", + "- $W$, $H$ and the space-time-time Lévy area $K$.\n", "For more information see the documentation of the `StochasticButcherTableau` class.\n", "\n", "First we will demonstrate an additive-noise-only SRK method, the ShARK method, on an SDE with additive, time-dependent noise.\n", @@ -263,7 +263,7 @@ " \n", "\n", "### Shifted Additive-noise Euler (SEA)\n", - "This variant of the Euler-Maruyama makes use of the space-time Levy area, which improves its local error to $O(h^2)$ compared to $O(h^{1.5})$ of the standard Euler-Maruyama. Nevertheless, it has a strong order of only 1 for additive-noise SDEs.\n", + "This variant of the Euler-Maruyama makes use of the space-time Lévy area, which improves its local error to $O(h^2)$ compared to $O(h^{1.5})$ of the standard Euler-Maruyama. Nevertheless, it has a strong order of only 1 for additive-noise SDEs.\n", "\n", "\n", " ### The \"Splitting Path Runge-Kutta\" (SPaRK) method\n", @@ -277,7 +277,7 @@ "When the noise is commutative it has order 1.\n", "When the noise is additive it has order 1.5.\n", "For the Langevin SDE it has order 2.\n", - "Requires the space-time Levy area H.\n", + "Requires the space-time Lévy area H.\n", "It also natively supports adaptive time-stepping.\n", "\n", "\n", diff --git a/test/helpers.py b/test/helpers.py index b7e33b54..fd4097dd 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -159,7 +159,7 @@ def _batch_sde_solve( shape=struct, tol=bm_tol, key=key, - levy_area=concrete_la, # pyright: ignore + levy_area=concrete_la, ) terms = get_terms(bm) if controller is None: @@ -183,12 +183,12 @@ def _resulting_levy_area( levy_area1: type[diffrax.AbstractBrownianIncrement], levy_area2: type[diffrax.AbstractBrownianIncrement], ) -> type[diffrax.AbstractBrownianIncrement]: - """A helper that returns the stricter Levy area. + """A helper that returns the stricter Lévy area. **Arguments:** - - `levy_area1`: The first Levy area type. - - `levy_area2`: The second Levy area type. + - `levy_area1`: The first Lévy area type. + - `levy_area2`: The second Lévy area type. **Returns:** diff --git a/test/test_brownian.py b/test/test_brownian.py index 55edaaf2..0c91ad57 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -13,7 +13,7 @@ import scipy.stats as stats -levy_areas = ( +_levy_areas = ( diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea, diffrax.SpaceTimeTimeLevyArea, @@ -36,13 +36,13 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -@pytest.mark.parametrize("levy_area", levy_areas) +@pytest.mark.parametrize("levy_area", _levy_areas) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): t0 = 0.0 t1 = 2.0 - shapes_dtypes = ( + shapes_dtypes1 = ( ((), None), ((0,), None), ((1, 0), None), @@ -50,8 +50,11 @@ def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): ((1, 2, 3, 4), jnp.float64), ({"a": (1,), "b": (2, 3)}, {"a": None, "b": jnp.float64}), ((2,), jnp.float16), - ((1, 2, 3, 4), jnp.complex128), (((1, 2), ((3, 4), (5, 6))), (jnp.float16, (jnp.float32, jnp.float64))), + ) + + shapes_dtypes2 = ( + ((1, 2, 3, 4), jnp.complex128), ({"a": (1,), "b": (2, 3)}, {"a": jnp.float64, "b": jnp.complex128}), ) @@ -59,9 +62,12 @@ def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): ctr is diffrax.VirtualBrownianTree and levy_area is diffrax.SpaceTimeTimeLevyArea ): - # VBT with STTLA does not support float16 or complex dtypes + # VBT with STTLA does not support complex dtypes # because it uses jax.random.multivariate_normal - shapes_dtypes = shapes_dtypes[:6] + shapes_dtypes = shapes_dtypes1 + + else: + shapes_dtypes = shapes_dtypes1 + shapes_dtypes2 def is_tuple_of_ints(obj): return isinstance(obj, tuple) and all(isinstance(x, int) for x in obj) @@ -112,7 +118,7 @@ def is_tuple_of_ints(obj): @pytest.mark.parametrize( "ctr", [diffrax.VirtualBrownianTree, diffrax.UnsafeBrownianPath] ) -@pytest.mark.parametrize("levy_area", levy_areas) +@pytest.mark.parametrize("levy_area", _levy_areas) @pytest.mark.parametrize("use_levy", (True, False)) def test_statistics(ctr, levy_area, use_levy): # Deterministic key for this test; not using getkey() @@ -174,7 +180,7 @@ def _eval(key): assert pval > 0.1 -def true_cond_stats_wh(bm_s, bm_u, s, r, u): +def _true_cond_stats_wh(bm_s, bm_u, s, r, u): w_s = bm_s.W w_u = bm_u.W h_s = bm_s.H @@ -205,7 +211,7 @@ def true_cond_stats_wh(bm_s, bm_u, s, r, u): return w_mean, w_std, h_mean, h_std -def true_cond_stats_whk(bm_s, bm_u, s, r, u): +def _true_cond_stats_whk(bm_s, bm_u, s, r, u): su = u - s sr = r - s ru = u - r @@ -270,7 +276,7 @@ def true_cond_stats_whk(bm_s, bm_u, s, r, u): return mean_whk, cov -def conditional_statistics( +def _conditional_statistics( levy_area, use_levy: bool, tol, spacing, spline: _Spline, min_num_points ): key = jr.PRNGKey(5678) @@ -377,7 +383,7 @@ def conditional_statistics( bk_r = r**2 * k_r # Compute the target conditional mean and covariance - true_mean_whk, true_cov = true_cond_stats_whk(bm_s, bm_u, s, r, u) + true_mean_whk, true_cov = _true_cond_stats_whk(bm_s, bm_u, s, r, u) # now compute the values of (W_sr, H_sr, K_sr), which are to be tested # against the normal distribution N(mean, cov) @@ -430,7 +436,7 @@ def conditional_statistics( assert bm_u.H is not None # Compute the true conditional statistics for W and H - w_mean2, w_std2, h_mean, h_std = true_cond_stats_wh(bm_s, bm_u, s, r, u) + w_mean2, w_std2, h_mean, h_std = _true_cond_stats_wh(bm_s, bm_u, s, r, u) # Check w_r|(w_s, w_u, h_s, h_u) normalised_w2 = (w_r - w_mean2) / w_std2 @@ -451,10 +457,10 @@ def conditional_statistics( ) -@pytest.mark.parametrize("levy_area", levy_areas) +@pytest.mark.parametrize("levy_area", _levy_areas) @pytest.mark.parametrize("use_levy", (True, False)) def test_conditional_statistics(levy_area, use_levy): - pvals_w1, pvals_w2, pvals_h, pvals_k, mean_err, cov_err = conditional_statistics( + pvals_w1, pvals_w2, pvals_h, pvals_k, mean_err, cov_err = _conditional_statistics( levy_area, use_levy, tol=2**-10, @@ -497,7 +503,7 @@ def _levy_area_spline(): and spline == "quad" ): # The quad spline is not defined for space-time and - # space-time-time Levy area + # space-time-time Lévy area continue yield levy_area, spline @@ -505,7 +511,7 @@ def _levy_area_spline(): @pytest.mark.parametrize("levy_area,spline", _levy_area_spline()) @pytest.mark.parametrize("use_levy", (True, False)) def test_spline(levy_area, use_levy, spline: _Spline): - pvals_w1, pvals_w2, pvals_h, pvals_k, mean_err, cov_err = conditional_statistics( + pvals_w1, pvals_w2, pvals_h, pvals_k, mean_err, cov_err = _conditional_statistics( levy_area, use_levy=use_levy, tol=2**-3, @@ -555,10 +561,10 @@ def pred_sttla(_mean_err, _cov_err): assert pred(pvals_w1) elif levy_area == diffrax.SpaceTimeLevyArea and spline == "zero": # We need a milder upper bound on jnp.mean(pvals_w1) because - # the presence of space-time Levy area gives W_r (i.e. the output + # the presence of space-time Lévy area gives W_r (i.e. the output # of the Brownian path) a variance very close to the correct one, # even when the spline is wrong. In pvals_w2 the influence of the - # Levy area is subtracted in the mean, so we can use a stricter test. + # Lévy area is subtracted in the mean, so we can use a stricter test. n = pvals_w1.shape[0] assert jnp.min(pvals_w1) < 0.03 / n and jnp.mean(pvals_w1) < 0.2 else: @@ -616,7 +622,7 @@ def eval_paths(t): bk_r = r**2 * k_r # Compute the target conditional mean and covariance - true_mean_whk, true_cov = true_cond_stats_whk(bm_s, bm_u, s, r, u) + true_mean_whk, true_cov = _true_cond_stats_whk(bm_s, bm_u, s, r, u) # now compute the values of (W_sr, H_sr, K_sr), which are to be tested # against the normal distribution N(mean, cov)