From aed3a5e9d7fd2696fe6f6b40a1fec4c53471bba3 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Oct 2024 13:51:27 -0400 Subject: [PATCH 1/6] Rework how ad works to prevent compiling the same gradient multiple times --- Project.toml | 8 +- docs/src/ext/ahmc.md | 7 +- docs/src/ext/optimization.md | 8 +- examples/advanced/HybridImaging/main.jl | 4 +- .../beginner/GeometricModeling/Project.toml | 2 +- .../intermediate/ClosureImaging/Project.toml | 4 +- examples/intermediate/ClosureImaging/main.jl | 6 +- .../PolarizedImaging/Project.toml | 2 - .../intermediate/PolarizedImaging/main.jl | 18 +-- .../intermediate/StokesIImaging/Project.toml | 2 +- examples/intermediate/StokesIImaging/main.jl | 12 +- ext/ComradeAdvancedHMCExt.jl | 28 ++-- ext/ComradeEnzymeExt.jl | 16 ++- ext/ComradeOptimizationExt.jl | 34 +++-- src/inference/optimization.jl | 2 +- src/posterior/transformed.jl | 2 +- src/posterior/vlbiposterior.jl | 24 ++-- src/rules.jl | 130 +----------------- test/Core/bayes.jl | 6 +- test/ext/comradeahmc.jl | 4 +- test/ext/comradeoptimization.jl | 4 +- 21 files changed, 118 insertions(+), 205 deletions(-) diff --git a/Project.toml b/Project.toml index 73c9ecaf..20d1cec2 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,6 @@ HypercubeTransform = "9ec9aee3-0fd3-44c2-8e61-a50acc66f3c8" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" PaddedViews = "5432bcbf-9aad-5242-b902-cca2824c8663" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" @@ -75,14 +74,13 @@ DimensionalData = "0.27, 0.28" Distributions = "0.25" DocStringExtensions = "0.8, 0.9" Dynesty = "0.4" -Enzyme = "0.12" -EnzymeCore = "0.7" +Enzyme = "0.13" +EnzymeCore = "0.8" FillArrays = "1" ForwardDiff = "0.9, 0.10" HypercubeTransform = "0.4" IntervalSets = "0.6, 0.7" LogDensityProblems = "2" -LogDensityProblemsAD = "1" Makie = "0.21" NamedTupleTools = "0.13,0.14" Optimization = "4" @@ -101,7 +99,7 @@ StatsBase = "0.33,0.34" StructArrays = "0.5,0.6" Tables = "1" TransformVariables = "0.8" -VLBIImagePriors = "0.8" +VLBIImagePriors = "0.9" VLBILikelihoods = "^0.2.6" VLBISkyModels = "0.6" julia = "1.10" diff --git a/docs/src/ext/ahmc.md b/docs/src/ext/ahmc.md index 586c89c6..3d6f89d1 100644 --- a/docs/src/ext/ahmc.md +++ b/docs/src/ext/ahmc.md @@ -13,9 +13,14 @@ To sample a user can use follow the standard `AdvancedHMC` interface, e.g., chain = sample(post, NUTS(0.8), 10_000; n_adapts=5_000) ``` +!!! warning + To use HMC the `VLBIPosterior` must be created with a specific `admode` specified. + The `admode` can be a union of `Nothing` and `<:EnzymeCore.Mode` types. We recommend + using `Enzyme.set_runtime_activity(Enzyme.Reverse)` + + In addition our sample call has a few additional keyword arguments: - - `adtype = Val(:Enzyme)`: The autodiff package to use. Currently the only options is `Enzyme`. Note that you must load Enzyme before calling `sample`. - `saveto = MemoryStore()`: Specifies how to store the samples. The default is `MemoryStore` which stores the samples directly in RAM. For large models this is not a good idea. To save samples periodically to disk use [`DiskStore`](@ref), and then load the results with `load_samples`. Note that like most `AbstractMCMC` samplers the initial location can be specified with the `initial_params` argument. diff --git a/docs/src/ext/optimization.md b/docs/src/ext/optimization.md index bffabbfe..cce76ab6 100644 --- a/docs/src/ext/optimization.md +++ b/docs/src/ext/optimization.md @@ -7,6 +7,12 @@ optimization algorithm. To see what optimizers are available and what options are available, please see the `Optimizations.jl` [docs](http://optimization.sciml.ai/dev/). +!!! warning + To use use a gradient optimizer with AD, `VLBIPosterior` must be created with a specific `admode` specified. + The `admode` can be a union of `Nothing` and `<:EnzymeCore.Mode` types. We recommend + using `Enzyme.set_runtime_activity(Enzyme.Reverse)` + + ## Example ```julia @@ -18,5 +24,5 @@ using Enzyme # Some stuff to create a posterior object post # of type Comrade.Posterior -xopt, sol = comrade_opt(post, LBFGS(); adtype=Val(:Enzyme)) +xopt, sol = comrade_opt(post, LBFGS()) ``` \ No newline at end of file diff --git a/examples/advanced/HybridImaging/main.jl b/examples/advanced/HybridImaging/main.jl index 9cbecc06..4b7eb2c8 100644 --- a/examples/advanced/HybridImaging/main.jl +++ b/examples/advanced/HybridImaging/main.jl @@ -156,7 +156,7 @@ skym = SkyModel(sky, skyprior, g; metadata=skymetadata) # This is everything we need to specify our posterior distribution, which our is the main # object of interest in image reconstructions when using Bayesian inference. -post = VLBIPosterior(skym, intmodel, dvis) +post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse)) # To sample from our prior we can do xrand = prior_sample(rng, post) @@ -180,7 +180,7 @@ fig |> DisplayAs.PNG |> DisplayAs.Text #hide using Optimization using OptimizationOptimJL using Enzyme -xopt, sol = comrade_opt(post, LBFGS(), AutoEnzyme(;mode=Enzyme.Reverse); +xopt, sol = comrade_opt(post, LBFGS(); initial_params=prior_sample(rng, post), maxiters=1000, g_tol=1e0) diff --git a/examples/beginner/GeometricModeling/Project.toml b/examples/beginner/GeometricModeling/Project.toml index 8ee225ea..7e7be586 100644 --- a/examples/beginner/GeometricModeling/Project.toml +++ b/examples/beginner/GeometricModeling/Project.toml @@ -20,4 +20,4 @@ Pigeons = "0.4" Plots = "1" Pyehtim = "0.1" StableRNGs = "1" -VLBIImagePriors = "0.8" +VLBIImagePriors = "0.9" diff --git a/examples/intermediate/ClosureImaging/Project.toml b/examples/intermediate/ClosureImaging/Project.toml index 10dc6fd5..918f8c78 100644 --- a/examples/intermediate/ClosureImaging/Project.toml +++ b/examples/intermediate/ClosureImaging/Project.toml @@ -6,6 +6,7 @@ DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -14,6 +15,7 @@ Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029" +VLBILikelihoods = "90db92cd-0007-4c0a-8e51-dbf0782ce592" [compat] CairoMakie = "0.12" @@ -24,4 +26,4 @@ Pkg = "1" Plots = "1" Pyehtim = "0.1" StableRNGs = "1" -VLBIImagePriors = "0.8" +VLBIImagePriors = "0.9" diff --git a/examples/intermediate/ClosureImaging/main.jl b/examples/intermediate/ClosureImaging/main.jl index a166a0e9..44c0d5b0 100644 --- a/examples/intermediate/ClosureImaging/main.jl +++ b/examples/intermediate/ClosureImaging/main.jl @@ -129,7 +129,8 @@ skym = SkyModel(sky, prior, grid; metadata=skymeta) # Since we are fitting closures we do not need to include an instrument model, since # the closure likelihood is approximately independent of gains in the high SNR limit. -post = VLBIPosterior(skym, dlcamp, dcphase) +using Enzyme +post = VLBIPosterior(skym, dlcamp, dcphase; admode=set_runtime_activity(Enzyme.Reverse)) # ## Reconstructing the Image @@ -144,8 +145,7 @@ post = VLBIPosterior(skym, dlcamp, dcphase) # OptimizationOptimJL. We also need to import Enzyme to allow for automatic differentiation. using Optimization using OptimizationOptimJL -using Enzyme -xopt, sol = comrade_opt(post, LBFGS(), AutoEnzyme(;mode=Enzyme.Reverse); +xopt, sol = comrade_opt(post, LBFGS(); maxiters=1000, initial_params=prior_sample(rng, post)) diff --git a/examples/intermediate/PolarizedImaging/Project.toml b/examples/intermediate/PolarizedImaging/Project.toml index b5e8f895..f3163f05 100644 --- a/examples/intermediate/PolarizedImaging/Project.toml +++ b/examples/intermediate/PolarizedImaging/Project.toml @@ -6,13 +6,11 @@ DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029" diff --git a/examples/intermediate/PolarizedImaging/main.jl b/examples/intermediate/PolarizedImaging/main.jl index abf726c1..8739ff15 100644 --- a/examples/intermediate/PolarizedImaging/main.jl +++ b/examples/intermediate/PolarizedImaging/main.jl @@ -180,8 +180,8 @@ end # image model. Our image will be a 10x10 raster with a 60μas FOV. using Distributions using VLBIImagePriors -fovx = μas2rad(150.0) -fovy = μas2rad(150.0) +fovx = μas2rad(200.0) +fovy = μas2rad(200.0) nx = ny = 32 grid = imagepixels(fovx, fovy, nx, ny) @@ -204,7 +204,7 @@ skymeta = (; mimg=mimg./flux(mimg), ftot=0.6) cprior = corr_image_prior(grid, dvis) skyprior = ( c = cprior, - σ = truncated(Normal(0.0, 0.5); lower=0.0), + σ = Exponential(0.1), p = cprior, p0 = Normal(-2.0, 2.0), pσ = truncated(Normal(0.0, 1.0); lower=0.01), @@ -287,7 +287,7 @@ J = JonesSandwich(js, G, D, R) intprior = ( lgR = ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.2))), lgrat= ArrayPrior(IIDSitePrior(ScanSeg(), Normal(0.0, 0.01))), - gpR = ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(π^2))); refant=SEFDReference(0.0), phase=false), + gpR = ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(π^2))); refant=SEFDReference(0.0), phase=true), gprat= ArrayPrior(IIDSitePrior(ScanSeg(), DiagonalVonMises(0.0, inv(0.1^2))); refant = SingleReference(:AA, 0.0), phase=false), dRx = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), dRy = ArrayPrior(IIDSitePrior(TrackSeg(), Normal(0.0, 0.2))), @@ -301,8 +301,9 @@ intmodel = InstrumentModel(J, intprior) # intmodel = InstrumentModel(JonesR(;add_fr=true)) # Putting it all together, we form our likelihood and posterior objects for optimization and -# sampling. -post = VLBIPosterior(skym, intmodel, dvis) +# sampling, and specifying to use Enzyme.Reverse with runtime activity for AD. +using Enzyme +post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse)) # ## Reconstructing the Image and Instrument Effects @@ -323,8 +324,7 @@ tpost = asflat(post) # work with the polarized Comrade posterior is Enzyme. using Optimization using OptimizationOptimisers -using Enzyme -xopt, sol = comrade_opt(post, Optimisers.Adam(), AutoEnzyme(;mode=Enzyme.Reverse); +xopt, sol = comrade_opt(post, Optimisers.Adam(); initial_params=prior_sample(rng, post), maxiters=25_000) @@ -384,7 +384,7 @@ p |> DisplayAs.PNG |> DisplayAs.Text # other imaging examples. For example # ```julia # using AdvancedHMC -# chain = sample(rng, post, NUTS(0.8), 10_000; adtype=AutoEnzyme(;mode=Enzyme.Reverse), n_adapts=5000, progress=true, initial_params=xopt) +# chain = sample(rng, post, NUTS(0.8), 10_000, n_adapts=5000, progress=true, initial_params=xopt) # ``` diff --git a/examples/intermediate/StokesIImaging/Project.toml b/examples/intermediate/StokesIImaging/Project.toml index 6bdbe99a..a36a695a 100644 --- a/examples/intermediate/StokesIImaging/Project.toml +++ b/examples/intermediate/StokesIImaging/Project.toml @@ -23,4 +23,4 @@ Pkg = "1" Plots = "1" Pyehtim = "0.1" StableRNGs = "1" -VLBIImagePriors = "0.8" +VLBIImagePriors = "0.9" diff --git a/examples/intermediate/StokesIImaging/main.jl b/examples/intermediate/StokesIImaging/main.jl index 42a69c59..4933b901 100644 --- a/examples/intermediate/StokesIImaging/main.jl +++ b/examples/intermediate/StokesIImaging/main.jl @@ -145,7 +145,12 @@ intpr = ( intmodel = InstrumentModel(G, intpr) -post = VLBIPosterior(skym, intmodel, dvis) +# To form the posterior we just combine the skymodel, instrument model and the data. Additionally, +# since we want to use gradients we need to specify the AD mode. Essentially for all modes we recommend +# using `Enzyme.set_runtime_activity(Enzyme.Reverse)`. Eventually as Comrade and Enzyme matures we will +# no need `set_runtime_activity`. +using Enzyme +post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse)) # done using the `asflat` function. tpost = asflat(post) ndim = dimension(tpost) @@ -160,8 +165,7 @@ ndim = dimension(tpost) # To initialize our sampler we will use optimize using Adam using Optimization using OptimizationOptimisers -using Enzyme -xopt, sol = comrade_opt(post, Optimisers.Adam(), AutoEnzyme(;mode=Enzyme.Reverse); initial_params=prior_sample(rng, post), maxiters=20_000, g_tol=1e-1) +xopt, sol = comrade_opt(post, Optimisers.Adam(); initial_params=prior_sample(rng, post), maxiters=20_000, g_tol=1e-1) # !!! warning # Fitting gains tends to be very difficult, meaning that optimization can take a lot longer. @@ -208,7 +212,7 @@ plot(gt, layout=(3,3), size=(600,500)) |> DisplayAs.PNG |> DisplayAs.Text # run #- using AdvancedHMC -chain = sample(rng, post, NUTS(0.8), 1_000; adtype=AutoEnzyme(;mode=Enzyme.Reverse), n_adapts=500, progress=false, initial_params=xopt) +chain = sample(rng, post, NUTS(0.8), 1_000; n_adapts=500, progress=false, initial_params=xopt) #- # !!! note # The above sampler will store the samples in memory, i.e. RAM. For large models this diff --git a/ext/ComradeAdvancedHMCExt.jl b/ext/ComradeAdvancedHMCExt.jl index d00aa047..e346323e 100644 --- a/ext/ComradeAdvancedHMCExt.jl +++ b/ext/ComradeAdvancedHMCExt.jl @@ -10,7 +10,7 @@ using Accessors using ArgCheck using DocStringExtensions using HypercubeTransform -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems using Printf using Random using StatsBase @@ -36,15 +36,14 @@ end function AbstractMCMC.Sample( rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPosterior, - sampler::AbstractHMCSampler; adtype=Val(:Enzyme), initial_params=nothiing, kwargs...) - ∇ℓ = ADgradient(adtype, tpost) + sampler::AbstractHMCSampler; initial_params=nothiing, kwargs...) θ0 = initialize_params(tpost, initial_params) - model, smplr = make_sampler(rng, ∇ℓ, sampler, θ0) + model, smplr = make_sampler(rng, tpost, sampler, θ0) return AbstractMCMC.Sample(rng, model, smplr; initial_params=θ0, kwargs...) end """ - sample(rng, post::VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...;saveto=MemoryStore(), adtype=Val(:Enzyme), initial_params=nothing, kwargs...) + sample(rng, post::VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...;saveto=MemoryStore(), initial_params=nothing, kwargs...) Sample from the posterior `post` using the sampler `sampler` for `nsamples` samples. Additional arguments are forwarded to AbstractMCMC.sample. If `saveto` is a DiskStore, the samples will be @@ -59,7 +58,6 @@ saved to disk. If `initial_params` is not `nothing` then the sampler will start ## Keyword Arguments - `saveto`: If a DiskStore, the samples will be saved to disk, if [`MemoryStore`](@ref) the samples will be stored in memory/ram. - - `adtype`: The automatic differentiation type to use. The default if Enzyme which is the recommended choice for Comrade currently. - `initial_params`: The initial parameters to start the sampler from. If `nothing` then the sampler will start from a random point in the prior. - `kwargs`: Additional keyword arguments to pass to the sampler. Examples include `n_adapts` which is the total number of samples to use for adaptation. To see the others see the AdvancedHMC documentation. @@ -67,15 +65,19 @@ saved to disk. If `initial_params` is not `nothing` then the sampler will start function AbstractMCMC.sample( rng::Random.AbstractRNG, post::Comrade.VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...; - saveto=MemoryStore(), adtype=Val(:Enzyme), initial_params=nothing, kwargs...) + saveto=MemoryStore(), initial_params=nothing, kwargs...) - saveto isa DiskStore && return sample_to_disk(rng, post, sampler, nsamples, args...; outdir=saveto.name, output_stride=min(saveto.stride, nsamples), adtype, initial_params, kwargs...) + saveto isa DiskStore && return sample_to_disk(rng, post, sampler, nsamples, args...; outdir=saveto.name, output_stride=min(saveto.stride, nsamples), initial_params, kwargs...) + if isnothing(Comrade.admode(post)) + throw(ArgumentError("You must specify an automatic differentiation type in VLBIPosterior with admode kwarg")) + else + tpost = asflat(post) + end tpost = asflat(post) - ∇ℓ = ADgradient(adtype, tpost) θ0 = initialize_params(tpost, initial_params) - model, smplr = make_sampler(rng, ∇ℓ, sampler, θ0) + model, smplr = make_sampler(rng, tpost, sampler, θ0) res = sample(rng, model, smplr, nsamples, args...; initial_params=θ0, saveto=saveto, chain_type=Array, kwargs...) @@ -90,7 +92,6 @@ end function initialize(rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPosterior, sampler::AbstractHMCSampler, nsamples, outbase, args...; n_adapts = min(nsamples÷2, 1000), - adtype = Val(:Enzyme), initial_params=nothing, outdir = "Results", output_stride=min(100, nsamples), restart = false, @@ -119,7 +120,7 @@ function initialize(rng::Random.AbstractRNG, tpost::Comrade.TransformedVLBIPoste @warn "No starting location chosen, picking start from prior" θ0 = prior_sample(rng, tpost) end - t = Sample(rng, tpost, sampler; initial_params=θ0, adtype, n_adapts, kwargs...)(1:nsamples) + t = Sample(rng, tpost, sampler; initial_params=θ0, n_adapts, kwargs...)(1:nsamples) pt = Iterators.partition(t, output_stride) nscans = nsamples÷output_stride + (nsamples%output_stride!=0 ? 1 : 0) @@ -158,7 +159,6 @@ end function sample_to_disk(rng::Random.AbstractRNG, post::Comrade.VLBIPosterior, sampler::AbstractHMCSampler, nsamples, args...; - adtype = Val(:Enzyme), n_adapts = min(nsamples÷2, 1000), initial_params=nothing, outdir = "Results", restart=false, @@ -172,7 +172,7 @@ function sample_to_disk(rng::Random.AbstractRNG, post::Comrade.VLBIPosterior, pt, state, out, i = initialize( rng, tpost, sampler, nsamples, outbase, args...; - n_adapts, adtype, + n_adapts, initial_params, restart, outdir, output_stride, kwargs... ) diff --git a/ext/ComradeEnzymeExt.jl b/ext/ComradeEnzymeExt.jl index 3180c945..d2cae43f 100644 --- a/ext/ComradeEnzymeExt.jl +++ b/ext/ComradeEnzymeExt.jl @@ -1,10 +1,20 @@ module ComradeEnzymeExt using Enzyme +using Comrade +using LogDensityProblems -function __init__() - # We need this to ensure than Enzyme can AD through the Comrade code base - Enzyme.API.runtimeActivity!(true) +LogDensityProblems.dimension(d::Comrade.TransformedVLBIPosterior) = dimension(d) +LogDensityProblems.capabilities(::Type{<:Comrade.TransformedVLBIPosterior}) = LogDensityProblems.LogDensityOrder{1}() + + +function LogDensityProblems.logdensity_and_gradient(d::Comrade.TransformedVLBIPosterior, x::AbstractArray) + mode = Enzyme.EnzymeCore.WithPrimal(Comrade.admode(d)) + dx = zero(x) + (_, y) = autodiff(mode, Comrade.logdensityof, Active, Const(d), Duplicated(x, dx)) + return y, dx end + + end \ No newline at end of file diff --git a/ext/ComradeOptimizationExt.jl b/ext/ComradeOptimizationExt.jl index 294a1e1d..52ebfc9b 100644 --- a/ext/ComradeOptimizationExt.jl +++ b/ext/ComradeOptimizationExt.jl @@ -6,11 +6,21 @@ using Optimization using Distributions using LinearAlgebra using HypercubeTransform - +using LogDensityProblems function Optimization.OptimizationFunction(post::Comrade.TransformedVLBIPosterior, args...; kwargs...) - ℓ(x,p) = -logdensityof(p, x) - return SciMLBase.OptimizationFunction(ℓ, args...; kwargs...) + ℓ(x,p=post) = -logdensityof(p, x) + if isnothing(Comrade.admode(post)) + return SciMLBase.OptimizationFunction(ℓ, args...; kwargs...) + else + function grad(G, x, p) + (_, dx) = LogDensityProblems.logdensity_and_gradient(post, x) + dx .*= -1 + G .= dx + return G + end + return SciMLBase.OptimizationFunction(ℓ, args...; grad=grad, kwargs...) + end end # """ @@ -36,17 +46,20 @@ end """ - comrade_opt(post::VLBIPosterior, opt, adtype=nothing, args...; initial_params=nothing, kwargs...) + comrade_opt(post::VLBIPosterior, opt, args...; initial_params=nothing, kwargs...) Optimize the posterior `post` using the `opt` optimizer. +!!! warning + To use use a gradient optimizer with AD, `VLBIPosterior` must be created with a specific `admode` specified. + The `admode` can be a union of `Nothing` and `<:EnzymeCore.Mode` types. We recommend + using `Enzyme.set_runtime_activity(Enzyme.Reverse)`. + + ## Arguments - `post` : The posterior to optimize. - `opt` : The optimizer to use. This can be any optimizer from `Optimization.jl`. - - `adtype` : The automatic differentiation type to use. The default is `nothing` which means - no automatic differentiation is used. To specify to use automatic differentiation - set `adtype`. For example if you wish to use `Enzyme` set `adtype=Optimization.AutoEnzyme(;mode=Enzyme.Reverse)`. - `args` : Additional arguments passed to the `Optimization`, `solve` method ## Keyword Arguments @@ -57,15 +70,14 @@ Optimize the posterior `post` using the `opt` optimizer. - `kwargs` : Additional keyword arguments passed `Optimization.jl` `solve` method. """ -function Comrade.comrade_opt(post::VLBIPosterior, opt, adtype=nothing, args...; initial_params=nothing, kwargs...) - if isnothing(adtype) - adtype = Optimization.SciMLBase.NoAD() +function Comrade.comrade_opt(post::VLBIPosterior, opt, args...; initial_params=nothing, kwargs...) + if isnothing(Comrade.admode(post)) tpost = ascube(post) else tpost = asflat(post) end - f = OptimizationFunction(tpost, adtype) + f = OptimizationFunction(tpost) if isnothing(initial_params) initial_params = prior_sample(tpost) diff --git a/src/inference/optimization.jl b/src/inference/optimization.jl index ec5b49b4..929f353b 100644 --- a/src/inference/optimization.jl +++ b/src/inference/optimization.jl @@ -1,7 +1,7 @@ export comrade_opt, comrade_laplace """ - comrade_opt(post::VLBIPosterior, opt, adtype=Optimization.NoAD(), args...; initial_params=nothing, kwargs...) + comrade_opt(post::VLBIPosterior, opt, adtype=nothing, args...; initial_params=nothing, kwargs...) Optimize the posterior `post` using the `opt` optimizer. The `adtype` specifies the automatic differentiation. The `args/kwargs` are forwarded to `the specific optimization package. diff --git a/src/posterior/transformed.jl b/src/posterior/transformed.jl index f01e8574..e8859070 100644 --- a/src/posterior/transformed.jl +++ b/src/posterior/transformed.jl @@ -11,7 +11,7 @@ struct TransformedVLBIPosterior{P<:VLBIPosterior,T} <: AbstractVLBIPosterior transform::T end (post::TransformedVLBIPosterior)(θ) = logdensityof(post, θ) - +admode(post::TransformedVLBIPosterior) = admode(post.lpost) function prior_sample(rng, tpost::TransformedVLBIPosterior, args...) inv = Base.Fix1(HypercubeTransform.inverse, tpost) diff --git a/src/posterior/vlbiposterior.jl b/src/posterior/vlbiposterior.jl index a1c0b63f..1329170d 100644 --- a/src/posterior/vlbiposterior.jl +++ b/src/posterior/vlbiposterior.jl @@ -1,26 +1,31 @@ -struct VLBIPosterior{D, T, P, MS<:ObservedSkyModel, MI<:AbstractInstrumentModel} <: AbstractVLBIPosterior +struct VLBIPosterior{D, T, P, MS<:ObservedSkyModel, MI<:AbstractInstrumentModel, ADMode<:Union{Nothing,EnzymeCore.Mode}} <: AbstractVLBIPosterior data::D lklhds::T prior::P skymodel::MS instrumentmodel::MI + admode::ADMode end (post::VLBIPosterior)(θ) = logdensityof(post, θ) - +admode(post::VLBIPosterior) = post.admode """ - VLBIPosterior(skymodel::SkyModel, instumentmodel::InstrumentModel, dataproducts::EHTObservationTable...) + VLBIPosterior(skymodel::SkyModel, instumentmodel::InstrumentModel, dataproducts::EHTObservationTable...; admode=nothing) Creates a VLBILikelihood using the `skymodel` its related metadata `skymeta` -and the `instrumentmodel` and its metadata `instumentmeta`. -. The `model` -is a function that converts from parameters `θ` to a Comrade +and the `instrumentmodel` and its metadata `instumentmeta`. The `model` is a +function that converts from parameters `θ` to a Comrade AbstractModel which can be used to compute [`visibilitymap`](@ref) and a set of `metadata` that is used by `model` to compute the model. +To enable automatic differentiation, the `admode` keyword argument can be set to any `EnzymeCore.Mode` type +of if no AD is desired then `nothing`. We recommend using `Enzyme.set_runtime_activity(Enzyme.Reverse)` +for essentially every problem. Note that runtime activity does have a perfomance cost, and as Enzyme and +Comrade matures we expect this to not need runtime activity. + # Warning The `model` itself must be a two argument function where the first argument is the set @@ -56,6 +61,7 @@ function VLBIPosterior( skymodel::AbstractSkyModel, instrumentmodel::AbstractInstrumentModel, dataproducts::EHTObservationTable...; + admode = nothing ) @@ -69,11 +75,11 @@ function VLBIPosterior( return VLBIPosterior{ typeof(dataproducts),typeof(ls),typeof(total_prior), - typeof(sky), typeof(int)}(dataproducts, ls, total_prior, sky, int) + typeof(sky), typeof(int), typeof(admode)}(dataproducts, ls, total_prior, sky, int, admode) end -VLBIPosterior(skymodel::AbstractSkyModel, dataproducts::EHTObservationTable...) = - VLBIPosterior(skymodel, IdealInstrumentModel(), dataproducts...) +VLBIPosterior(skymodel::AbstractSkyModel, dataproducts::EHTObservationTable...; admode=nothing) = + VLBIPosterior(skymodel, IdealInstrumentModel(), dataproducts...; admode) function combine_prior(skyprior, instrumentmodelprior) return NamedDist((sky=skyprior, instrument=instrumentmodelprior)) diff --git a/src/rules.jl b/src/rules.jl index c502e078..3821e17b 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -31,132 +31,4 @@ # end # Enzyme.EnzymeRules.inactive(::typeof(Base.dataids), u::StructArray) = nothing -# Enzyme.EnzymeRules.inactive(::typeof(Base.unalias), u::StructArray, args...) = nothing - - -## Temporary rule for sparse matmuls. Will be removed once Enzyme merges https://github.com/EnzymeAD/Enzyme.jl/pull/1792 -using SparseArrays: SparseMatrixCSCUnion -using LinearAlgebra -using EnzymeCore: Annotation - -function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth, - func::Const{typeof(LinearAlgebra.mul!)}, - ::Type{RT}, - C::Annotation{<:StridedVecOrMat}, - A::Const{<:SparseMatrixCSCUnion}, - B::Annotation{<:StridedVecOrMat}, - α::Annotation{<:Number}, - β::Annotation{<:Number} - ) where {RT} - - cache_C = !(isa(β, Const)) ? copy(C.val) : nothing - # Always need to do forward pass otherwise primal may not be correct - func.val(C.val, A.val, B.val, α.val, β.val) - - primal = if EnzymeRules.needs_primal(config) - C.val - else - nothing - end - - shadow = if EnzymeRules.needs_shadow(config) - C.dval - else - nothing - end - - # Check if A is overwritten and B is active (and thus required) - cache_A = ( EnzymeRules.overwritten(config)[5] - && !(typeof(B) <: Const) - && !(typeof(C) <: Const) - ) ? copy(A.val) : nothing - - # cache_B = ( EnzymeRules.overwritten(config)[6]) ? copy(B.val) : nothing - - if !isa(α, Const) - cache_α = A.val*B.val - else - cache_α = nothing - end - - cache = (cache_C, cache_A, cache_α) - - return EnzymeRules.AugmentedReturn(primal, shadow, cache) -end - -function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth, - func::Const{typeof(LinearAlgebra.mul!)}, - ::Type{RT}, cache, - C::Annotation{<:StridedVecOrMat}, - A::Const{<:SparseMatrixCSCUnion}, - B::Annotation{<:StridedVecOrMat}, - α::Annotation{<:Number}, - β::Annotation{<:Number} - ) where {RT} - - cache_C, cache_A, cache_α = cache - Cval = !isnothing(cache_C) ? cache_C : C.val - Aval = !isnothing(cache_A) ? cache_A : A.val - # Bval = !isnothing(cache_B) ? cache_B : B.val - - N = EnzymeRules.width(config) - if !isa(C, Const) - dCs = C.dval - dBs = isa(B, Const) ? dCs : B.dval - - dα = if !isa(α, Const) - if N == 1 - LinearAlgebra.dot(C.dval, cache_α) - else - ntuple(Val(N)) do i - Base.@_inline_meta - LinearAlgebra.dot(C.dval[i], cache_α) - end - end - else - nothing - end - - dβ = if !isa(β, Const) - if N == 1 - LinearAlgebra.dot(C.dval, Cval) - else - ntuple(Val(N)) do i - Base.@_inline_meta - LinearAlgebra.dot(C.dval[i], Cval) - end - end - else - nothing - end - - for i in 1:N - - # This rule is incorrect since I need to project dA to have the same - # sparsity pattern as A. - # if !isa(A, Const) - # dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - # #dA .+= α*dC*B' - # mul!(dA, dC, Bval', α.val, true) - # end - - if !isa(B, Const) - #dB .+= α*A'*dC - if N ==1 - func.val(dBs, Aval', dCs, α.val, true) - else - func.val(dBs[i], Aval', dCs[i], α.val, true) - end - end - - if N==1 - dCs .*= β.val - else - dCs[i] .*= β.val - end - end - end - - return (nothing, nothing, nothing, dα, dβ) -end - +# Enzyme.EnzymeRules.inactive(::typeof(Base.unalias), u::StructArray, args...) = nothing \ No newline at end of file diff --git a/test/Core/bayes.jl b/test/Core/bayes.jl index 9c41f85c..731b04f7 100644 --- a/test/Core/bayes.jl +++ b/test/Core/bayes.jl @@ -14,8 +14,8 @@ using Enzyme _,vis, amp, lcamp, cphase = load_data() g = imagepixels(μas2rad(150.0), μas2rad(150.0), 256, 256) skym = SkyModel(test_model, test_prior(), g) - post_cl = VLBIPosterior(skym, lcamp, cphase) - post = VLBIPosterior(skym, vis) + post_cl = VLBIPosterior(skym, lcamp, cphase; admode=set_runtime_activity(Enzyme.Reverse)) + post = VLBIPosterior(skym, vis; admode=set_runtime_activity(Enzyme.Reverse)) prior = test_prior() @@ -54,7 +54,7 @@ using Enzyme show(IOBuffer(), MIME"text/plain"(), tpostf) - f = OptimizationFunction(tpostf, Optimization.AutoEnzyme(;mode=Enzyme.Reverse)) + f = OptimizationFunction(tpostf) x0 = transform(tpostf, [ 0.1, 0.4, 0.5, diff --git a/test/ext/comradeahmc.jl b/test/ext/comradeahmc.jl index 7948f44f..53dcf30f 100644 --- a/test/ext/comradeahmc.jl +++ b/test/ext/comradeahmc.jl @@ -7,7 +7,7 @@ using Enzyme _, _, _, lcamp, cphase = load_data() g = imagepixels(μas2rad(150.0), μas2rad(150.0), 256, 256) skym = SkyModel(test_model, test_prior(), g) - post = VLBIPosterior(skym, lcamp, cphase) + post = VLBIPosterior(skym, lcamp, cphase; admode=set_runtime_activity(Enzyme.Reverse)) x0 = (sky = (f1 = 1.0916271439905998, σ1 = 8.230088139590025e-11, @@ -20,7 +20,7 @@ using Enzyme x = 1.451956089157719e-10, y = 1.455983181049137e-10),) s1 = NUTS(0.65) - hchain = sample(post, s1, 1_000; n_adapts=500, progress=false, adtype=Val(:Enzyme)) + hchain = sample(post, s1, 1_000; n_adapts=500, progress=false) hchain = sample(post, s1, 1_000; n_adapts=500, progress=false, initial_params=x0) out = sample(post, s1, 1_000; n_adapts=500, saveto=DiskStore(name=joinpath(@__DIR__, "Test")), initial_params=x0) out = sample(post, s1, 1_200; n_adapts=500, saveto=DiskStore(name=joinpath(@__DIR__, "Test")), initial_params=x0, restart=true) diff --git a/test/ext/comradeoptimization.jl b/test/ext/comradeoptimization.jl index 64d1c4ec..814ef69c 100644 --- a/test/ext/comradeoptimization.jl +++ b/test/ext/comradeoptimization.jl @@ -8,7 +8,7 @@ using Test _, _, _, lcamp, cphase = load_data() g = imagepixels(μas2rad(150.0), μas2rad(150.0), 256, 256) skym = SkyModel(test_model, test_prior(), g) - post = VLBIPosterior(skym, lcamp, cphase) + post = VLBIPosterior(skym, lcamp, cphase; admode=set_runtime_activity(Enzyme.Reverse)) tpost = asflat(post) x0 = transform(tpost, [ 0.0, @@ -23,7 +23,7 @@ using Test 2.0, ]) - xopt2, sol = comrade_opt(post, LBFGS(), AutoEnzyme(;mode=Enzyme.Reverse); initial_params=x0, maxiters=10_000) + xopt2, sol = comrade_opt(post, LBFGS(); initial_params=x0, maxiters=10_000) xopt = xopt2.sky @test isapprox(xopt.f1/xopt.f2, 2.0, atol=1e-3) From fd1a436a46fc9c9ce8c0110e325d9f63fe6d57b7 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Oct 2024 14:36:05 -0400 Subject: [PATCH 2/6] Refactor likelihoods to prevent anonymous functions --- src/posterior/likelihood.jl | 68 +++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/src/posterior/likelihood.jl b/src/posterior/likelihood.jl index 493c9542..e861131b 100644 --- a/src/posterior/likelihood.jl +++ b/src/posterior/likelihood.jl @@ -15,51 +15,86 @@ are simulated data. likelihood(d::ConditionedLikelihood, μ) = d.kernel(μ) +struct _Visibility{S,L} + S::S + L::L +end + +function (v::_Visibility)(μ) + return ComplexVisLikelihood(baseimage(μ), v.S, v.L) +end # internal function that creates the likelihood for a set of complex visibilities function makelikelihood(data::Comrade.EHTObservationTable{<:Comrade.EHTVisibilityDatum}) Σ = noise(data).^2 vis = measurement(data) lnorm = VLBILikelihoods.lognorm(ComplexVisLikelihood(vis, Σ)) - ℓ = ConditionedLikelihood(vis) do μ - ComplexVisLikelihood(μ, Σ, lnorm) - end + ℓ = ConditionedLikelihood(_Visibility(Σ, lnorm), vis) return ℓ end +struct _Coherency{S,L} + S::S + L::L +end + +function (c::_Coherency)(μ) + return CoherencyLikelihood(baseimage(μ), c.S, c.L) +end + function makelikelihood(data::Comrade.EHTObservationTable{<:Comrade.EHTCoherencyDatum}) Σ = map(x->x.^2, noise(data)) vis = measurement(data) - # lnorm = VLBILikelihoods.lognorm(CoherencyLikelihood(vis, Σ)) - ℓ = ConditionedLikelihood(vis) do μ - # @info typeof(μ) - CoherencyLikelihood(baseimage(μ), Σ, 0.0) - end + lnorm = VLBILikelihoods.lognorm(CoherencyLikelihood(vis, Σ)) + ℓ = ConditionedLikelihood(_Coherency(Σ, lnorm), vis) return ℓ end +struct _VisAmp{S} + S::S +end + +function (v::_VisAmp)(μ) + return RiceAmplitudeLikelihood(abs.(baseimage(μ)), v.S) +end + # internal function that creates the likelihood for a set of visibility amplitudes function makelikelihood(data::Comrade.EHTObservationTable{<:Comrade.EHTVisibilityAmplitudeDatum}) Σ = noise(data).^2 amp = measurement(data) - ℓ = ConditionedLikelihood(amp) do μ - RiceAmplitudeLikelihood(abs.(μ), Σ) - end + ℓ = ConditionedLikelihood(_VisAmp(Σ), amp) return ℓ end +struct _LCamp{F,S,L} + f::F + S::S + L::L +end + +function (c::_LCamp)(μ) + return AmplitudeLikelihood(c.f(baseimage(μ)), c.S, c.L) +end + # internal function that creates the likelihood for a set of log closure amplitudes function makelikelihood(data::Comrade.EHTObservationTable{<:Comrade.EHTLogClosureAmplitudeDatum}) Σlca = factornoisecovariance(arrayconfig(data)) f = Base.Fix2(logclosure_amplitudes, designmat(arrayconfig(data))) amp = measurement(data) lnorm = VLBILikelihoods.lognorm(AmplitudeLikelihood(amp, Σlca)) - ℓ = ConditionedLikelihood(amp) do μ - AmplitudeLikelihood(f(μ), Σlca, lnorm) - end + ℓ = ConditionedLikelihood(_LCamp(f, Σlca, lnorm), amp) return ℓ end +struct _CPhase{F,S,L} + f::F + S::S + L::L +end + +function (c::_CPhase)(μ) + return ClosurePhaseLikelihood(c.f(baseimage(μ)), c.S, c.L) +end # internal function that creates the likelihood for a set of closure phase datum function makelikelihood(data::Comrade.EHTObservationTable{<:Comrade.EHTClosurePhaseDatum}) @@ -67,9 +102,6 @@ function makelikelihood(data::Comrade.EHTObservationTable{<:Comrade.EHTClosurePh f = Base.Fix2(closure_phases, designmat(arrayconfig(data))) phase = measurement(data) lnorm = VLBILikelihoods.lognorm(ClosurePhaseLikelihood(phase, Σcp)) - ℓ = ConditionedLikelihood(phase) do μ - ClosurePhaseLikelihood(f(μ), Σcp, lnorm) - end - + ℓ = ConditionedLikelihood(_CPhase(f, Σcp, lnorm), phase) return ℓ end From 145c098e8fe02547f2e029d1a3968a110fcf6bc8 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Oct 2024 15:45:33 -0400 Subject: [PATCH 3/6] Fix tests --- src/instrument/instrument_transforms.jl | 8 ++--- test/Core/bayes.jl | 8 ++--- test/Core/rules.jl | 45 ------------------------- 3 files changed, 8 insertions(+), 53 deletions(-) delete mode 100644 test/Core/rules.jl diff --git a/src/instrument/instrument_transforms.jl b/src/instrument/instrument_transforms.jl index 857b380f..62da9bf2 100644 --- a/src/instrument/instrument_transforms.jl +++ b/src/instrument/instrument_transforms.jl @@ -46,11 +46,11 @@ end @inline function site_sum(y, site_map::SiteLookup) yout = similar(y) - for site in site_map.lookup - ys = @inbounds @view y[site] + @inbounds for site in site_map.lookup + ys = @view y[site] # y should never alias so we should be fine here. - youts = @inbounds @view yout[site] - @inline cumsum!(youts, ys) + youts = @view yout[site] + cumsum!(youts, ys) end return yout end diff --git a/test/Core/bayes.jl b/test/Core/bayes.jl index 731b04f7..b2ae78c0 100644 --- a/test/Core/bayes.jl +++ b/test/Core/bayes.jl @@ -15,7 +15,7 @@ using Enzyme g = imagepixels(μas2rad(150.0), μas2rad(150.0), 256, 256) skym = SkyModel(test_model, test_prior(), g) post_cl = VLBIPosterior(skym, lcamp, cphase; admode=set_runtime_activity(Enzyme.Reverse)) - post = VLBIPosterior(skym, vis; admode=set_runtime_activity(Enzyme.Reverse)) + post = VLBIPosterior(skym, vis) prior = test_prior() @@ -135,7 +135,7 @@ using FiniteDifferences tpost = asflat(post) x = prior_sample(tpost) - gz = Enzyme.gradient(Enzyme.Reverse, Const(tpost), x) + gz = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpost), x) mfd = central_fdm(5,1) gfd, = FiniteDifferences.grad(mfd, tpost, x) @test gz ≈ gfd @@ -152,7 +152,7 @@ using FiniteDifferences x = prior_sample(tpost) fj = instrumentmodel(post, prior_sample(post)) residual(post, Comrade.transform(tpost, x)) - gz = Enzyme.gradient(Enzyme.Reverse, Const(tpost), x) + gz = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpost), x) mfd = central_fdm(5,1) gfd, = FiniteDifferences.grad(mfd, tpost, x) @test gz ≈ gfd @@ -218,7 +218,7 @@ end x0 = prior_sample(tpostf) @inferred logdensityof(tpostf, x0) - gz = Enzyme.gradient(Enzyme.Reverse, Const(tpostf), x0) + gz = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpostf), x0) gn, = FiniteDifferences.grad(mfd, tpostf, x0) @test gz ≈ gn end diff --git a/test/Core/rules.jl b/test/Core/rules.jl deleted file mode 100644 index fa441444..00000000 --- a/test/Core/rules.jl +++ /dev/null @@ -1,45 +0,0 @@ -using SparseArrays -using LinearAlgebra -using EnzymeTestUtils - -@testset "SparseArrays spmatvec reverse rule" begin - C = zeros(18) - M = sprand(18, 9, 0.1) - v = randn(9) - α = 2.0 - β = 1.0 - - for Tret in (Duplicated,), Tv in (Const, Duplicated,), - Tα in (Const, Active), Tβ in (Const, Active) - - are_activities_compatible(Tret, Tret, Tv, Tα, Tβ) || continue - test_reverse(mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) - - end - - - for Tret in (Duplicated,), Tv in (Const, Duplicated,), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tret, Tv) || continue - test_reverse(mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) - end -end - -@testset "SparseArrays spmatmat reverse rule" begin - C = zeros(18, 11) - M = sprand(18, 9, 0.1) - v = randn(9, 11) - α = 2.0 - β = 1.0 - - for Tret in (Duplicated, ), Tv in (Const, Duplicated, ), - Tα in (Const, Active), Tβ in (Const, Active) - - are_activities_compatible(Tret, Tv, Tα, Tβ) || continue - test_reverse(mul!, Tret, (C, Tret), (M, Const), (v, Tv), (α, Tα), (β, Tβ)) - end - - for Tret in (Duplicated, ), Tv in (Const, Duplicated, ), bα in (true, false), bβ in (true, false) - are_activities_compatible(Tret, Tv) || continue - test_reverse(mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const)) - end -end From 64c0a9178da93b122af9a54687d2200d4fca76fa Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Oct 2024 17:35:33 -0400 Subject: [PATCH 4/6] Fix tutorial --- examples/advanced/HybridImaging/Project.toml | 2 +- ext/ComradeOptimizationExt.jl | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/advanced/HybridImaging/Project.toml b/examples/advanced/HybridImaging/Project.toml index 28a8fe93..11b0baee 100644 --- a/examples/advanced/HybridImaging/Project.toml +++ b/examples/advanced/HybridImaging/Project.toml @@ -21,4 +21,4 @@ Plots = "1" Pyehtim = "0.1" StableRNGs = "1" StatsBase = "0.34" -VLBIImagePriors = "0.8" +VLBIImagePriors = "0.9" diff --git a/ext/ComradeOptimizationExt.jl b/ext/ComradeOptimizationExt.jl index 52ebfc9b..aed367c1 100644 --- a/ext/ComradeOptimizationExt.jl +++ b/ext/ComradeOptimizationExt.jl @@ -70,8 +70,8 @@ Optimize the posterior `post` using the `opt` optimizer. - `kwargs` : Additional keyword arguments passed `Optimization.jl` `solve` method. """ -function Comrade.comrade_opt(post::VLBIPosterior, opt, args...; initial_params=nothing, kwargs...) - if isnothing(Comrade.admode(post)) +function Comrade.comrade_opt(post::VLBIPosterior, opt, args...; initial_params=nothing, lb=nothing, ub=nothing, cube=false, kwargs...) + if isnothing(Comrade.admode(post)) || cube tpost = ascube(post) else tpost = asflat(post) @@ -85,11 +85,9 @@ function Comrade.comrade_opt(post::VLBIPosterior, opt, args...; initial_params=n initial_params = Comrade.inverse(tpost, initial_params) end - lb = nothing - ub = nothing if tpost.transform isa HypercubeTransform.AbstractHypercubeTransform lb=fill(0.0001, dimension(tpost)) - ub = fill(0.9999, dimension(tpost)) + ub=fill(0.9999, dimension(tpost)) end prob = OptimizationProblem(f, initial_params, tpost; lb, ub) From 0ea755cea83663df98f8f55d81ffc936fe359912 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Oct 2024 17:40:52 -0400 Subject: [PATCH 5/6] update test for enzyme 0.13 --- test/Core/bayes.jl | 10 +++++----- test/Core/core.jl | 2 +- test/Core/partially_fixed.jl | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/Core/bayes.jl b/test/Core/bayes.jl index b2ae78c0..ab8926f2 100644 --- a/test/Core/bayes.jl +++ b/test/Core/bayes.jl @@ -104,8 +104,8 @@ using Enzyme @test LogDensityProblems.dimension(tpostc) == length(c0) @test LogDensityProblems.capabilities(typeof(post)) === LogDensityProblems.LogDensityOrder{0}() - @test LogDensityProblems.capabilities(typeof(tpostf)) === LogDensityProblems.LogDensityOrder{0}() - @test LogDensityProblems.capabilities(typeof(tpostc)) === LogDensityProblems.LogDensityOrder{0}() + @test LogDensityProblems.capabilities(typeof(tpostf)) === LogDensityProblems.LogDensityOrder{1}() + @test LogDensityProblems.capabilities(typeof(tpostc)) === LogDensityProblems.LogDensityOrder{1}() end @testset "corr image prior" begin @@ -135,7 +135,7 @@ using FiniteDifferences tpost = asflat(post) x = prior_sample(tpost) - gz = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpost), x) + gz, = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpost), x) mfd = central_fdm(5,1) gfd, = FiniteDifferences.grad(mfd, tpost, x) @test gz ≈ gfd @@ -152,7 +152,7 @@ using FiniteDifferences x = prior_sample(tpost) fj = instrumentmodel(post, prior_sample(post)) residual(post, Comrade.transform(tpost, x)) - gz = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpost), x) + gz, = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpost), x) mfd = central_fdm(5,1) gfd, = FiniteDifferences.grad(mfd, tpost, x) @test gz ≈ gfd @@ -218,7 +218,7 @@ end x0 = prior_sample(tpostf) @inferred logdensityof(tpostf, x0) - gz = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpostf), x0) + gz, = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), Const(tpostf), x0) gn, = FiniteDifferences.grad(mfd, tpostf, x0) @test gz ≈ gn end diff --git a/test/Core/core.jl b/test/Core/core.jl index 1ef7e53e..7523c531 100644 --- a/test/Core/core.jl +++ b/test/Core/core.jl @@ -13,4 +13,4 @@ include(joinpath(@__DIR__, "observation.jl")) include(joinpath(@__DIR__, "partially_fixed.jl")) include(joinpath(@__DIR__, "models.jl")) include(joinpath(@__DIR__, "bayes.jl")) -include(joinpath(@__DIR__, "rules.jl")) +# include(joinpath(@__DIR__, "rules.jl")) diff --git a/test/Core/partially_fixed.jl b/test/Core/partially_fixed.jl index f8213667..d73e33e9 100644 --- a/test/Core/partially_fixed.jl +++ b/test/Core/partially_fixed.jl @@ -28,8 +28,8 @@ using Enzyme gfdf, = grad(fdm, f, x) gfdlj, = grad(fdm, flj, x) - gzf = Enzyme.gradient(Enzyme.Reverse, Const(f), x) - gzflj = Enzyme.gradient(Enzyme.Reverse, Const(flj), x) + gzf, = Enzyme.gradient(Enzyme.Reverse, Const(f), x) + gzflj, = Enzyme.gradient(Enzyme.Reverse, Const(flj), x) @test gzf ≈ gfdf @test gzflj ≈ gfdlj From 289bd34139840678c54f8d76a9fd2c3f68d73142 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sat, 5 Oct 2024 18:32:36 -0400 Subject: [PATCH 6/6] Update hybrid --- examples/advanced/HybridImaging/main.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced/HybridImaging/main.jl b/examples/advanced/HybridImaging/main.jl index 4b7eb2c8..e4043f11 100644 --- a/examples/advanced/HybridImaging/main.jl +++ b/examples/advanced/HybridImaging/main.jl @@ -156,6 +156,7 @@ skym = SkyModel(sky, skyprior, g; metadata=skymetadata) # This is everything we need to specify our posterior distribution, which our is the main # object of interest in image reconstructions when using Bayesian inference. +using Enzyme post = VLBIPosterior(skym, intmodel, dvis; admode=set_runtime_activity(Enzyme.Reverse)) # To sample from our prior we can do @@ -179,7 +180,6 @@ fig |> DisplayAs.PNG |> DisplayAs.Text #hide # To use this we use the [`comrade_opt`](@ref) function using Optimization using OptimizationOptimJL -using Enzyme xopt, sol = comrade_opt(post, LBFGS(); initial_params=prior_sample(rng, post), maxiters=1000, g_tol=1e0)