-
-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
STTLA PR #414
STTLA PR #414
Conversation
7b9b1e7
to
a1f2e2a
Compare
Okay, happy to start thinking about reviewing this PR now! If you can rebase on top of |
That's great news! I rebased it on top of main and changed the base branch :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome stuff! I've just quickly looked over this. I'll let you address my first round of comments, but my initial impression is that this is a fairly small change, so it should be pretty easy to land.
diffrax/_autocitation.py
Outdated
return r""" | ||
% You are simulating Brownian motion using Levy area, the formulae for which | ||
% are due to: | ||
@misc{jelinčič2024singleseed, | ||
title={Single-seed generation of Brownian paths and integrals | ||
for adaptive and high order SDE solvers}, | ||
author={Andraž Jelinčič and James Foster and Patrick Kidger}, | ||
year={2024}, | ||
eprint={2405.06464}, | ||
archivePrefix={arXiv}, | ||
primaryClass={math.NA} | ||
} | ||
|
||
% and Theorem 6.1.6 of | ||
@phdthesis{foster2020a, | ||
publisher = {University of Oxford}, | ||
school = {University of Oxford}, | ||
title = {Numerical approximations for stochastic differential equations}, | ||
author = {Foster, James M.}, | ||
year = {2020} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these citations deserve to also go VirtualBrownianTree.__doc__
! :) You did a lot of work making this happen!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean in the autocitation for the VBT? It is already in the doc of VBT, but I did add it to autocitation as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that I think these citations should appear in the docstring, and be parsed out with _parse_reference_multi
.
test/test_brownian.py
Outdated
@@ -13,6 +13,11 @@ | |||
import scipy.stats as stats | |||
|
|||
|
|||
levy_areas = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: should start with an underscore, it's private to this file.
test/test_brownian.py
Outdated
): | ||
# VBT with STTLA does not support float16 or complex dtypes | ||
# because it uses jax.random.multivariate_normal | ||
shapes_dtypes = shapes_dtypes[:6] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a little error prone. Can this be switched into two groups shapes_dtypes1
and shapes_dtypes2
, which are then combined or not appropriately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point.
diffrax/_brownian/tree.py
Outdated
# in addition to the non-rescaled ones, for the purposes of | ||
# taking the difference between two Levy areas. | ||
class _AbstractLevyVal(eqx.Module): | ||
dt: eqx.AbstractVar[Inexact[Array, ""]] | ||
W: eqx.AbstractVar[Array] | ||
|
||
|
||
class _BMLevyVal(_AbstractLevyVal): | ||
dt: Inexact[Array, ""] | ||
W: Array | ||
|
||
|
||
class _AbstractSpaceTimeLevyVal(_AbstractLevyVal): | ||
H: eqx.AbstractVar[Array] | ||
bar_H: eqx.AbstractVar[Array] | ||
|
||
|
||
class _SpaceTimeLevyVal(_AbstractSpaceTimeLevyVal): | ||
dt: Inexact[Array, ""] | ||
W: Array | ||
H: Array | ||
bar_H: Array | ||
|
||
|
||
class _SpaceTimeTimeLevyVal(_AbstractSpaceTimeLevyVal): | ||
dt: Inexact[Array, ""] | ||
W: Array | ||
H: Array | ||
bar_H: Array | ||
K: Array | ||
bar_K: Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly slightly simpler might be something like:
class _LevyVal(eqx.Module):
brownian_increment: AbstractBrownianIncrement
bar_H: Optional[Array]
bar_K: Optional[Array]
?
Possibly adding some Generic[...]
typevars for each of the three arguments if required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, I suppose I was trying to do it very properly, but since it's never exposed to the user, you're right, I should go for the simpler option.
diffrax/_brownian/tree.py
Outdated
class _State(eqx.Module): | ||
level: IntScalarLike # level of the tree | ||
s: RealScalarLike # starting time of the interval | ||
s: Inexact[Array, ""] # starting time of the interval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Times should always be real, not complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to avoid a myriad of with jax.numpy_dtype_promotion("standard"):
, which spaghettify the code a bit, and I thought the times which are only used internally could safely be complex. But I guess it could be a source of bugs, so fair enough.
Thanks for the quick review! I made the changes you suggested. I also added a little subsection on Levy areas in |
diffrax/_autocitation.py
Outdated
return r""" | ||
% You are simulating Brownian motion using Levy area, the formulae for which | ||
% are due to: | ||
@misc{jelinčič2024singleseed, | ||
title={Single-seed generation of Brownian paths and integrals | ||
for adaptive and high order SDE solvers}, | ||
author={Andraž Jelinčič and James Foster and Patrick Kidger}, | ||
year={2024}, | ||
eprint={2405.06464}, | ||
archivePrefix={arXiv}, | ||
primaryClass={math.NA} | ||
} | ||
|
||
% and Theorem 6.1.6 of | ||
@phdthesis{foster2020a, | ||
publisher = {University of Oxford}, | ||
school = {University of Oxford}, | ||
title = {Numerical approximations for stochastic differential equations}, | ||
author = {Foster, James M.}, | ||
year = {2020} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that I think these citations should appear in the docstring, and be parsed out with _parse_reference_multi
.
diffrax/_brownian/path.py
Outdated
use_levy: bool, | ||
): | ||
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) | ||
key_w, key_hh, key_kk = jr.split(key, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do the split inside the if
blocks, to avoid the overhead in the common case of no Levy area?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to keep the generated path the same regardless of the Levy area setting, but fair, I guess this isn't really relevant for UBP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I think I was using the wrong key to generate the bm anyway, oops.
diffrax/_brownian/tree.py
Outdated
K: Optional[Array] | ||
bar_K: Optional[Array] | ||
|
||
def __post_init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: prefer __check_init__
where possible. This is an Equinox-specific extension that (a) runs even when you inherit, and (b) isn't ignored if you define a custom __init__
method, and (c) doesn't allow you to still mutate the class whilst it is running.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I understand, will do.
diffrax/_brownian/tree.py
Outdated
- `x0`: `LevyVal` at time `s`. | ||
- `x1`: `LevyVal` at time `u`. | ||
- `x0`: `_AbstractLevyVal` at time `s`. | ||
- `x1`: `_AbstractLevyVal` at time `u`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the Abstract
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, will do.
diffrax/_brownian/tree.py
Outdated
""" | ||
dtype = jnp.dtype(x0.W) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer using jnp.result_type
as the argument is an array. I forget exactly what goes wrong but I recall getting bit by jnp.dtype
at some point before.
diffrax/_brownian/tree.py
Outdated
) | ||
|
||
if self._spline == "sqrt": | ||
# NOTE: not compatible with jnp.float16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should still aim to handle this. Probably this can be done by casting the dtype?
dtype_atleast32 = jnp.result_type(dtype, jnp.float32)
hat_y = jr.multivariate_normal(..., dtype=dtype_atleast32)
hat_y = hat_y.astype(dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I think generating a multivariate normal shouldn't be too tricky if we wanted to just do it manually.
I also don't love that this is apparently calling out to SVD. Admittedly, though, I have just done a quick google for an explicit form for the square root of a 3x3 posdef matrix and it looks like there isn't one...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did the Cholesky decomp manually (with the help of some obscure symbolic algebra package), but it's very complicated and in my very basic experiments it took slightly longer to compute it with that formula (so computing the decomposition directly from s, r, u), as opposed to first computing the cov matrix, and then doing SVD on that. The issue with Cholesky is that it is very imprecise when the matrix is close to singular (which happens when r is close to s or to u). SVD is the only good option for near-singular matrices. Alternatively we'd have to do a separate case for when r-s or u-r is small, which is probably not ideal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the type-casting idea is very good.
hat_w_sr, hat_hh_sr, hat_kk_sr = [ | ||
x.squeeze(axis=-1) for x in jnp.split(hat_y, 3, axis=-1) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thnk just hat_w_sr, hat_hh_sr, hat_kk_sr = hat_y
will work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I don't think it will, because we're splitting by the last dimension, whereas tuple unpacking splits along the first.
diffrax/_solver/dopri5.py
Outdated
] = _Dopri5Interpolation | ||
interpolation_cls: ClassVar[Callable[..., _Dopri5Interpolation]] = ( | ||
_Dopri5Interpolation | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we skip the spurious reformatting?
(Let's put that in a separate PR if you like? Perhaps this is coming from bumping ruff
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll drop it, sorry.
docs/api/brownian.md
Outdated
|
||
## Levy areas | ||
|
||
Brownian controls can return certain types of Levy areas. These are iterated integrals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Levy should have a diacritic: Lévy.
I don't always include this when writing informally (including in your name, I know!) but I try to get it right in persisent documentation!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yes, I think there's many places where I should fix this.
docs/api/brownian.md
Outdated
For example if `solver.minimal_levy_area` returns an `AbstractSpaceTimeLevyArea`, then | ||
the Brownian motion (which is either an `UnsafeBrownianPath` or | ||
a `VirtualBrownianTree`) should be initialized with `levy_area=SpaceTimeLevyArea` or | ||
`levy_area=SpaceTimeTimeLevyArea`. Note that for the BM, a concrete class must be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I try to avoid the "BM" abbreviation in documentation.
9f33807
to
fd946c5
Compare
Hi Patrick, sorry this round of corrections took so long. I made the fixes you mentioned, including adding diacritics on Lévy, which changed some files outside this PR, but we're still only on 12 files changed, so I hope it's fine 😅. I also rebased on top of the current main. |
No worries -- and the great news is that I think this is ready to be merged :D |
d4a7e4c
to
fa83ab4
Compare
That's very good to hear, I rebased it and squashed some of the commits. I can squash all of them if needed. Hopefully the tests pass. |
Hi Patrick, |
Aaaaaand merged! :D I've just merged this into a new |
That's great news, thanks! I will probably open a new PR with Langevin solvers sometime over the next few days. That one is a bit larger than this one, but in my opinion still more lightweight than the SRK PR. But feel free to finish the next release first and turn to Langevin PR afterwards if you prefer. And after that I'll give you some peace as far as Diffrax is concerned. Upcoming projects are then Langevin MCMC, some signature stuff and finally diffusion models. |
This PR adds space-time-time Levy area to
VirtualBrownianTree
andUnsafeBrownianPath
. The changes to the overall logic are minor, I mostly just added the math for the space-time-time case.The additions to
test_brownian.py
are more extensive:test_conditional_statistics
we test the pvals of the marginals for W, H, and K separately. This confirms that they are all Gaussian, so their joint distribution is exactly determined by their means and covariances, which we also test for. All of these are conditioned on the sample just before and the one just after the current sample (as it was before). I moved the code for computing the statistics into theconditional_statistics
function.test_conditional_statistics
itself is based on thm 3.7 from the Single-seed paper, I chose to test its correctness from a different angle intest_whk_interpolation
. This test tries three things: a) just final interpolation with no halving (brownian bridge) steps, b) only one halving step followed by a final interpolation, c) 15 halving steps (i.e.tol = 2**-15
) and no final interpolation (aka. "zero" spline). Since they all have the same distribution as predicted by thm 3.7, I conclude that indeed everything is as it should be. This still relies on the assumption that the halving steps are correct (i.e. thm 3.5), but that was computed by James in his thesis, so I suppose that's a fair assumption :).In the third commit I added shape annotations to arrays in
AbstractBrownianIncrement
and its subclasses. This is optional, so if you prefer we can remove this commit.In the fourth commit I just removed two unnecessary
# pyright: ignore
comments, so pyright doesn't complain. These have nothing to do with the rest of this PR.I hope this PR is less challenging to review. No hurry though :)