Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STTLA PR #414

Merged
merged 4 commits into from
Jun 24, 2024
Merged

STTLA PR #414

merged 4 commits into from
Jun 24, 2024

Conversation

andyElking
Copy link
Contributor

@andyElking andyElking commented May 8, 2024

This PR adds space-time-time Levy area to VirtualBrownianTree and UnsafeBrownianPath. The changes to the overall logic are minor, I mostly just added the math for the space-time-time case.

The additions to test_brownian.py are more extensive:

  • In test_conditional_statistics we test the pvals of the marginals for W, H, and K separately. This confirms that they are all Gaussian, so their joint distribution is exactly determined by their means and covariances, which we also test for. All of these are conditioned on the sample just before and the one just after the current sample (as it was before). I moved the code for computing the statistics into the conditional_statistics function.
  • Since the code of test_conditional_statistics itself is based on thm 3.7 from the Single-seed paper, I chose to test its correctness from a different angle in test_whk_interpolation. This test tries three things: a) just final interpolation with no halving (brownian bridge) steps, b) only one halving step followed by a final interpolation, c) 15 halving steps (i.e. tol = 2**-15) and no final interpolation (aka. "zero" spline). Since they all have the same distribution as predicted by thm 3.7, I conclude that indeed everything is as it should be. This still relies on the assumption that the halving steps are correct (i.e. thm 3.5), but that was computed by James in his thesis, so I suppose that's a fair assumption :).

In the third commit I added shape annotations to arrays in AbstractBrownianIncrement and its subclasses. This is optional, so if you prefer we can remove this commit.

In the fourth commit I just removed two unnecessary # pyright: ignore comments, so pyright doesn't complain. These have nothing to do with the rest of this PR.

I hope this PR is less challenging to review. No hurry though :)

@andyElking andyElking force-pushed the sttla_pr branch 2 times, most recently from 7b9b1e7 to a1f2e2a Compare May 9, 2024 14:07
@patrick-kidger
Copy link
Owner

Okay, happy to start thinking about reviewing this PR now! If you can rebase on top of main then I'll take a look :)

@andyElking andyElking changed the base branch from dev to main May 26, 2024 20:32
@andyElking
Copy link
Contributor Author

That's great news! I rebased it on top of main and changed the base branch :)

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff! I've just quickly looked over this. I'll let you address my first round of comments, but my initial impression is that this is a fairly small change, so it should be pretty easy to land.

Comment on lines 369 to 389
return r"""
% You are simulating Brownian motion using Levy area, the formulae for which
% are due to:
@misc{jelinčič2024singleseed,
title={Single-seed generation of Brownian paths and integrals
for adaptive and high order SDE solvers},
author={Andraž Jelinčič and James Foster and Patrick Kidger},
year={2024},
eprint={2405.06464},
archivePrefix={arXiv},
primaryClass={math.NA}
}

% and Theorem 6.1.6 of
@phdthesis{foster2020a,
publisher = {University of Oxford},
school = {University of Oxford},
title = {Numerical approximations for stochastic differential equations},
author = {Foster, James M.},
year = {2020}
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these citations deserve to also go VirtualBrownianTree.__doc__! :) You did a lot of work making this happen!

Copy link
Contributor Author

@andyElking andyElking May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in the autocitation for the VBT? It is already in the doc of VBT, but I did add it to autocitation as well.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that I think these citations should appear in the docstring, and be parsed out with _parse_reference_multi.

@@ -13,6 +13,11 @@
import scipy.stats as stats


levy_areas = (
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should start with an underscore, it's private to this file.

):
# VBT with STTLA does not support float16 or complex dtypes
# because it uses jax.random.multivariate_normal
shapes_dtypes = shapes_dtypes[:6]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a little error prone. Can this be switched into two groups shapes_dtypes1 and shapes_dtypes2, which are then combined or not appropriately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point.

Comment on lines 68 to 98
# in addition to the non-rescaled ones, for the purposes of
# taking the difference between two Levy areas.
class _AbstractLevyVal(eqx.Module):
dt: eqx.AbstractVar[Inexact[Array, ""]]
W: eqx.AbstractVar[Array]


class _BMLevyVal(_AbstractLevyVal):
dt: Inexact[Array, ""]
W: Array


class _AbstractSpaceTimeLevyVal(_AbstractLevyVal):
H: eqx.AbstractVar[Array]
bar_H: eqx.AbstractVar[Array]


class _SpaceTimeLevyVal(_AbstractSpaceTimeLevyVal):
dt: Inexact[Array, ""]
W: Array
H: Array
bar_H: Array


class _SpaceTimeTimeLevyVal(_AbstractSpaceTimeLevyVal):
dt: Inexact[Array, ""]
W: Array
H: Array
bar_H: Array
K: Array
bar_K: Array
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly slightly simpler might be something like:

class _LevyVal(eqx.Module):
    brownian_increment: AbstractBrownianIncrement
    bar_H: Optional[Array]
    bar_K: Optional[Array]

?
Possibly adding some Generic[...] typevars for each of the three arguments if required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I suppose I was trying to do it very properly, but since it's never exposed to the user, you're right, I should go for the simpler option.

class _State(eqx.Module):
level: IntScalarLike # level of the tree
s: RealScalarLike # starting time of the interval
s: Inexact[Array, ""] # starting time of the interval
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Times should always be real, not complex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to avoid a myriad of with jax.numpy_dtype_promotion("standard"):, which spaghettify the code a bit, and I thought the times which are only used internally could safely be complex. But I guess it could be a source of bugs, so fair enough.

@andyElking
Copy link
Contributor Author

Thanks for the quick review! I made the changes you suggested.

I also added a little subsection on Levy areas in docs/api/brownian/md. Mainly to clear up how solver.minimal_levy_area interacts with VBT, etc.

Comment on lines 369 to 389
return r"""
% You are simulating Brownian motion using Levy area, the formulae for which
% are due to:
@misc{jelinčič2024singleseed,
title={Single-seed generation of Brownian paths and integrals
for adaptive and high order SDE solvers},
author={Andraž Jelinčič and James Foster and Patrick Kidger},
year={2024},
eprint={2405.06464},
archivePrefix={arXiv},
primaryClass={math.NA}
}

% and Theorem 6.1.6 of
@phdthesis{foster2020a,
publisher = {University of Oxford},
school = {University of Oxford},
title = {Numerical approximations for stochastic differential equations},
author = {Foster, James M.},
year = {2020}
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that I think these citations should appear in the docstring, and be parsed out with _parse_reference_multi.

use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
key_w, key_hh, key_kk = jr.split(key, 3)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the split inside the if blocks, to avoid the overhead in the common case of no Levy area?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep the generated path the same regardless of the Levy area setting, but fair, I guess this isn't really relevant for UBP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I think I was using the wrong key to generate the bm anyway, oops.

K: Optional[Array]
bar_K: Optional[Array]

def __post_init__(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: prefer __check_init__ where possible. This is an Equinox-specific extension that (a) runs even when you inherit, and (b) isn't ignored if you define a custom __init__ method, and (c) doesn't allow you to still mutate the class whilst it is running.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I understand, will do.

- `x0`: `LevyVal` at time `s`.
- `x1`: `LevyVal` at time `u`.
- `x0`: `_AbstractLevyVal` at time `s`.
- `x1`: `_AbstractLevyVal` at time `u`.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the Abstract.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, will do.

"""
dtype = jnp.dtype(x0.W)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer using jnp.result_type as the argument is an array. I forget exactly what goes wrong but I recall getting bit by jnp.dtype at some point before.

)

if self._spline == "sqrt":
# NOTE: not compatible with jnp.float16
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still aim to handle this. Probably this can be done by casting the dtype?

dtype_atleast32 = jnp.result_type(dtype, jnp.float32)
hat_y = jr.multivariate_normal(..., dtype=dtype_atleast32)
hat_y = hat_y.astype(dtype)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I think generating a multivariate normal shouldn't be too tricky if we wanted to just do it manually.

I also don't love that this is apparently calling out to SVD. Admittedly, though, I have just done a quick google for an explicit form for the square root of a 3x3 posdef matrix and it looks like there isn't one...

Copy link
Contributor Author

@andyElking andyElking Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the Cholesky decomp manually (with the help of some obscure symbolic algebra package), but it's very complicated and in my very basic experiments it took slightly longer to compute it with that formula (so computing the decomposition directly from s, r, u), as opposed to first computing the cov matrix, and then doing SVD on that. The issue with Cholesky is that it is very imprecise when the matrix is close to singular (which happens when r is close to s or to u). SVD is the only good option for near-singular matrices. Alternatively we'd have to do a separate case for when r-s or u-r is small, which is probably not ideal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the type-casting idea is very good.

Comment on lines +510 to +511
hat_w_sr, hat_hh_sr, hat_kk_sr = [
x.squeeze(axis=-1) for x in jnp.split(hat_y, 3, axis=-1)
]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thnk just hat_w_sr, hat_hh_sr, hat_kk_sr = hat_y will work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think it will, because we're splitting by the last dimension, whereas tuple unpacking splits along the first.

] = _Dopri5Interpolation
interpolation_cls: ClassVar[Callable[..., _Dopri5Interpolation]] = (
_Dopri5Interpolation
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip the spurious reformatting?

(Let's put that in a separate PR if you like? Perhaps this is coming from bumping ruff.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll drop it, sorry.


## Levy areas

Brownian controls can return certain types of Levy areas. These are iterated integrals
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Levy should have a diacritic: Lévy.

I don't always include this when writing informally (including in your name, I know!) but I try to get it right in persisent documentation!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yes, I think there's many places where I should fix this.

For example if `solver.minimal_levy_area` returns an `AbstractSpaceTimeLevyArea`, then
the Brownian motion (which is either an `UnsafeBrownianPath` or
a `VirtualBrownianTree`) should be initialized with `levy_area=SpaceTimeLevyArea` or
`levy_area=SpaceTimeTimeLevyArea`. Note that for the BM, a concrete class must be
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I try to avoid the "BM" abbreviation in documentation.

@andyElking andyElking force-pushed the sttla_pr branch 2 times, most recently from 9f33807 to fd946c5 Compare June 12, 2024 18:33
@andyElking
Copy link
Contributor Author

Hi Patrick, sorry this round of corrections took so long. I made the fixes you mentioned, including adding diacritics on Lévy, which changed some files outside this PR, but we're still only on 12 files changed, so I hope it's fine 😅. I also rebased on top of the current main.

@patrick-kidger
Copy link
Owner

No worries -- and the great news is that I think this is ready to be merged :D
Can you rebase on top of the latest main, and we can get this in!

@andyElking andyElking force-pushed the sttla_pr branch 2 times, most recently from d4a7e4c to fa83ab4 Compare June 18, 2024 09:39
@andyElking
Copy link
Contributor Author

andyElking commented Jun 18, 2024

That's very good to hear, I rebased it and squashed some of the commits. I can squash all of them if needed. Hopefully the tests pass.

@andyElking
Copy link
Contributor Author

Hi Patrick,
Sorry for yet another delay. Seems like the tests passed, so I hope this means this PR is ready for a merge. Let me know if there's any more fixes I should make.

@patrick-kidger patrick-kidger changed the base branch from main to dev June 24, 2024 21:18
@patrick-kidger patrick-kidger merged commit da5031d into patrick-kidger:dev Jun 24, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 24, 2024

Aaaaaand merged! :D
This is really great stuff, it's awesome to see how much you're adding to Diffrax.

I've just merged this into a new dev branch (currently equal to main + your PR) in preparation for the next release! I'm aiming to merge in #387, and possibly #436, and then do a new release. What do your own future plans look like here?

@andyElking
Copy link
Contributor Author

That's great news, thanks!

I will probably open a new PR with Langevin solvers sometime over the next few days. That one is a bit larger than this one, but in my opinion still more lightweight than the SRK PR. But feel free to finish the next release first and turn to Langevin PR afterwards if you prefer.

And after that I'll give you some peace as far as Diffrax is concerned. Upcoming projects are then Langevin MCMC, some signature stuff and finally diffusion models.

@andyElking andyElking deleted the sttla_pr branch June 26, 2024 09:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants