Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KL divergence terms for Latent SDEs #402

Closed
wants to merge 32 commits into from
Closed

Conversation

lockwo
Copy link
Contributor

@lockwo lockwo commented Apr 17, 2024

Addresses #401. Revives #104. Based on that PR, I made the minimal requirements to get it up to current version (e.g. taking callables instead of ODE terms since we can't make these .vf becuase _broadcast_and_upcast requires that aug_y and drift(aug_y) are the same shape, but they aren't).

@lockwo
Copy link
Contributor Author

lockwo commented Apr 17, 2024

Before going further (there is a lot I am going to improve/polish) I wanted to check with your thoughts on the general approach of KL being terms and exposing the user to a function that converts their problem. An alternative could be something like in torchsde where it's part of the intregration method, i.e. the user flags it at integration time.

lockwo and others added 6 commits April 17, 2024 18:24
**Major**

- Renamed "bar"->"meter" to reflect the fact that e.g. the text output isn't a bar.
- Added docstrings for all progress meter methods, the new `diffeqsolve(..., progress_meter=...)` argument, and added everything into the docs.
- Fixed bug with t0!=0 producing the wrong results.
- Fixed vmap'd diffeqsolves crashing progress bars. Making this work actually necessitated a fair bit of rewriting, with a gamut of different callbacks and unvmaps.

**Minor**

- Progress updates now move over `[0, 1]` rather than `[0, 100]`.
- Tidied up tqdm bar formatting.
- `TqdmProgressMeter` now checks at trace-time if `tqdm` is installed, rather than crashing at runtime.
- The progress bar state is now a `PyTree[Array]`
* Added parametric control types

* Demo of AbstractTerm change

* Correct initial parametric control type implementation.

* Parametric AbstractTerm initial implementation.

* Update tests and fix hinting

* Implement review comments.

* Add parametric control check to integrator.

* Update and test parametric control check

* Introduce new LevyArea types

* Updated Brownian path LevyArea types

* Replace Union types in isinstance checks

* Remove rogue comment

* Revert _brownian_arch to single assignment

* Revert _evaluate_leaf key splitting

* Rename variables in test_term

* Update isinstance and issubclass checks

* Safer handling in _denormalise_bm_inc

* Fix style in integrate control type check

* Add draft vector_field typing

* Add draft vector_field typing

* Fix term test

* Revert extemporaneous modifications in _tree

* Rename TimeLevyArea to BrownianIncrement and simplify diff

* Rename AbstractLevyReturn to AbstractBrownianReturn

* Rename _LevyArea to _BrownianReturn

* Enhance _term_compatiblity checks

* Fix merge issues

* Bump pre-commit and fix type hints

* Clean up from self-review

* Explicitly add typeguard to deps

* Bump ruff config to new syntax

* Parameterised terms: fixed term compatibility + spurious pyright errors

Phew, this ended up being a complicated one!

Let's start with the easy stuff:
- Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`.
- Now using ruff.lint and pinned exact typeguard version.

Now on to the hard stuff:
- Fixed term compatibibility missing some edge cases.

Edge cases? What edge cases? Well, what we had before was basically predicated around doing
```python
vf, contr = get_args(term_cls)
```
recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal.

What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it:
```python
class ODETerm(AbstractTerm[_VF, RealScaleLike]): ...
```
so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition).

Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`).

Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings...

So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness.

(I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.)

Does anyone else remember the days when this was a package primarily concerned about solving differential equations?

---------

Co-authored-by: Patrick Kidger <[email protected]>
- set reportUnnecessaryTypeIgnoreComment=true, as required by the
static type "tests" in `test_term`;

- remove now redundant, commented out, beatype based tests in
`test_term`;

- remove all unnecessary type ignore comments from across the
codebase.
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On make these terms: I don't have super strong feelings, but compared to the original PR we have now more clearly defined what a term is in Diffrax, and I think there are other points on the design space.

To be precise: given a diffeq of the form

dy = f(y, z) da + g(y, z) db
dz = h(y, z) da + k(y, z) db

then this would be represented in Diffrax as

terms = (
    MultiTerm(f, g),
    MultiTerm(h, k),
)

In general: everything inside a MultiTerm(...) is all applied to the same dfoo. For example the SDE-specific solvers consume a MultiTerm[ODETerm, AbstractTerm], for the drift and diffusion.
Meanwhile the PyTree structure of terms themselves corresponds to different dfoo and dbar. For example semi-implicit Euler takes a pair of (AbstractTerm, AbstractTerm), corresponding to the two components that are being evolved.

In this case there is an argument that the extra KL-divergence term should really correspond to a new dfoo, and that as such the correct thing to do is to instead replace terms with (terms, kl_term), and then provide a wrapper solver which understands this alternative term structure.

diffrax/_kl_term.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

On the topic of Lineax: indeed, this should definitely make handling PyTrees much easier.

lockwo and others added 2 commits April 21, 2024 22:25
* Fix complex casting and types issues

* Dependency version
@lockwo lockwo changed the base branch from main to dev April 23, 2024 19:02
@lockwo
Copy link
Contributor Author

lockwo commented Apr 24, 2024

I think your idea makes a lot of sense, and I made a fair amount of progress on the solver wrapper approach.

@lockwo lockwo marked this pull request as ready for review April 27, 2024 06:26
@lockwo
Copy link
Contributor Author

lockwo commented Apr 27, 2024

Ok, I polished things up. I went with a sort of hybrid approach where the users specifies the SDEs as you described, then just wraps a solver and everything works smoothly. However, I did create internal terms, in order to get an arbitrary solver to integrate through the KL computation, that was the best way I could think of to do so, but they are completely hidden from the user. I also added the example (can be modified to add more text, or remove pmap although I do like having an example with distribution especially since its painfully slow without it) and a test and updated the docs. Taking it off draft now since its a real PR.

@lockwo lockwo requested a review from patrick-kidger April 27, 2024 06:30
@frankschae
Copy link

This is a very cool feature/example! It looks like one needs to specify

levy_area=diffrax.BrownianIncrement

in diffrax.UnsafeBrownianPath

@lockwo
Copy link
Contributor Author

lockwo commented May 8, 2024

Thanks @frankschae , good catch!

lockwo and others added 4 commits May 8, 2024 17:57
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gosh, it's taken me a long time to get around to this 😅

Thank you for your patience!

As you can see, I think there's actually a way we can make this much much simpler. :)

Comment on lines +807 to 810
if isinstance(solver, KLSolver):
y0 = (y0, 0.0)
y0 = jtu.tree_map(_promote, y0)
# Normalises time: if t0 > t1 then flip things around.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a thing we should put in _integrate.py. Rather, we should have some kind of

def kl_foo_bar(...):
    ...

function, which can be used as

solver, y0 = kl_foo_bar(...)
diffeqsolve(solver=solver, y0,=y0 ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you want y0 out of kl_foo_bar to be specific KL y0 class or would a tuple be fine? Also, would we want to modify the y0 and the solver at the same time or have a KL solver and a KL y0 transformer?

@@ -0,0 +1,299 @@
import operator
from typing import Optional, Tuple, Union
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned elsewhere, use tuple instead of Tuple!

)


_DiffusionTerm = Union[ControlTerm, WeaklyDiagonalControlTerm]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to specialise on specifically a weakly-diagonal kind of control.

Rather, I think any kind of control should be admissible! Then we use Lineax to do the linear solve against whatever kind of operator is returned.

You know, maybe what we should really be doing is implementing #370! That is to say, allow the vector_field of diffrax.ControlTerm(vector_field, ...) to instead return a lx.AbstractLinearOperator. That would then make things exceptionally easy over here.

WDYT?

(I'll hold off on doing the rest of the review yet as I can see this may change things.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense, that could reduce the parsing/if elsing complexity I have by a fair amount. I will make a PR for #370 then incorporate those changes into this and we can revisit this PR at that point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants