Skip to content

Commit

Permalink
Merge pull request #374 from ptiede/ptiede-enzyme0.13
Browse files Browse the repository at this point in the history
Rework ad to prevent compiling the same gradient multiple times
  • Loading branch information
ptiede authored Oct 5, 2024
2 parents 05dd65f + 289bd34 commit 444ab63
Show file tree
Hide file tree
Showing 27 changed files with 182 additions and 284 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ HypercubeTransform = "9ec9aee3-0fd3-44c2-8e61-a50acc66f3c8"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
PaddedViews = "5432bcbf-9aad-5242-b902-cca2824c8663"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Expand Down Expand Up @@ -75,14 +74,13 @@ DimensionalData = "0.27, 0.28"
Distributions = "0.25"
DocStringExtensions = "0.8, 0.9"
Dynesty = "0.4"
Enzyme = "0.12"
EnzymeCore = "0.7"
Enzyme = "0.13"
EnzymeCore = "0.8"
FillArrays = "1"
ForwardDiff = "0.9, 0.10"
HypercubeTransform = "0.4"
IntervalSets = "0.6, 0.7"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
Makie = "0.21"
NamedTupleTools = "0.13,0.14"
Optimization = "4"
Expand All @@ -101,7 +99,7 @@ StatsBase = "0.33,0.34"
StructArrays = "0.5,0.6"
Tables = "1"
TransformVariables = "0.8"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
VLBILikelihoods = "^0.2.6"
VLBISkyModels = "0.6"
julia = "1.10"
Expand Down
7 changes: 6 additions & 1 deletion docs/src/ext/ahmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@ To sample a user can use follow the standard `AdvancedHMC` interface, e.g.,
chain = sample(post, NUTS(0.8), 10_000; n_adapts=5_000)
```

!!! warning
To use HMC the `VLBIPosterior` must be created with a specific `admode` specified.
The `admode` can be a union of `Nothing` and `<:EnzymeCore.Mode` types. We recommend
using `Enzyme.set_runtime_activity(Enzyme.Reverse)`


In addition our sample call has a few additional keyword arguments:

- `adtype = Val(:Enzyme)`: The autodiff package to use. Currently the only options is `Enzyme`. Note that you must load Enzyme before calling `sample`.
- `saveto = MemoryStore()`: Specifies how to store the samples. The default is `MemoryStore` which stores the samples directly in RAM. For large models this is not a good idea. To save samples periodically to disk use [`DiskStore`](@ref), and then load the results with `load_samples`.

Note that like most `AbstractMCMC` samplers the initial location can be specified with the `initial_params` argument.
Expand Down
8 changes: 7 additions & 1 deletion docs/src/ext/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ optimization algorithm.
To see what optimizers are available and what options are available, please see the `Optimizations.jl` [docs](http://optimization.sciml.ai/dev/).


!!! warning
To use use a gradient optimizer with AD, `VLBIPosterior` must be created with a specific `admode` specified.
The `admode` can be a union of `Nothing` and `<:EnzymeCore.Mode` types. We recommend
using `Enzyme.set_runtime_activity(Enzyme.Reverse)`


## Example

```julia
Expand All @@ -18,5 +24,5 @@ using Enzyme
# Some stuff to create a posterior object
post # of type Comrade.Posterior

xopt, sol = comrade_opt(post, LBFGS(); adtype=Val(:Enzyme))
xopt, sol = comrade_opt(post, LBFGS())
```
2 changes: 1 addition & 1 deletion examples/advanced/HybridImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
StatsBase = "0.34"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
6 changes: 3 additions & 3 deletions examples/advanced/HybridImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ skym = SkyModel(sky, skyprior, g; metadata=skymetadata)

# This is everything we need to specify our posterior distribution, which our is the main
# object of interest in image reconstructions when using Bayesian inference.
post = VLBIPosterior(skym, intmodel, dvis)
using Enzyme
post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse))

# To sample from our prior we can do
xrand = prior_sample(rng, post)
Expand All @@ -179,8 +180,7 @@ fig |> DisplayAs.PNG |> DisplayAs.Text #hide
# To use this we use the [`comrade_opt`](@ref) function
using Optimization
using OptimizationOptimJL
using Enzyme
xopt, sol = comrade_opt(post, LBFGS(), AutoEnzyme(;mode=Enzyme.Reverse);
xopt, sol = comrade_opt(post, LBFGS();
initial_params=prior_sample(rng, post), maxiters=1000, g_tol=1e0)


Expand Down
2 changes: 1 addition & 1 deletion examples/beginner/GeometricModeling/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ Pigeons = "0.4"
Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
4 changes: 3 additions & 1 deletion examples/intermediate/ClosureImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -14,6 +15,7 @@ Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
VLBILikelihoods = "90db92cd-0007-4c0a-8e51-dbf0782ce592"

[compat]
CairoMakie = "0.12"
Expand All @@ -24,4 +26,4 @@ Pkg = "1"
Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
6 changes: 3 additions & 3 deletions examples/intermediate/ClosureImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ skym = SkyModel(sky, prior, grid; metadata=skymeta)

# Since we are fitting closures we do not need to include an instrument model, since
# the closure likelihood is approximately independent of gains in the high SNR limit.
post = VLBIPosterior(skym, dlcamp, dcphase)
using Enzyme
post = VLBIPosterior(skym, dlcamp, dcphase; admode=set_runtime_activity(Enzyme.Reverse))

# ## Reconstructing the Image

Expand All @@ -144,8 +145,7 @@ post = VLBIPosterior(skym, dlcamp, dcphase)
# OptimizationOptimJL. We also need to import Enzyme to allow for automatic differentiation.
using Optimization
using OptimizationOptimJL
using Enzyme
xopt, sol = comrade_opt(post, LBFGS(), AutoEnzyme(;mode=Enzyme.Reverse);
xopt, sol = comrade_opt(post, LBFGS();
maxiters=1000, initial_params=prior_sample(rng, post))


Expand Down
2 changes: 0 additions & 2 deletions examples/intermediate/PolarizedImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
18 changes: 9 additions & 9 deletions examples/intermediate/PolarizedImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ end
# image model. Our image will be a 10x10 raster with a 60μas FOV.
using Distributions
using VLBIImagePriors
fovx = μas2rad(150.0)
fovy = μas2rad(150.0)
fovx = μas2rad(200.0)
fovy = μas2rad(200.0)
nx = ny = 32
grid = imagepixels(fovx, fovy, nx, ny)

Expand All @@ -204,7 +204,7 @@ skymeta = (; mimg=mimg./flux(mimg), ftot=0.6)
cprior = corr_image_prior(grid, dvis)
skyprior = (
c = cprior,
σ = truncated(Normal(0.0, 0.5); lower=0.0),
σ = Exponential(0.1),
p = cprior,
p0 = Normal(-2.0, 2.0),
= truncated(Normal(0.0, 1.0); lower=0.01),
Expand Down Expand Up @@ -287,7 +287,7 @@ J = JonesSandwich(js, G, D, R)
intprior = (
lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.2))),
lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.01))),
gpR = ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv^2))); refant=SEFDReference(0.0), phase=false),
gpR = ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv^2))); refant=SEFDReference(0.0), phase=true),
gprat= ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(0.1^2))); refant = SingleReference(:AA, 0.0), phase=false),
dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))),
dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))),
Expand All @@ -301,8 +301,9 @@ intmodel = InstrumentModel(J, intprior)

# intmodel = InstrumentModel(JonesR(;add_fr=true))
# Putting it all together, we form our likelihood and posterior objects for optimization and
# sampling.
post = VLBIPosterior(skym, intmodel, dvis)
# sampling, and specifying to use Enzyme.Reverse with runtime activity for AD.
using Enzyme
post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse))

# ## Reconstructing the Image and Instrument Effects

Expand All @@ -323,8 +324,7 @@ tpost = asflat(post)
# work with the polarized Comrade posterior is Enzyme.
using Optimization
using OptimizationOptimisers
using Enzyme
xopt, sol = comrade_opt(post, Optimisers.Adam(), AutoEnzyme(;mode=Enzyme.Reverse);
xopt, sol = comrade_opt(post, Optimisers.Adam();
initial_params=prior_sample(rng, post), maxiters=25_000)


Expand Down Expand Up @@ -384,7 +384,7 @@ p |> DisplayAs.PNG |> DisplayAs.Text
# other imaging examples. For example
# ```julia
# using AdvancedHMC
# chain = sample(rng, post, NUTS(0.8), 10_000; adtype=AutoEnzyme(;mode=Enzyme.Reverse), n_adapts=5000, progress=true, initial_params=xopt)
# chain = sample(rng, post, NUTS(0.8), 10_000, n_adapts=5000, progress=true, initial_params=xopt)
# ```


Expand Down
2 changes: 1 addition & 1 deletion examples/intermediate/StokesIImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ Pkg = "1"
Plots = "1"
Pyehtim = "0.1"
StableRNGs = "1"
VLBIImagePriors = "0.8"
VLBIImagePriors = "0.9"
12 changes: 8 additions & 4 deletions examples/intermediate/StokesIImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ intpr = (
intmodel = InstrumentModel(G, intpr)


post = VLBIPosterior(skym, intmodel, dvis)
# To form the posterior we just combine the skymodel, instrument model and the data. Additionally,
# since we want to use gradients we need to specify the AD mode. Essentially for all modes we recommend
# using `Enzyme.set_runtime_activity(Enzyme.Reverse)`. Eventually as Comrade and Enzyme matures we will
# no need `set_runtime_activity`.
using Enzyme
post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse))
# done using the `asflat` function.
tpost = asflat(post)
ndim = dimension(tpost)
Expand All @@ -160,8 +165,7 @@ ndim = dimension(tpost)
# To initialize our sampler we will use optimize using Adam
using Optimization
using OptimizationOptimisers
using Enzyme
xopt, sol = comrade_opt(post, Optimisers.Adam(), AutoEnzyme(;mode=Enzyme.Reverse); initial_params=prior_sample(rng, post), maxiters=20_000, g_tol=1e-1)
xopt, sol = comrade_opt(post, Optimisers.Adam(); initial_params=prior_sample(rng, post), maxiters=20_000, g_tol=1e-1)

# !!! warning
# Fitting gains tends to be very difficult, meaning that optimization can take a lot longer.
Expand Down Expand Up @@ -208,7 +212,7 @@ plot(gt, layout=(3,3), size=(600,500)) |> DisplayAs.PNG |> DisplayAs.Text
# run
#-
using AdvancedHMC
chain = sample(rng, post, NUTS(0.8), 1_000; adtype=AutoEnzyme(;mode=Enzyme.Reverse), n_adapts=500, progress=false, initial_params=xopt)
chain = sample(rng, post, NUTS(0.8), 1_000; n_adapts=500, progress=false, initial_params=xopt)
#-
# !!! note
# The above sampler will store the samples in memory, i.e. RAM. For large models this
Expand Down
28 changes: 14 additions & 14 deletions ext/ComradeAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Accessors
using ArgCheck
using DocStringExtensions
using HypercubeTransform
using LogDensityProblems, LogDensityProblemsAD
using LogDensityProblems
using Printf
using Random
using StatsBase
Expand All @@ -36,15 +36,14 @@ end

function AbstractMCMC.Sample(
rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPosterior,
sampler::AbstractHMCSampler; adtype=Val(:Enzyme), initial_params=nothiing, kwargs...)
∇ℓ = ADgradient(adtype, tpost)
sampler::AbstractHMCSampler; initial_params=nothiing, kwargs...)
θ0 = initialize_params(tpost, initial_params)
model, smplr = make_sampler(rng, ∇ℓ, sampler, θ0)
model, smplr = make_sampler(rng, tpost, sampler, θ0)
return AbstractMCMC.Sample(rng, model, smplr; initial_params=θ0, kwargs...)
end

"""
sample(rng, post::VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...;saveto=MemoryStore(), adtype=Val(:Enzyme), initial_params=nothing, kwargs...)
sample(rng, post::VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...;saveto=MemoryStore(), initial_params=nothing, kwargs...)
Sample from the posterior `post` using the sampler `sampler` for `nsamples` samples. Additional
arguments are forwarded to AbstractMCMC.sample. If `saveto` is a DiskStore, the samples will be
Expand All @@ -59,23 +58,26 @@ saved to disk. If `initial_params` is not `nothing` then the sampler will start
## Keyword Arguments
- `saveto`: If a DiskStore, the samples will be saved to disk, if [`MemoryStore`](@ref) the samples will be stored in memory/ram.
- `adtype`: The automatic differentiation type to use. The default if Enzyme which is the recommended choice for Comrade currently.
- `initial_params`: The initial parameters to start the sampler from. If `nothing` then the sampler will start from a random point in the prior.
- `kwargs`: Additional keyword arguments to pass to the sampler. Examples include `n_adapts` which is the total number of samples to use for adaptation.
To see the others see the AdvancedHMC documentation.
"""
function AbstractMCMC.sample(
rng::Random.AbstractRNG, post::Comrade.VLBIPosterior,
sampler::AbstractHMCSampler, nsamples, args...;
saveto=MemoryStore(), adtype=Val(:Enzyme), initial_params=nothing, kwargs...)
saveto=MemoryStore(), initial_params=nothing, kwargs...)

saveto isa DiskStore && return sample_to_disk(rng, post, sampler, nsamples, args...; outdir=saveto.name, output_stride=min(saveto.stride, nsamples), adtype, initial_params, kwargs...)
saveto isa DiskStore && return sample_to_disk(rng, post, sampler, nsamples, args...; outdir=saveto.name, output_stride=min(saveto.stride, nsamples), initial_params, kwargs...)

if isnothing(Comrade.admode(post))
throw(ArgumentError("You must specify an automatic differentiation type in VLBIPosterior with admode kwarg"))
else
tpost = asflat(post)
end

tpost = asflat(post)
∇ℓ = ADgradient(adtype, tpost)
θ0 = initialize_params(tpost, initial_params)
model, smplr = make_sampler(rng, ∇ℓ, sampler, θ0)
model, smplr = make_sampler(rng, tpost, sampler, θ0)

res = sample(rng, model, smplr, nsamples, args...;
initial_params=θ0, saveto=saveto, chain_type=Array, kwargs...)
Expand All @@ -90,7 +92,6 @@ end
function initialize(rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPosterior,
sampler::AbstractHMCSampler, nsamples, outbase, args...;
n_adapts = min(nsamples÷2, 1000),
adtype = Val(:Enzyme),
initial_params=nothing, outdir = "Results",
output_stride=min(100, nsamples),
restart = false,
Expand Down Expand Up @@ -119,7 +120,7 @@ function initialize(rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPoste
@warn "No starting location chosen, picking start from prior"
θ0 = prior_sample(rng, tpost)
end
t = Sample(rng, tpost, sampler; initial_params=θ0, adtype, n_adapts, kwargs...)(1:nsamples)
t = Sample(rng, tpost, sampler; initial_params=θ0, n_adapts, kwargs...)(1:nsamples)
pt = Iterators.partition(t, output_stride)
nscans = nsamples÷output_stride + (nsamples%output_stride!=0 ? 1 : 0)

Expand Down Expand Up @@ -158,7 +159,6 @@ end

function sample_to_disk(rng::Random.AbstractRNG, post::Comrade.VLBIPosterior,
sampler::AbstractHMCSampler, nsamples, args...;
adtype = Val(:Enzyme),
n_adapts = min(nsamples÷2, 1000),
initial_params=nothing, outdir = "Results",
restart=false,
Expand All @@ -172,7 +172,7 @@ function sample_to_disk(rng::Random.AbstractRNG, post::Comrade.VLBIPosterior,

pt, state, out, i = initialize(
rng, tpost, sampler, nsamples, outbase, args...;
n_adapts, adtype,
n_adapts,
initial_params, restart, outdir, output_stride, kwargs...
)

Expand Down
16 changes: 13 additions & 3 deletions ext/ComradeEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
module ComradeEnzymeExt

using Enzyme
using Comrade
using LogDensityProblems

function __init__()
# We need this to ensure than Enzyme can AD through the Comrade code base
Enzyme.API.runtimeActivity!(true)
LogDensityProblems.dimension(d::Comrade.TransformedVLBIPosterior) = dimension(d)
LogDensityProblems.capabilities(::Type{<:Comrade.TransformedVLBIPosterior}) = LogDensityProblems.LogDensityOrder{1}()


function LogDensityProblems.logdensity_and_gradient(d::Comrade.TransformedVLBIPosterior, x::AbstractArray)
mode = Enzyme.EnzymeCore.WithPrimal(Comrade.admode(d))
dx = zero(x)
(_, y) = autodiff(mode, Comrade.logdensityof, Active, Const(d), Duplicated(x, dx))
return y, dx
end



end
Loading

0 comments on commit 444ab63

Please sign in to comment.