diff --git a/benchmarks/brownian_tree_times.py b/benchmarks/brownian_tree_times.py index 66190f45..832e8502 100644 --- a/benchmarks/brownian_tree_times.py +++ b/benchmarks/brownian_tree_times.py @@ -1,6 +1,6 @@ """ v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is -additionally capable of computing Levy area. +additionally capable of computing Lévy area. Here we check the speed of the new implementation against the old implementation, to be sure that it is still fast. diff --git a/diffrax/_autocitation.py b/diffrax/_autocitation.py index c41fcb72..547177ce 100644 --- a/diffrax/_autocitation.py +++ b/diffrax/_autocitation.py @@ -368,28 +368,17 @@ def _space_time_levy_area(terms): ) leaves = jtu.tree_leaves(terms, is_leaf=has_levy_area) if any(has_levy_area(leaf) for leaf in leaves): - return r""" -% You are simulating Brownian motion using Levy area, the formulae for which -% are due to: -@misc{jelinčič2024singleseed, - title={Single-seed generation of Brownian paths and integrals - for adaptive and high order SDE solvers}, - author={Andraž Jelinčič and James Foster and Patrick Kidger}, - year={2024}, - eprint={2405.06464}, - archivePrefix={arXiv}, - primaryClass={math.NA} -} - -% and Theorem 6.1.6 of -@phdthesis{foster2020a, - publisher = {University of Oxford}, - school = {University of Oxford}, - title = {Numerical approximations for stochastic differential equations}, - author = {Foster, James M.}, - year = {2020} -} -""" + _, single_seed_ref, foster_ref = _parse_reference_multi(VirtualBrownianTree) + return ( + r""" + % You are simulating Brownian motion using Lévy area, + % the formulae for which are due to: + """ + + single_seed_ref + + "\n\n" + + r"""% and Theorem 6.1.6 of:""" + + foster_ref + ) @citation_rules.append diff --git a/diffrax/_brownian/base.py b/diffrax/_brownian/base.py index 1642d315..21618b76 100644 --- a/diffrax/_brownian/base.py +++ b/diffrax/_brownian/base.py @@ -41,7 +41,7 @@ def evaluate( left-continuous or right-continuous at any jump points, but Brownian motion has no jump points.) - `use_levy`: If True, the return type will be a `LevyVal`, which contains - PyTrees of Brownian increments and their Levy areas. + PyTrees of Brownian increments and their Lévy areas. **Returns:** diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 09c9ba37..0333caa5 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -46,7 +46,7 @@ class UnsafeBrownianPath(AbstractBrownianPath): motion. Hence the restrictions above. (They describe the general case for which the correlation structure isn't needed.) - !!! info "Levy Area" + !!! info "Lévy Area" Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or `diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it @@ -147,11 +147,11 @@ def _evaluate_leaf( use_levy: bool, ): w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) - key_w, key_hh, key_kk = jr.split(key, 3) - w = jr.normal(key, shape.shape, shape.dtype) * w_std dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype)) if levy_area is SpaceTimeTimeLevyArea: + key_w, key_hh, key_kk = jr.split(key, 3) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std kk_std = w_std / math.sqrt(720) @@ -159,10 +159,13 @@ def _evaluate_leaf( levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) elif levy_area is SpaceTimeLevyArea: + key_w, key_hh = jr.split(key, 2) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh) elif levy_area is BrownianIncrement: + w = jr.normal(key, shape.shape, shape.dtype) * w_std levy_val = BrownianIncrement(dt=dt, W=w) else: assert False @@ -180,6 +183,6 @@ def _evaluate_leaf( be a tuple of integers, describing the shape of a single JAX array. In that case the dtype is chosen to be the default floating-point dtype. - `key`: A random key. -- `levy_area`: Whether to additionally generate Levy area. This is required by some SDE +- `levy_area`: Whether to additionally generate Lévy area. This is required by some SDE solvers. """ diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index 91d3b8fa..83259567 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -52,9 +52,9 @@ # author = {Foster, James M.}, # year = {2020} # } -# For more about space-time Levy area see Definition 4.2.1. -# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6. -# For the general interpolation rule for space-time Levy area see Theorem 6.1.4. +# For more about space-time Lévy area see Definition 4.2.1. +# For the midpoint rule for generating space-time Lévy area see Theorem 6.1.6. +# For the general interpolation rule for space-time Lévy area see Theorem 6.1.4. FloatDouble: TypeAlias = tuple[Inexact[Array, " *shape"], Inexact[Array, " *shape"]] FloatTriple: TypeAlias = tuple[ @@ -64,9 +64,9 @@ _BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement) -# An internal dataclass that holds the rescaled Levy areas +# An internal dataclass that holds the rescaled Lévy areas # in addition to the non-rescaled ones, for the purposes of -# taking the difference between two Levy areas. +# taking the difference between two Lévy areas. class _LevyVal(eqx.Module): dt: RealScalarLike W: Array @@ -75,7 +75,7 @@ class _LevyVal(eqx.Module): K: Optional[Array] bar_K: Optional[Array] - def __post_init__(self): + def __check_init__(self): if self.H is None: assert self.bar_H is None assert self.K is None @@ -99,14 +99,14 @@ def _levy_diff(_, x0: _LevyVal, x1: _LevyVal) -> AbstractBrownianIncrement: **Arguments:** - `_`: unused, for the purposes of aligning the `jtu.tree_map`. - - `x0`: `_AbstractLevyVal` at time `s`. - - `x1`: `_AbstractLevyVal` at time `u`. + - `x0`: `_LevyVal` at time `s`. + - `x1`: `_LevyVal` at time `u`. **Returns:** `AbstractBrownianIncrement(W_su, H_su, K_su)` """ - dtype = jnp.dtype(x0.W) + dtype = jnp.result_type(x0.W) tdtype = complex_to_real_dtype(dtype) su = jnp.asarray(x1.dt - x0.dt, dtype=tdtype) if x0.H is None: # BM only case @@ -114,7 +114,7 @@ def _levy_diff(_, x0: _LevyVal, x1: _LevyVal) -> AbstractBrownianIncrement: return BrownianIncrement(dt=su, W=x1.W - x0.W) # the following computation is common to the space-time - # and the space-time-time Levy area case + # and the space-time-time Lévy area case assert x0.H is not None assert x1.H is not None assert x0.bar_H is not None @@ -126,10 +126,10 @@ def _levy_diff(_, x0: _LevyVal, x1: _LevyVal) -> AbstractBrownianIncrement: bhh_su = x1.bar_H - x0.bar_H - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s) hh_su = inverse_su * bhh_su - if x0.K is None: # space-time Levy area case + if x0.K is None: # space-time Lévy area case return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su) - elif x0.K is not None: # space-time-time Levy area case + elif x0.K is not None: # space-time-time Lévy area case assert x1.K is not None assert x0.bar_K is not None assert x1.bar_K is not None @@ -137,9 +137,9 @@ def _levy_diff(_, x0: _LevyVal, x1: _LevyVal) -> AbstractBrownianIncrement: bkk_su = ( x1.bar_K - x0.bar_K - - su / 2 * x0.bar_H - + x0.dt / 2 * bhh_su - - (x1.dt - 2 * x0.dt) / 12 * u_bb_s + - (su / 2) * x0.bar_H + + (x0.dt / 2) * bhh_su + - ((x1.dt - 2 * x0.dt) / 12) * u_bb_s ) su2 = jnp.square(su) inverse_su2 = 1 / jnp.where(jnp.abs(su2) < jnp.finfo(su2).eps, jnp.inf, su2) @@ -161,8 +161,12 @@ def _make_levy_val(_, x: _LevyVal) -> AbstractBrownianIncrement: def _split_interval( - pred: BoolScalarLike, x_stu: FloatTriple, x_st_tu: FloatDouble -) -> FloatTriple: + pred: BoolScalarLike, x_stu: Optional[FloatTriple], x_st_tu: Optional[FloatDouble] +) -> Optional[FloatTriple]: + if x_stu is None: + assert x_st_tu is None + return None + assert x_st_tu is not None x_s, x_t, x_u = x_stu x_st, x_tu = x_st_tu x_s = jnp.where(pred, x_t, x_s) @@ -174,12 +178,14 @@ def _split_interval( class VirtualBrownianTree(AbstractBrownianPath): """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`. - !!! info "Levy Area" + !!! info "Lévy Area" The parameter `levy_area` can be set to one of: + - [`diffrax.BrownianIncrement`][] (default, generates the increment of W) - - [`diffrax.SpaceTimeLevyArea`][] (generates W and the space-time Levy area H) + - [`diffrax.SpaceTimeLevyArea`][] (generates W and the space-time Lévy area H) - [`diffrax.SpaceTimeTimeLevyArea`][] (generates W, H and the space-time-time - Levy area K) + Lévy area K) + The choice of `levy_area` will impact the Brownian path, so even with the same key, the trajectory will be different depending on the value of `levy_area`. @@ -197,8 +203,8 @@ class VirtualBrownianTree(AbstractBrownianPath): ``` The implementation here is an improvement on the above, in that it additionally - simulates space-time and space-time-time Levy areas, and exactly matches the - distribution of the Brownian motion and its Levy areas at all query times. + simulates space-time and space-time-time Lévy areas, and exactly matches the + distribution of the Brownian motion and its Lévy areas at all query times. This is due to the paper ```bibtex @@ -386,7 +392,7 @@ def _body_fun(_state: _State): ( _t, _w_stu, - _w_inc, + _w_st_tu, _keys, _bhh_stu, _bhh_st_tu, @@ -400,22 +406,10 @@ def _body_fun(_state: _State): _key_st, _key_tu = _keys _key = jnp.where(_cond, _key_st, _key_tu) - _w = _split_interval(_cond, _w_stu, _w_inc) - - if self.levy_area is SpaceTimeTimeLevyArea: - assert _bkk_stu is not None and _bkk_st_tu is not None - _bkk = _split_interval(_cond, _bkk_stu, _bkk_st_tu) - else: - _bkk = None - - if ( - self.levy_area is SpaceTimeLevyArea - or self.levy_area is SpaceTimeTimeLevyArea - ): - assert _bhh_stu is not None and _bhh_st_tu is not None - _bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu) - else: - _bhh = None + _w = _split_interval(_cond, _w_stu, _w_st_tu) + assert _w is not None + _bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu) + _bkk = _split_interval(_cond, _bkk_stu, _bkk_st_tu) return _State( level=_level, @@ -474,8 +468,8 @@ def _body_fun(_state: _State): 2 * su ) wk_cov = (sr_by_su**4) * ru_by_su * (sr - ru) / 12 - hh_cov = ( - sr / 12 * (1 - sr_by_su_3 * (sr2 + 2 * sr * ru + 16 * ru2) / su2) + hh_cov = (sr / 12) * ( + 1 - sr_by_su_3 * (sr2 + 2 * sr * ru + 16 * ru2) / su2 ) hk_cov = -(ru / 24) * sr_by_su_5 kk_cov = (sr / 720) * (1.0 - sr_by_su_5) @@ -489,15 +483,20 @@ def _body_fun(_state: _State): ) if self._spline == "sqrt": - # NOTE: not compatible with jnp.float16 + # NOTE: jr.multivariate_normal is not compatible with jnp.float16, + # so we need to cast to jnp.float32 before calling it. + with jax.numpy_dtype_promotion("standard"): + dtype_atleast32 = jnp.result_type(dtype, jnp.float32) + cov = jnp.asarray(cov, dtype_atleast32) hat_y = jr.multivariate_normal( final_state.key, - jnp.zeros((3,), dtype), + jnp.zeros((3,), dtype_atleast32), cov, shape=shape, - dtype=dtype, + dtype=dtype_atleast32, method="svd", ) + hat_y = jnp.asarray(hat_y, dtype) elif self._spline == "zero": hat_y = jnp.zeros(shape=shape + (3,), dtype=dtype) @@ -525,9 +524,9 @@ def _body_fun(_state: _State): bkk_r = ( bkk_s + bkk_sr - + sr / 2 * bhh_s - - s / 2 * bhh_sr - + (r - 2 * s) / 12 * r_bb_s + + (sr / 2) * bhh_s + - (s / 2) * bhh_sr + + ((r - 2 * s) / 12) * r_bb_s ) inverse_r = 1 / jnp.where(jnp.square(r) < jnp.finfo(r).eps, jnp.inf, r) @@ -618,7 +617,7 @@ def _brownian_arch( and also returns `w_st` and `w_tu` in addition to just `w_t`. Same for `bhh` if it is not None. Note that the inputs and outputs already contain `bkk`. These values are - there for the sake of a future extension with "space-time-time" Levy area + there for the sake of a future extension with "space-time-time" Lévy area and should be None for now. **Arguments:** @@ -637,6 +636,7 @@ def _brownian_arch( - `bhh_st_tu`: (optional) $(\bar{H}_{s,t}, \bar{H}_{t,u})$ - `bkk_stu`: (optional) $(\bar{K}_s, \bar{K}_t, \bar{K}_u)$ - `bkk_st_tu`: (optional) $(\bar{K}_{s,t}, \bar{K}_{t,u})$ + """ key_st, midpoint_key, key_tu = jr.split(_state.key, 3) keys = (key_st, key_tu) @@ -669,15 +669,15 @@ def _brownian_arch( su2 = su**2 w_term1 = w_su / 2 - w_term2 = 3 / (2 * su) * bhh_su + z + w_term2 = (3 / (2 * su)) * bhh_su + z w_st = w_term1 + w_term2 w_tu = w_term1 - w_term2 - bhh_term1 = bhh_su / 8 - st / 2 * z - bhh_term2 = 15 / (8 * su) * bkk_su + st * x1 + bhh_term1 = bhh_su / 8 - (st / 2) * z + bhh_term2 = (15 / (8 * su)) * bkk_su + st * x1 bhh_st = bhh_term1 + bhh_term2 bhh_tu = bhh_term1 - bhh_term2 bkk_term1 = bkk_su / 32 - (su2 / 8) * x1 - bkk_term2 = su2 / 4 * x2 + bkk_term2 = (su2 / 4) * x2 bkk_st = bkk_term1 + bkk_term2 bkk_tu = bkk_term1 - bkk_term2 w_st_tu = (w_st, w_tu) @@ -690,9 +690,9 @@ def _brownian_arch( bkk_t = ( bkk_s + bkk_st - + st / 2 * bhh_s - - s / 2 * bhh_st - + (t - 2 * s) / 12 * t_bb_s + + (st / 2) * bhh_s + - (s / 2) * bhh_st + + ((t - 2 * s) / 12) * t_bb_s ) w_stu = (w_s, w_t, w_u) @@ -711,13 +711,13 @@ def _brownian_arch( n = z2 * jnp.sqrt(su / 12) w_term1 = w_su / 2 - w_term2 = 3 / (2 * su) * bhh_su + z + w_term2 = (3 / (2 * su)) * bhh_su + z w_st = w_term1 + w_term2 w_tu = w_term1 - w_term2 w_st_tu = (w_st, w_tu) bhh_term1 = bhh_su / 8 - su / 4 * z - bhh_term2 = su / 4 * n + bhh_term2 = (su / 4) * n bhh_st = bhh_term1 + bhh_term2 bhh_tu = bhh_term1 - bhh_term2 bhh_st_tu = (bhh_st, bhh_tu) @@ -734,7 +734,7 @@ def _brownian_arch( assert _state.bhh_s_u_su is None assert _state.bkk_s_u_su is None mean = 0.5 * w_su - w_term2 = root_su / 2 * jr.normal(midpoint_key, shape, dtype) + w_term2 = (root_su / 2) * jr.normal(midpoint_key, shape, dtype) w_st = mean + w_term2 w_tu = mean - w_term2 w_st_tu = (w_st, w_tu) diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index 494ceacd..a74f278c 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -48,8 +48,7 @@ class AbstractStochasticCoeffs(eqx.Module): b_error: eqx.AbstractVar[Optional[Float[np.ndarray, " s"]]] @abc.abstractmethod - def check(self) -> int: - ... + def check(self) -> int: ... class AdditiveCoeffs(AbstractStochasticCoeffs): @@ -143,8 +142,8 @@ def __post_init__(self): assert self.coeffs_hh.check() == num_stages if self.coeffs_kk is not None: assert self.coeffs_hh is not None, ( - "If space-time-time Levy area (K) is used," - " space-time Levy area (H) must also be used." + "If space-time-time Lévy area (K) is used," + " space-time Lévy area (H) must also be used." ) assert type(self.coeffs_kk) is type(self.coeffs_w) assert self.coeffs_kk.check() == num_stages @@ -189,9 +188,9 @@ def __post_init__(self): - `coeffs_w`: An instance of `AdditiveCoeffs` or `GeneralCoeffs`, providing the coefficients of the Brownian motion increments. - `coeffs_hh`: An instance of `AdditiveCoeffs` or `GeneralCoeffs`, providing the - coefficients of the space-time Levy area. + coefficients of the space-time Lévy area. - `coeffs_kk`: An instance of `AdditiveCoeffs` or `GeneralCoeffs`, providing the - coefficients of the space-time-time Levy area. + coefficients of the space-time-time Lévy area. - `ignore_stage_f`: Optional. A NumPy array of length `s` of booleans. If `True` at stage `j`, the vector field of the drift term will not be evaluated at stage `j`. - `ignore_stage_g`: Optional. A NumPy array of length `s` of booleans. If `True` at @@ -206,9 +205,9 @@ class AbstractSRK(AbstractSolver[_SolverState]): `MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))` or `MultiTerm(ODETerm(drift), WeaklyDiagonalControlTerm(diffusion, brownian_motion))`. Depending on the solver, the Brownian motion might need to generate - different types of Levy areas, specified by the `minimal_levy_area` attribute. + different types of Lévy areas, specified by the `minimal_levy_area` attribute. - For example, the [`diffrax.ShARK`][] solver requires space-time Levy area, so + For example, the [`diffrax.ShARK`][] solver requires space-time Lévy area, so it will have `minimal_levy_area = AbstractSpaceTimeLevyArea` and the Brownian motion must be initialised with `levy_area=SpaceTimeLevyArea`. @@ -260,10 +259,10 @@ class AbstractSRK(AbstractSolver[_SolverState]): term_compatible_contr_kwargs = (dict(), dict(use_levy=True)) tableau: AbstractClassVar[StochasticButcherTableau] - # Indicates the type of Levy area used by the solver. - # The BM must generate at least this type of Levy area, but can generate - # more. E.g. if the solver uses space-time Levy area, then the BM generates - # space-time-time Levy area as well that is fine. The other way around would + # Indicates the type of Lévy area used by the solver. + # The BM must generate at least this type of Lévy area, but can generate + # more. E.g. if the solver uses space-time Lévy area, then the BM generates + # space-time-time Lévy area as well that is fine. The other way around would # not work. This is mostly an easily readable indicator so that methods know # what kind of BM to use. @property @@ -288,7 +287,7 @@ def init( args: PyTree, ) -> _SolverState: del t1 - # Check that the diffusion has the correct Levy area + # Check that the diffusion has the correct Lévy area _, diffusion = terms.terms if self.tableau.is_additive_noise(): @@ -372,12 +371,12 @@ def make_zeros_aux(leaf): b_levy_list = [] levy_areas = [] - if self.tableau.coeffs_hh is not None: # space-time Levy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) levy_areas.append(bm_inc.H) b_levy_list.append(jnp.asarray(self.tableau.coeffs_hh.b_sol, dtype=dtype)) - if self.tableau.coeffs_kk is not None: # space-time-time Levy area + if self.tableau.coeffs_kk is not None: # space-time-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeTimeLevyArea) levy_areas.append(bm_inc.K) b_levy_list.append( @@ -397,7 +396,7 @@ def aux_add_levy(w_leaf, *levy_leaves): levylist_kgs = [] # will contain levy * g(t0 + c_j * h, z_j) for each stage j # where levy is either H or K (if those entries exist) - # this is similar to h_kfs or w_kgs, but for the Levy area(s) + # this is similar to h_kfs or w_kgs, but for the Lévy area(s) if self.tableau.is_additive_noise(): # additive noise # compute g once since it is constant @@ -413,12 +412,12 @@ def _comp_g(_t): w_kgs = diffusion.prod(g0, w) a_w = jnp.asarray(self.tableau.coeffs_w.a, dtype=dtype) - if self.tableau.coeffs_hh is not None: # space-time Levy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) levylist_kgs.append(diffusion.prod(g0, bm_inc.H)) a_levy.append(jnp.asarray(self.tableau.coeffs_hh.a, dtype=dtype)) - if self.tableau.coeffs_kk is not None: # space-time-time Levy area + if self.tableau.coeffs_kk is not None: # space-time-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeTimeLevyArea) levylist_kgs.append(diffusion.prod(g0, bm_inc.K)) a_levy.append(jnp.asarray(self.tableau.coeffs_kk.a, dtype=dtype)) @@ -436,11 +435,11 @@ def _comp_g(_t): w_kgs = make_zeros() a_w = self._embed_a_lower(self.tableau.coeffs_w.a, dtype) - # do the same for each type of Levy area - if self.tableau.coeffs_hh is not None: # space-time Levy area + # do the same for each type of Lévy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area levylist_kgs.append(make_zeros()) a_levy.append(self._embed_a_lower(self.tableau.coeffs_hh.a, dtype)) - if self.tableau.coeffs_kk is not None: # space-time-time Levy area + if self.tableau.coeffs_kk is not None: # space-time-time Lévy area levylist_kgs.append(make_zeros()) a_levy.append(self._embed_a_lower(self.tableau.coeffs_kk.a, dtype)) @@ -581,7 +580,7 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in): # In the additive noise case (i.e. when g is independent of y), # we still need a correction term in case the diffusion vector field # g depends on t. This term is of the form $(g1 - g0) * (0.5*W_n - H_n)$. - if self.tableau.coeffs_hh is not None: # space-time Levy area + if self.tableau.coeffs_hh is not None: # space-time Lévy area assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) time_var_contr = (bm_inc.W**ω - 2.0 * bm_inc.H**ω).ω time_var_term = diffusion.prod(g_delta, time_var_contr) diff --git a/docs/api/brownian.md b/docs/api/brownian.md index c5eb52b4..6d9719bf 100644 --- a/docs/api/brownian.md +++ b/docs/api/brownian.md @@ -25,12 +25,12 @@ SDEs are simulated using a Brownian motion as a control. (See the neural SDE exa --- -## Levy areas +## Lévy areas -Brownian controls can return certain types of Levy areas. These are iterated integrals +Brownian controls can return certain types of Lévy areas. These are iterated integrals of the Brownian motion, and are used by some SDE solvers. When a solver requires a -Levy area, it will have a `minimal_levy_area` attribute, which will always return an -abstract Levy area type, and it can accept any subclass of that type. +Lévy area, it will have a `minimal_levy_area` attribute, which will always return an +abstract Lévy area type, and it can accept any subclass of that type. The inheritance hierarchy is as follows: ``` AbstractBrownianIncrement @@ -43,8 +43,8 @@ AbstractBrownianIncrement For example if `solver.minimal_levy_area` returns an `AbstractSpaceTimeLevyArea`, then the Brownian motion (which is either an `UnsafeBrownianPath` or a `VirtualBrownianTree`) should be initialized with `levy_area=SpaceTimeLevyArea` or -`levy_area=SpaceTimeTimeLevyArea`. Note that for the BM, a concrete class must be -used, not its abstract parent. +`levy_area=SpaceTimeTimeLevyArea`. Note that for the Brownian motion, +a concrete class must be used, not its abstract parent. ::: diffrax.AbstractBrownianIncrement diff --git a/docs/devdocs/srk_example.ipynb b/docs/devdocs/srk_example.ipynb index f1905822..aa2a38cf 100644 --- a/docs/devdocs/srk_example.ipynb +++ b/docs/devdocs/srk_example.ipynb @@ -19,11 +19,11 @@ "\n", "To account for time-dependent noise, the SRK adds a term to the output of each step, which allows it to still maintain its usual strong order of convergence.\n", "\n", - "The SRK is capable of utilising various types of time Levy area, depending on the tableau provided. It can use:\n", + "The SRK is capable of utilising various types of time Lévy area, depending on the tableau provided. It can use:\n", "\n", - "- just the Brownian motion $W$, without any Levy area\n", - "- $W$ and the space-time Levy area $H$\n", - "- $W$, $H$ and the space-time-time Levy area $K$.\n", + "- just the Brownian motion $W$, without any Lévy area\n", + "- $W$ and the space-time Lévy area $H$\n", + "- $W$, $H$ and the space-time-time Lévy area $K$.\n", "For more information see the documentation of the `StochasticButcherTableau` class.\n", "\n", "First we will demonstrate an additive-noise-only SRK method, the ShARK method, on an SDE with additive, time-dependent noise.\n", @@ -263,7 +263,7 @@ " \n", "\n", "### Shifted Additive-noise Euler (SEA)\n", - "This variant of the Euler-Maruyama makes use of the space-time Levy area, which improves its local error to $O(h^2)$ compared to $O(h^{1.5})$ of the standard Euler-Maruyama. Nevertheless, it has a strong order of only 1 for additive-noise SDEs.\n", + "This variant of the Euler-Maruyama makes use of the space-time Lévy area, which improves its local error to $O(h^2)$ compared to $O(h^{1.5})$ of the standard Euler-Maruyama. Nevertheless, it has a strong order of only 1 for additive-noise SDEs.\n", "\n", "\n", " ### The \"Splitting Path Runge-Kutta\" (SPaRK) method\n", @@ -277,7 +277,7 @@ "When the noise is commutative it has order 1.\n", "When the noise is additive it has order 1.5.\n", "For the Langevin SDE it has order 2.\n", - "Requires the space-time Levy area H.\n", + "Requires the space-time Lévy area H.\n", "It also natively supports adaptive time-stepping.\n", "\n", "\n", diff --git a/test/helpers.py b/test/helpers.py index 66970db7..fd4097dd 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -183,12 +183,12 @@ def _resulting_levy_area( levy_area1: type[diffrax.AbstractBrownianIncrement], levy_area2: type[diffrax.AbstractBrownianIncrement], ) -> type[diffrax.AbstractBrownianIncrement]: - """A helper that returns the stricter Levy area. + """A helper that returns the stricter Lévy area. **Arguments:** - - `levy_area1`: The first Levy area type. - - `levy_area2`: The second Levy area type. + - `levy_area1`: The first Lévy area type. + - `levy_area2`: The second Lévy area type. **Returns:** diff --git a/test/test_brownian.py b/test/test_brownian.py index 55fac6f5..0c91ad57 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -49,12 +49,12 @@ def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): ((3, 4), jnp.float32), ((1, 2, 3, 4), jnp.float64), ({"a": (1,), "b": (2, 3)}, {"a": None, "b": jnp.float64}), + ((2,), jnp.float16), + (((1, 2), ((3, 4), (5, 6))), (jnp.float16, (jnp.float32, jnp.float64))), ) shapes_dtypes2 = ( - ((2,), jnp.float16), ((1, 2, 3, 4), jnp.complex128), - (((1, 2), ((3, 4), (5, 6))), (jnp.float16, (jnp.float32, jnp.float64))), ({"a": (1,), "b": (2, 3)}, {"a": jnp.float64, "b": jnp.complex128}), ) @@ -62,7 +62,7 @@ def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): ctr is diffrax.VirtualBrownianTree and levy_area is diffrax.SpaceTimeTimeLevyArea ): - # VBT with STTLA does not support float16 or complex dtypes + # VBT with STTLA does not support complex dtypes # because it uses jax.random.multivariate_normal shapes_dtypes = shapes_dtypes1 @@ -503,7 +503,7 @@ def _levy_area_spline(): and spline == "quad" ): # The quad spline is not defined for space-time and - # space-time-time Levy area + # space-time-time Lévy area continue yield levy_area, spline @@ -561,10 +561,10 @@ def pred_sttla(_mean_err, _cov_err): assert pred(pvals_w1) elif levy_area == diffrax.SpaceTimeLevyArea and spline == "zero": # We need a milder upper bound on jnp.mean(pvals_w1) because - # the presence of space-time Levy area gives W_r (i.e. the output + # the presence of space-time Lévy area gives W_r (i.e. the output # of the Brownian path) a variance very close to the correct one, # even when the spline is wrong. In pvals_w2 the influence of the - # Levy area is subtracted in the mean, so we can use a stricter test. + # Lévy area is subtracted in the mean, so we can use a stricter test. n = pvals_w1.shape[0] assert jnp.min(pvals_w1) < 0.03 / n and jnp.mean(pvals_w1) < 0.2 else: