Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KL Divergence for Latent SDEs #463

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

Conversation

lockwo
Copy link
Contributor

@lockwo lockwo commented Jul 12, 2024

A continuation of #402 with the new 0.6.0 lineax changes

@lockwo lockwo marked this pull request as ready for review July 12, 2024 18:14
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I really like the example here!

I'm afraid this might still take a bit more iteration to get to something clean enough to merge, though -- see my comments. :)

diffrax/_integrate.py Outdated Show resolved Hide resolved
diffrax/_solver/kl.py Outdated Show resolved Hide resolved
diffrax/_solver/kl.py Outdated Show resolved Hide resolved
diffrax/_solver/kl.py Outdated Show resolved Hide resolved
diffrax/_solver/kl.py Outdated Show resolved Hide resolved
Comment on lines 132 to 135
The input must be a `MultiTerm` composed of the first SDE with drift `f`
and diffusion `g` and the second either a SDE or just the drift term
(since the diffusion is assumed to be the same). For example, a type
of: `MuliTerm(MultiTerm(ODETerm, _DiffusionTerm), ODETerm)`.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per this comment:
#402 (review)
and also the updated term docs:
https://docs.kidger.site/diffrax/api/terms/
then this outer MultiTerm isn't really in-keeping. We're not adding all of these extra terms on to the same evolving state.

Bearing in mind that the rest of Diffrax has to see this as just another SDE solve.

I think this one might take a bit more iteration to get to something that's obeying the abstractions in the way they're designed, I'm afraid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I get what you're saying, multiterm implies a single differential equation "unit". So the composed multi terms is bad form. However, I'm not sure I see the difficulty going forward, I can replace it with tuple (multiterm, multiterm) or even tuple (multiterm, ode term). Which seems to adhere to this principle of multiterm = sde unit, since we are integrating two simultaneous SDEs, while also falling in line with other solvers (such as implicit Euler as you remarked).

On the terms vs solver approach, I am open to both. I think in my many iterations/experimentations I found the solver approach more in line with my thinking about the nature of the problem, specifically the original idea of (terms, kl_term) I didn't see as appealing since the KL_term relies on information from the other term and I didn't see a clean way to do that. However, having terms with a term wrapper is very doable (but may not mesh with the repo as well).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe introducing a totally new term is ok (given the remarks in #453), in which case the approach of a KLTerm (rather than a solver), is doable. Given the restricted nature of terms so far, I originally thought that wasn't in line with the package

diffrax/_solver/kl.py Outdated Show resolved Hide resolved
diffrax/_solver/kl.py Outdated Show resolved Hide resolved
diffrax/_solver/kl.py Outdated Show resolved Hide resolved
@lockwo
Copy link
Contributor Author

lockwo commented Jul 22, 2024

Okay, I really like the example here!

I'm afraid this might still take a bit more iteration to get to something clean enough to merge, though -- see my comments. :)

Happy to iterate on cleaning it, I think the biggest question is the design one (on solvers, terms, and how to represent the problem in diffrax). Once that is resolved, I can iterate quickly to get the rest in :)

@lockwo
Copy link
Contributor Author

lockwo commented Aug 15, 2024

Ok, I took the feedback from Andraz's Langevin PR regarding terms and incorporated it into this PR. I think it made things simpler and more in line with the diffrax philosophy, let me know what you think. Basically, now like Langevin, there's just a function that accepts multi term and returns a multi term of private terms that can be consumed by any solver.

The reason I went with returning a single multiterm is you are kind of only solving the one SDE. You use the prior SDE to inform the KL divergence, but its not like fully integrated or anything

@lockwo lockwo requested a review from patrick-kidger August 15, 2024 19:21
@lockwo lockwo mentioned this pull request Aug 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants