Skip to content

Commit

Permalink
Add pigeons as an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Oct 29, 2023
1 parent bda86ae commit 175e2c9
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 10 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca"

[weakdeps]
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"

[extensions]
ComradeMakieExt = "Makie"
ComradePigeonsExt = "Pigeons"
ComradePyehtimExt = "Pyehtim"

[compat]
Expand All @@ -71,6 +73,7 @@ Makie = "0.19"
NamedTupleTools = "0.13,0.14"
PaddedViews = "0.5"
ParameterHandling = "0.4"
Pigeons = "0.2"
PolarizedTypes = "0.1"
PrettyTables = "1, 2"
Pyehtim = "0.1"
Expand All @@ -93,8 +96,9 @@ julia = "1.8"
[extras]
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Makie", "Pyehtim"]
test = ["Test", "Makie", "Pigeons", "Pyehtim"]
1 change: 1 addition & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
OptimizationBBO = "3e6eede4-6085-4f62-9a71-46d9bc1eb92b"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
Pigeons = "0eb8d820-af6a-4919-95ae-11206f830c31"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
16 changes: 8 additions & 8 deletions examples/geometric_modeling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,18 @@ fig, ax, plt = CM.image(g, model(xopt); axis=(xreversed=true, aspect=1, xlabel="
#
# Comrade provides several sampling and other posterior approximation tools. To see the
# list, please see the Libraries section of the docs. For this example, we will be using
# [AdvancedHMC.jl](https://github.com/TuringLang/AdvancedHMC.jl), which uses
# an adaptive Hamiltonian Monte Carlo sampler called NUTS to approximate the posterior.
# Most of Comrade's external libraries follow a similar interface. To use AdvancedHMC
# do the following:
# [Pigeons.jl](https://github.com/Julia-Tempering/Pigeons.jl) which is a state-of-the-art
# parallel tempering sampler that enables global exploration of the posterior. For smaller dimension
# problems (< 100) we recommend using this sampler especially if you have access to > 1 thread/core.
using Pigeons
pt = pigeons(target=cpost, explorer=SliceSampler(), record=[traces, round_trip, log_sum_ratio], n_chains=18, n_rounds=9)
chain = sample_array(cpost, pt)

using ComradeAHMC, Zygote
chain, stats = sample(rng, post, AHMC(metric=DiagEuclideanMetric(ndim), autodiff=Val(:Zygote)), 2000; nadapts=1000, init_params=xopt)

# That's it! To finish it up we can then plot some simple visual fit diagnostics.

# First to plot the image we call
imgs = intensitymap.(skymodel.(Ref(post), sample(chain[1000:end], 100)), μas2rad(200.0), μas2rad(200.0), 128, 128)
imgs = intensitymap.(skymodel.(Ref(post), sample(chain, 100)), μas2rad(200.0), μas2rad(200.0), 128, 128)
imageviz(imgs[end], colormap=:afmhot)

# What about the mean image? Well let's grab 100 images from the chain, where we first remove the
Expand All @@ -212,7 +212,7 @@ plot(model(xopt), dlcamp, label="MAP")
p = plot(dlcamp);
uva = [sqrt.(uvarea(dlcamp[i])) for i in 1:length(dlcamp)]
for i in 1:10
m = simulate_observation(post, chain[rand(rng, 1000:2000)])[1]
m = simulate_observation(post, sample(chain, 1)[1])[1]
scatter!(uva, m, color=:grey, label=:none, alpha=0.1)
end
p
Expand Down
66 changes: 66 additions & 0 deletions ext/ComradePigeonsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,76 @@ using Comrade

if isdefined(Base, :get_extension)
using Pigeons
using AbstractMCMC
using LogDensityProblems
using HypercubeTransform
using TypedTables
using TransformVariables
using Random

else
using ..Pigeons
using ..AbstractMCMC
using ..LogDensityProblems
using ..HypercubeTransform
using ..TypedTables
using ..TransformVariables
using ..Random
end

Pigeons.initialization(tpost::Comrade.TransformedPosterior, rng::Random.AbstractRNG, ::Int) = prior_sample(rng, tpost)

struct PriorRef{P,T}
model::P
transform::T
end

function (p::PriorRef{P,<:TransformVariables.AbstractTransform})(x) where {P}
y, lj = TransformVariables.transform_and_logjac(p.transform, x)
logdensityof(p.model, y) + lj
end

function (p::PriorRef{P,<:HypercubeTransform.AbstractHypercubeTransform})(x) where {P}
for xx in x
(xx > 1 || xx < 0) && return convert(eltype(x), -Inf)
end
return zero(eltype(x))
end

Pigeons.default_explorer(::Comrade.TransformedPosterior{P,<:HypercubeTransform.AbstractHypercubeTransform}) where {P} =
SliceSampler()

Pigeons.default_explorer(::Comrade.TransformedPosterior{P,<:TransformVariables.AbstractTransform}) where {P} =
Pigeons.AutoMALA(;default_autodiff_backend = :Zygote)

function Pigeons.default_reference(tpost::Comrade.TransformedPosterior)
t = tpost.transform
p = tpost.lpost.prior
return PriorRef(p, t)
end

function Pigeons.sample_iid!(target::Comrade.TransformedPosterior, replica, shared)
replica.state = Pigeons.initialization(target, replica.rng, replica.replica_index)
end

function Pigeons.sample_iid!(target::PriorRef{P, <:TransformVariables.AbstractTransform}, replica, shared) where {P}
replica.state .= Comrade.inverse(target.transform, rand(replica.rng, target.model))
end

function Pigeons.sample_iid!(::PriorRef{P,<:HypercubeTransform.AbstractHypercubeTransform}, replica, shared) where {P}
rand!(replica.rng, replica.state)
end


function Pigeons.sample_array(tpost::Comrade.TransformedPosterior, pt::Pigeons.PT)
samples = sample_array(pt)
arr = reshape(samples, size(samples, 1), size(samples, 2))
return Table(map(x->Comrade.transform(tpost, x), eachrow(arr)))
end


LogDensityProblems.dimension(t::PriorRef) = Comrade.dimension(t.transform)
LogDensityProblems.logdensity(t::PriorRef, x) = t(x)
LogDensityProblems.capabilities(::Type{<:PriorRef}) = LogDensityProblems.LogDensityOrder{0}()

end
2 changes: 1 addition & 1 deletion src/observations/observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ Base.@kwdef struct EHTVisibilityDatum{S<:Number} <: AbstractVisibilityDatum{S}
"""
Complex Vis. measurement (Jy)
"""
measurement::S
measurement::Complex{S}
"""
error of the complex vis (Jy)
"""
Expand Down

0 comments on commit 175e2c9

Please sign in to comment.