diff --git a/Project.toml b/Project.toml index 6b8abb913..928327b08 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.7" +version = "0.23.8" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/Project.toml b/docs/Project.toml index b6cce8b37..2e7de3456 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -17,3 +18,4 @@ LogDensityProblems = "2" MLUtils = "0.3, 0.4" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" +MCMCChains = "5" diff --git a/docs/make.jl b/docs/make.jl index 9e0694019..c71187ee1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -10,9 +10,7 @@ using DynamicPPL: AbstractPPL using Distributions # Doctest setup -DocMeta.setdocmeta!( - DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true -) +DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true) makedocs(; sitename="DynamicPPL", diff --git a/src/model.jl b/src/model.jl index 1613efab6..c6858041a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1059,6 +1059,42 @@ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) end +""" + logjoint(model::Model, chain::AbstractMCMC.AbstractChains) + +Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> logjoint(demo_model([1., 2.]), chain); +``` +""" +function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) + var_info = VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + DynamicPPL.logjoint(model, argvals_dict) + end +end + """ logprior(model::Model, varinfo::AbstractVarInfo) @@ -1070,6 +1106,42 @@ function logprior(model::Model, varinfo::AbstractVarInfo) return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) end +""" + logprior(model::Model, chain::AbstractMCMC.AbstractChains) + +Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> logprior(demo_model([1., 2.]), chain); +``` +""" +function logprior(model::Model, chain::AbstractMCMC.AbstractChains) + var_info = VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + DynamicPPL.logprior(model, argvals_dict) + end +end + """ loglikelihood(model::Model, varinfo::AbstractVarInfo) @@ -1081,6 +1153,42 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) end +""" + loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) + +Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> loglikelihood(demo_model([1., 2.]), chain); +``` +""" +function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) + var_info = VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + loglikelihood(model, argvals_dict) + end +end + """ generated_quantities(model::Model, chain::AbstractChains) diff --git a/test/model.jl b/test/model.jl index b9ad9fc7a..481aa4e38 100644 --- a/test/model.jl +++ b/test/model.jl @@ -27,7 +27,7 @@ end @testset "model.jl" begin @testset "convenience functions" begin - model = gdemo_default + model = gdemo_default # defined in test/test_util.jl # sample from model and extract variables vi = VarInfo(model) @@ -49,6 +49,77 @@ end ljoint = logjoint(model, vi) @test ljoint ≈ lprior + llikelihood @test ljoint ≈ lp + + #### logprior, logjoint, loglikelihood for MCMC chains #### + for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12 + var_info = VarInfo(model) + vns = DynamicPPL.TestUtils.varnames(model) + syms = unique(DynamicPPL.getsym.(vns)) + + # generate a chain of sample parameter values. + N = 200 + vals_OrderedDict = mapreduce(hcat, 1:N) do _ + rand(OrderedDict, model) + end + vals_mat = mapreduce(hcat, 1:N) do i + [vals_OrderedDict[i][vn] for vn in vns] + end + i = 1 + for col in eachcol(vals_mat) + col_flattened = [] + [push!(col_flattened, x...) for x in col] + if i == 1 + chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) + else + chain_mat = vcat( + chain_mat, reshape(col_flattened, 1, length(col_flattened)) + ) + end + i += 1 + end + chain_mat = convert(Matrix{Float64}, chain_mat) + + # devise parameter names for chain + sample_values_vec = collect(values(vals_OrderedDict[1])) + symbol_names = [] + chain_sym_map = Dict() + for k in 1:length(keys(var_info)) + vn_parent = keys(var_info)[k] + sym = DynamicPPL.getsym(vn_parent) + vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl + for vn_child in vn_children + chain_sym_map[Symbol(vn_child)] = sym + symbol_names = [symbol_names; Symbol(vn_child)] + end + end + chain = Chains(chain_mat, symbol_names) + + # calculate the pointwise loglikelihoods for the whole chain using the newly written functions + logpriors = logprior(model, chain) + loglikelihoods = loglikelihood(model, chain) + logjoints = logjoint(model, chain) + # compare them with true values + for i in 1:N + samples_dict = Dict() + for chain_key in keys(chain) + value = chain[i, chain_key, 1] + key = chain_sym_map[chain_key] + existing_value = get(samples_dict, key, Float64[]) + push!(existing_value, value) + samples_dict[key] = existing_value + end + samples = (; samples_dict...) + samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl + @test logpriors[i] ≈ + DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m]) + @test loglikelihoods[i] ≈ DynamicPPL.TestUtils.loglikelihood_true( + model, samples[:s], samples[:m] + ) + @test logjoints[i] ≈ + DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m]) + end + println("\n model $(model) passed !!! \n") + end end @testset "rng" begin diff --git a/test/test_util.jl b/test/test_util.jl index f3e54c437..892f7221a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -82,3 +82,18 @@ short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" + +# convenient functions for testing model.jl +# function to modify the representation of values based on their length +function modify_value_representation(nt::NamedTuple) + modified_nt = NamedTuple() + for (key, value) in zip(keys(nt), values(nt)) + if length(value) == 1 # Scalar value + modified_value = value[1] + else # Non-scalar value + modified_value = value + end + modified_nt = merge(modified_nt, (key => modified_value,)) + end + return modified_nt +end