From 175e2c982d16c1248100e8deff79aa7e735037f8 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 29 Oct 2023 12:50:44 -0400 Subject: [PATCH] Add pigeons as an extension --- Project.toml | 6 ++- examples/Project.toml | 1 + examples/geometric_modeling.jl | 16 ++++---- ext/ComradePigeonsExt.jl | 66 ++++++++++++++++++++++++++++++++ src/observations/observations.jl | 2 +- 5 files changed, 81 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index e8712750..0d26933b 100644 --- a/Project.toml +++ b/Project.toml @@ -45,10 +45,12 @@ VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca" [weakdeps] Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31" Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468" [extensions] ComradeMakieExt = "Makie" +ComradePigeonsExt = "Pigeons" ComradePyehtimExt = "Pyehtim" [compat] @@ -71,6 +73,7 @@ Makie = "0.19" NamedTupleTools = "0.13,0.14" PaddedViews = "0.5" ParameterHandling = "0.4" +Pigeons = "0.2" PolarizedTypes = "0.1" PrettyTables = "1, 2" Pyehtim = "0.1" @@ -93,8 +96,9 @@ julia = "1.8" [extras] CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31" Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Makie", "Pyehtim"] +test = ["Test", "Makie", "Pigeons", "Pyehtim"] diff --git a/examples/Project.toml b/examples/Project.toml index 91095f95..76a937d6 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -20,6 +20,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" +Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/examples/geometric_modeling.jl b/examples/geometric_modeling.jl index 964fcc5c..935a03ea 100644 --- a/examples/geometric_modeling.jl +++ b/examples/geometric_modeling.jl @@ -182,18 +182,18 @@ fig, ax, plt = CM.image(g, model(xopt); axis=(xreversed=true, aspect=1, xlabel=" # # Comrade provides several sampling and other posterior approximation tools. To see the # list, please see the Libraries section of the docs. For this example, we will be using -# [AdvancedHMC.jl](https://github.com/TuringLang/AdvancedHMC.jl), which uses -# an adaptive Hamiltonian Monte Carlo sampler called NUTS to approximate the posterior. -# Most of Comrade's external libraries follow a similar interface. To use AdvancedHMC -# do the following: +# [Pigeons.jl](https://github.com/Julia-Tempering/Pigeons.jl) which is a state-of-the-art +# parallel tempering sampler that enables global exploration of the posterior. For smaller dimension +# problems (< 100) we recommend using this sampler especially if you have access to > 1 thread/core. +using Pigeons +pt = pigeons(target=cpost, explorer=SliceSampler(), record=[traces, round_trip, log_sum_ratio], n_chains=18, n_rounds=9) +chain = sample_array(cpost, pt) -using ComradeAHMC, Zygote -chain, stats = sample(rng, post, AHMC(metric=DiagEuclideanMetric(ndim), autodiff=Val(:Zygote)), 2000; nadapts=1000, init_params=xopt) # That's it! To finish it up we can then plot some simple visual fit diagnostics. # First to plot the image we call -imgs = intensitymap.(skymodel.(Ref(post), sample(chain[1000:end], 100)), μas2rad(200.0), μas2rad(200.0), 128, 128) +imgs = intensitymap.(skymodel.(Ref(post), sample(chain, 100)), μas2rad(200.0), μas2rad(200.0), 128, 128) imageviz(imgs[end], colormap=:afmhot) # What about the mean image? Well let's grab 100 images from the chain, where we first remove the @@ -212,7 +212,7 @@ plot(model(xopt), dlcamp, label="MAP") p = plot(dlcamp); uva = [sqrt.(uvarea(dlcamp[i])) for i in 1:length(dlcamp)] for i in 1:10 - m = simulate_observation(post, chain[rand(rng, 1000:2000)])[1] + m = simulate_observation(post, sample(chain, 1)[1])[1] scatter!(uva, m, color=:grey, label=:none, alpha=0.1) end p diff --git a/ext/ComradePigeonsExt.jl b/ext/ComradePigeonsExt.jl index d52cccdb..38eecd3a 100644 --- a/ext/ComradePigeonsExt.jl +++ b/ext/ComradePigeonsExt.jl @@ -4,10 +4,76 @@ using Comrade if isdefined(Base, :get_extension) using Pigeons + using AbstractMCMC + using LogDensityProblems + using HypercubeTransform + using TypedTables + using TransformVariables + using Random + else using ..Pigeons + using ..AbstractMCMC + using ..LogDensityProblems + using ..HypercubeTransform + using ..TypedTables + using ..TransformVariables + using ..Random +end + +Pigeons.initialization(tpost::Comrade.TransformedPosterior, rng::Random.AbstractRNG, ::Int) = prior_sample(rng, tpost) + +struct PriorRef{P,T} + model::P + transform::T +end + +function (p::PriorRef{P,<:TransformVariables.AbstractTransform})(x) where {P} + y, lj = TransformVariables.transform_and_logjac(p.transform, x) + logdensityof(p.model, y) + lj +end + +function (p::PriorRef{P,<:HypercubeTransform.AbstractHypercubeTransform})(x) where {P} + for xx in x + (xx > 1 || xx < 0) && return convert(eltype(x), -Inf) + end + return zero(eltype(x)) +end + +Pigeons.default_explorer(::Comrade.TransformedPosterior{P,<:HypercubeTransform.AbstractHypercubeTransform}) where {P} = + SliceSampler() + +Pigeons.default_explorer(::Comrade.TransformedPosterior{P,<:TransformVariables.AbstractTransform}) where {P} = + Pigeons.AutoMALA(;default_autodiff_backend = :Zygote) + +function Pigeons.default_reference(tpost::Comrade.TransformedPosterior) + t = tpost.transform + p = tpost.lpost.prior + return PriorRef(p, t) +end + +function Pigeons.sample_iid!(target::Comrade.TransformedPosterior, replica, shared) + replica.state = Pigeons.initialization(target, replica.rng, replica.replica_index) +end + +function Pigeons.sample_iid!(target::PriorRef{P, <:TransformVariables.AbstractTransform}, replica, shared) where {P} + replica.state .= Comrade.inverse(target.transform, rand(replica.rng, target.model)) +end + +function Pigeons.sample_iid!(::PriorRef{P,<:HypercubeTransform.AbstractHypercubeTransform}, replica, shared) where {P} + rand!(replica.rng, replica.state) +end + + +function Pigeons.sample_array(tpost::Comrade.TransformedPosterior, pt::Pigeons.PT) + samples = sample_array(pt) + arr = reshape(samples, size(samples, 1), size(samples, 2)) + return Table(map(x->Comrade.transform(tpost, x), eachrow(arr))) end +LogDensityProblems.dimension(t::PriorRef) = Comrade.dimension(t.transform) +LogDensityProblems.logdensity(t::PriorRef, x) = t(x) +LogDensityProblems.capabilities(::Type{<:PriorRef}) = LogDensityProblems.LogDensityOrder{0}() end diff --git a/src/observations/observations.jl b/src/observations/observations.jl index ab52d392..d64c8658 100755 --- a/src/observations/observations.jl +++ b/src/observations/observations.jl @@ -495,7 +495,7 @@ Base.@kwdef struct EHTVisibilityDatum{S<:Number} <: AbstractVisibilityDatum{S} """ Complex Vis. measurement (Jy) """ - measurement::S + measurement::Complex{S} """ error of the complex vis (Jy) """