Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reverse mode AD #12

Open
thorek1 opened this issue Jun 1, 2024 · 6 comments
Open

reverse mode AD #12

thorek1 opened this issue Jun 1, 2024 · 6 comments

Comments

@thorek1
Copy link

thorek1 commented Jun 1, 2024

I would like to try your algorithm but my application requires reverse mode AD (Zygote to be more specific). Do you support it?
My reading of the code is that it supports ForwardDiff.jl only for now

@JaimeRZP
Copy link
Owner

JaimeRZP commented Jun 5, 2024

Hi Thorek! I am currently using MCHMC with Zygote so it should work. The problem might be at the Turing level.

@thorek1
Copy link
Author

thorek1 commented Jun 5, 2024

Gotcha. I will try that in the meantime then. Turing with Zygote is an issue on Turings end?
Thanks

@JaimeRZP
Copy link
Owner

JaimeRZP commented Jun 6, 2024

There's been a lot of work to get Zygote to get working properly in Turing in the last releases.
I would make sure I am using the latest version (0.32?).

@thorek1
Copy link
Author

thorek1 commented Jun 7, 2024

I use >=0.32 as well. Just to make sure, here is the alternative (from a user perspective): samps = Turing.sample(loglikelihood_fn, NUTS(adtype = AutoZygote()), n_samples). Ideally the micro canonical hmc sampler supports a similar way of switching ad backend.

@JaimeRZP
Copy link
Owner

JaimeRZP commented Jun 7, 2024

yes, what you are looking for is:

# Define sampler
mchmc = MCHMC(n_adapts, tev; adaptive=true)
sampler = externalsampler(mchm; adtype=AutoZygote())
# Sample
chain = Turing.sample(model,  sampler, 10_000)

hope this helps!

@thorek1
Copy link
Author

thorek1 commented Jun 9, 2024

I understand this isn't an issue with the package. hence, I would recommend closing the issue.

now that I got it to work I wanted to share that for my use case I need much more adaptation draws and samples to get comparable ESS with NUTS [MCHMC (50000,20000) vs NUTS (3000)], while MCHMC is faster.

Do you have any hint what to do when during the tuning phase it shows NaN. I can restart the procedure and it might find a valid point but even then I saw it converging to epsilon = 0 and not recovering from there. Starting from the/a mode does not help either. NUTS does work in this case.

check here for the example application I used with both NUTS and MCHMC

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants