Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cvi projection marginal rule (with proposal distribution) #430

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 60 additions & 34 deletions ext/ReactiveMPProjectionExt/rules/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,61 @@ end
return FactorizedJoint((q,))
end

function create_density_function(forms_match, i, pre_samples, logp_nc_drop_index, m_in)
if forms_match
return z -> logp_nc_drop_index(z, i, pre_samples)
end
return z -> logp_nc_drop_index(z, i, pre_samples) + logpdf(m_in, z)
end

function optimize_parameters(i, pre_samples, m_ins, logp_nc_drop_index, method)
m_in = m_ins[i]
default_type = ExponentialFamily.exponential_family_typetag(m_in)
prj = create_project_to_ins(method, m_in, i)

typeform = ExponentialFamilyProjection.get_projected_to_type(prj)
dims = ExponentialFamilyProjection.get_projected_to_dims(prj)
forms_match = typeform === default_type && dims == size(m_in)

df = create_density_function(forms_match, i, pre_samples, logp_nc_drop_index, m_in)
logp = convert(
promote_variate_type(variate_form(typeof(m_in)), BayesBase.AbstractContinuousGenericLogPdf),
UnspecifiedDomain(),
df
)

return forms_match ? project_to(prj, logp, m_in) : project_to(prj, logp)
end

function generate_samples(rng, ::Nothing, m_ins, n_samples, ::Val{FullSampling})
return return zip(map(m_in -> ReactiveMP.cvilinearize(rand(rng, m_in, n_samples)), m_ins)...)
end

function generate_samples(::Any, ::Nothing, m_ins, n_samples, ::Val{MeanBased})
return return zip(map(m_in -> [mean(m_in)], m_ins)...)
end

function generate_samples(rng, proposal_distribution::FactorizedJoint, ::Any, n_samples, ::Val{FullSampling})
return return zip(map(q_in -> ReactiveMP.cvilinearize(rand(rng, q_in, n_samples)), proposal_distribution.multipliers)...)
end

function generate_samples(::Any, proposal_distribution::FactorizedJoint, m_ins, n_samples, ::Val{MeanBased})
return return zip(map(q_in -> [mean(q_in)], proposal_distribution.multipliers)...)
end

@marginalrule DeltaFn(:ins) (m_out::Any, m_ins::ManyOf{N, Any}, meta::DeltaMeta{M}) where {N, M <: CVIProjection} = begin
method = ReactiveMP.getmethod(meta)
rng = method.rng
pre_samples = zip(map(m_in_k -> ReactiveMP.cvilinearize(rand(rng, m_in_k, method.marginalsamples)), m_ins)...)

proposal_distribution_container = method.proposal_distribution

pre_samples = generate_samples(
rng,
proposal_distribution_container.distribution,
m_ins,
method.marginalsamples,
Val(method.sampling_strategy) # Wrap in Val{}
)

logp_nc_drop_index = let g = getnodefn(meta, Val(:out)), pre_samples = pre_samples
(z, i, pre_samples) -> begin
samples = map(ttuple -> ReactiveMP.TupleTools.insertat(ttuple, i, (z,)), pre_samples)
Expand All @@ -66,37 +116,13 @@ end
end
end

optimize_natural_parameters = let m_ins = m_ins, logp_nc_drop_index = logp_nc_drop_index
(i, pre_samples) -> begin
m_in = m_ins[i]
default_type = ExponentialFamily.exponential_family_typetag(m_in)

prj = create_project_to_ins(method, m_in, i)

typeform = ExponentialFamilyProjection.get_projected_to_type(prj)
dims = ExponentialFamilyProjection.get_projected_to_dims(prj)
forms_match = typeform === default_type && dims == size(m_in)

# Create log probability function
df = if forms_match
let i = i, pre_samples = pre_samples, logp_nc_drop_index = logp_nc_drop_index
(z) -> logp_nc_drop_index(z, i, pre_samples)
end
else
let i = i, pre_samples = pre_samples, logp_nc_drop_index = logp_nc_drop_index, m_in = m_in
(z) -> logp_nc_drop_index(z, i, pre_samples) + logpdf(m_in, z)
end
end

logp = convert(
promote_variate_type(variate_form(typeof(m_in)), BayesBase.AbstractContinuousGenericLogPdf),
UnspecifiedDomain(),
df
)

return forms_match ? project_to(prj, logp, m_in) : project_to(prj, logp)
end
end
result = FactorizedJoint(
ntuple(
i -> optimize_parameters(i, pre_samples, m_ins, logp_nc_drop_index, method),
length(m_ins)
)
)

return FactorizedJoint(ntuple(i -> optimize_natural_parameters(i, pre_samples), length(m_ins)))
proposal_distribution_container.distribution = result
return result
end
17 changes: 16 additions & 1 deletion src/approximations/cvi_projection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
export CVIProjection

export CVISamplingStrategy, FullSampling, MeanBased

@enum CVISamplingStrategy begin
FullSampling
MeanBased
end

mutable struct ProposalDistributionContainer{PD}
distribution::PD
end

"""
CVIProjection(; parameters...)

Expand All @@ -16,17 +27,21 @@ This structure is a subtype of `AbstractApproximationMethod` and is used to conf
- `outsamples::S`: The number of samples used for approximating output message distributions. Default is `100`.
- `out_prjparams::OF`: the form parameter used to select the distribution form on which one to project out edge, if it's not provided will be infered from marginal form
- `in_prjparams::IFS`: a namedtuple like object to select the form on which one to project in the input edge, if it's not provided will be infered from the incoming message onto this edge
- `proposal_distribution::PD`: the proposal distribution used for generating samples, if it's not provided will be infered from the incoming message onto this edge
- `sampling_strategy::SS`: the sampling strategy for the logpdf approximation

!!! note
The `CVIProjection` method is an experimental enhancement of the now-deprecated `CVI`, offering better stability and improved accuracy.
Note that the parameters of this structure, as well as their defaults, are subject to change during the experimentation phase.
"""
Base.@kwdef struct CVIProjection{R, S, OF, IFS} <: AbstractApproximationMethod
Base.@kwdef struct CVIProjection{R, S, OF, IFS, PD, SS} <: AbstractApproximationMethod
rng::R = Random.MersenneTwister(42)
marginalsamples::S = 10
outsamples::S = 100
out_prjparams::OF = nothing
in_prjparams::IFS = nothing
proposal_distribution::PD = ProposalDistributionContainer{Any}(nothing)
sampling_strategy::SS = FullSampling
end

function get_kth_in_form(::CVIProjection{R, S, OF, Nothing}, ::Int) where {R, S, OF}
Expand Down
87 changes: 87 additions & 0 deletions test/ext/ReactiveMPProjectionExt/rules/marginals_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,90 @@ end
@test isa(result[2], MvNormalMeanScalePrecision)
end
end

@testitem "CVIProjection proposal distribution convergence tests" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase, LinearAlgebra
using Random, Distributions

@testset "Posterior approximation quality" begin
rng = MersenneTwister(123)
method = CVIProjection(rng = rng, marginalsamples = 2000)
meta = DeltaMeta(method = method, inverse = nothing)

f(x, y) = x * y

# Define distributions
m_out = NormalMeanVariance(2.0, 0.1)
m_in1 = NormalMeanVariance(0.0, 2.0)
m_in2 = NormalMeanVariance(0.0, 2.0)

# Function to compute unnormalized log posterior for a sample
function log_posterior(x, y)
return logpdf(m_in1, x) + logpdf(m_in2, y) + logpdf(m_out, f(x, y))
end

# Estimate KL divergence using samples
function estimate_kl_divergence(q_result)
n_samples = 10000
samples_q = [(rand(rng, q_result[1]), rand(rng, q_result[2])) for _ in 1:n_samples]

# Compute E_q[log q(x,y) - log p(x,y)]
log_q_terms = [logpdf(q_result[1], x) + logpdf(q_result[2], y) for (x, y) in samples_q]
log_p_terms = [log_posterior(x, y) for (x, y) in samples_q]

return mean(log_q_terms .- log_p_terms)
end

# Run multiple iterations and collect KL divergences
n_iterations = 10
kl_divergences = Vector{Float64}(undef, n_iterations)

for i in 1:n_iterations
result = @call_marginalrule DeltaFn{f}(:ins) (m_out = m_out, m_ins = ManyOf(m_in1, m_in2), meta = meta)
kl_divergences[i] = estimate_kl_divergence(result)
end

@test kl_divergences[1] > kl_divergences[end]
end
end

@testitem "Basic checks for marginal rule with mean based approximation" begin
using ExponentialFamily, ExponentialFamilyProjection, BayesBase
import ReactiveMP: @test_rules, @test_marginalrules

@testset "f(x, y) -> [x, y], x~Normal, y~Normal, out~MvNormal (marginalization)" begin
f(x, y) = [x, y]
meta = DeltaMeta(method = CVIProjection(sampling_strategy = MeanBased), inverse = nothing)
@test_marginalrules [check_type_promotion = false, atol = 1e-1] DeltaFn{f}(:ins) [(
input = (m_out = MvGaussianMeanCovariance(ones(2), [2 0; 0 2]), m_ins = ManyOf(NormalMeanVariance(0, 1), NormalMeanVariance(1, 2)), meta = meta),
output = FactorizedJoint((NormalMeanVariance(1 / 3, 2 / 3), NormalMeanVariance(1.0, 1.0)))
)]
end
end

@testitem "DeltaNode - CVI sampling strategy performance comparison" begin
using Test
using BenchmarkTools
using BayesBase, ExponentialFamily, ExponentialFamilyProjection

f(x, y) = [x, y]

function run_marginal_test(strategy)
meta = DeltaMeta(method = CVIProjection(sampling_strategy = strategy))
m_out = MvGaussianMeanCovariance(ones(2), [2 0; 0 2])
m_in1 = NormalMeanVariance(0.0, 2.0)
m_in2 = NormalMeanVariance(0.0, 2.0)
return @belapsed begin
@call_marginalrule DeltaFn{f}(:ins) (m_out = $m_out, m_ins = ManyOf($m_in1, $m_in2), meta = $meta)
end samples = 2
end

# Run benchmarks
full_time = run_marginal_test(FullSampling)
mean_time = run_marginal_test(MeanBased)

@test mean_time < full_time

# Optional: Print the actual times for verification
@info "Sampling strategy performance" full_time mean_time ratio = (full_time / mean_time)
end
Loading