From f6f02aca7198a9e58cbea9beb56a242b198f12a1 Mon Sep 17 00:00:00 2001 From: YongchaoHuang <34540771+YongchaoHuang@users.noreply.github.com> Date: Mon, 24 Jul 2023 14:47:48 +0100 Subject: [PATCH] log probability interface for post-inference analysis (#438) * extended methods for `logprior`, `loglikelihood`, `logposterior` for chains. * accept Github Actions. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * typed `AbstractChains`; removed Array inputs. * re-formatting. * removed comments to pass formatting test. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * 1. removed the import statements in `lop.jl`; 2. removed the DynamicPPL. namespace declarations of the functions; 3. used Distributions.loglikelihood (instead of `StatsBase.loglikelihood`); 4. moved the tests to test/logp.jl and included them in test/runtests.jl. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: David Widmann * Update src/logp.jl Co-authored-by: David Widmann * Update src/logp.jl Co-authored-by: David Widmann * modified src/logp.jl; added test/logp.jl * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Modified Docstrings; Changed names; Modified methods following Tor's suggestions. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renamed `chain_logprior`,`chain_loglikelihood',`chain_logposterior' to `logprior`,`loglikelihood',`logposterior' . * added `include("logdensityfunction.jl")` to `DynamicPPL.jl` * formatted `test/logp.jl`. * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatted `scr/logp.jl`. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatted `scr/logp.jl`. * Removed comments. * Update src/logp.jl Co-authored-by: David Widmann * Update src/logp.jl Co-authored-by: David Widmann * Update src/logp.jl Co-authored-by: David Widmann * Update src/logp.jl Co-authored-by: David Widmann * Update src/logp.jl Co-authored-by: David Widmann * Update src/logp.jl Co-authored-by: David Widmann * Update test/logp.jl Co-authored-by: David Widmann * removed redundant methods (NamedTuples and Array inputs). * added REPL examples to docstrings. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added `start_idx` into `src/logp.jl`; rewrite `test/logp.jl` using `map-do`; added `MCMCChains` & `StableRNGs` to `DynamicPPL.jl`. * added `start_idx` to the 3 methods. * Reduced chainn size in the docstrings example. * applied formatting. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * upated signatures in docstrings. * applied formatting. * formatted again. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix doctests setup * Update docs/make.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/runtests.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * applied formatting. * Formatting. * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/logp.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix doc tests again. * Fixed formatting. * Merged `logp.jl` into `model.jl` * CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat) (#448) * CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat) * Update test/turing/Project.toml Co-authored-by: CompatHelper Julia Co-authored-by: Tor Erlend Fjelde * CompatHelper: bump compat for Turing to 0.23 for package turing, (keep existing compat) (#439) * CompatHelper: bump compat for Turing to 0.23 for package turing, (keep existing compat) * Update test/turing/Project.toml Co-authored-by: CompatHelper Julia Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Fixed obsolete `TArray` reference. * Fixed incorrect code. * More bugfixes in logp tests. * Avoid calling Turing sampler. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Replace SampleFromPrior with synthetic chain. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Minor bugfix. * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update Project.toml * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update test/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Update src/model.jl Co-authored-by: David Widmann * Added `logprior_true(model,NamedTuple)' and `loglikelihood_true(model, NamedTuple)' methods; revised test/model.jl accordingly (removed `MCMCChains.get_param()' ). * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fixed missing prefix and imports. * Move tests into convenience functions. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Removed constraints on floating number precision. * Fix type constraint again. * Apply suggestions from code review Co-authored-by: David Widmann * 1. removed `StableRNGs`; 2. replaced `map(1:N) do i` in test/model.jl by `for i in 1:N`. * Bugfix. * Import TestUtils -- it is not exported by DPPL. * Specialise on model type. * Improve test. * Update src/test_utils.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde Co-authored-by: David Widmann * Apply suggestions from code review Co-authored-by: David Widmann * midified the way chain value was extracted in all 3 methods. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann * Update src/model.jl * Update src/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/utils.jl * rewrote the tests (mainly the way extracting parameter values from chain). * removed BangBang from doctest setup; fixed imcomplete end in test/model.jl. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed a naming bug (argvals_mat_dict) in src/model.jl. * fixed a typo - missing `var_info`. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann * Apply suggestions from code review Co-authored-by: David Widmann * Explicitly added `using Distributions` in doctests; Accepted suggestion in test/model.jl, tests passed locally. * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * rm unnecessary deps * replace contains with subsumes. * rm redundant deps in docs build script. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix format. * Replaced `subsumes` by `contains`. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * replaced 'contains' by a new, temporary method 'subsumes_sym', just for testing purpose. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * modified `/test/model.jl`: 1. build a map between model parameter symbol (`s`,`m`) and chain parameter names (which is obtained via `varname_leaves`.) 2. use this naming map to collect sample values from chain and drop into `log_true` for validation. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fixed a mistake in `modify_value_representation`. * fixed `gdemo_default` * assigned `model=gdemo_default`. * src/model.jl: added `DynamicPPL.` to `logprior` and `logjoint`. * commented out `gdemo_d()` as a trial test. * used `Symbol(vn_child)` as keys in `chain_sym_map`. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * explicitly loaded `varname_leaves` and `values_from_chain`. * added `print` statements for temporary diagnosis purpose. * added 'print` statements for temporary diagnostics purpose. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * diagnostics again. * Removed some `print` statements as it's working. * Update test/model.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Update test/model.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Update test/model.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Update test/model.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * 1. moved helper functions to `test_util.jl`; 2. re-wrote the way `chain_mat` can be generated. * Update test/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatting. * formatting. * Update utils.jl * Update test_util.jl * Update Project.toml * replaced 'varname_leaves' by 'DynamicPPL.varname_leaves'. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann Co-authored-by: Hong Ge Co-authored-by: CompatHelper Julia Co-authored-by: Tor Erlend Fjelde Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Jose Storopoli --- Project.toml | 2 +- docs/Project.toml | 2 + docs/make.jl | 4 +- src/model.jl | 108 ++++++++++++++++++++++++++++++++++++++++++++++ test/model.jl | 73 ++++++++++++++++++++++++++++++- test/test_util.jl | 15 +++++++ 6 files changed, 199 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 6b8abb913..928327b08 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.7" +version = "0.23.8" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/Project.toml b/docs/Project.toml index b6cce8b37..2e7de3456 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -17,3 +18,4 @@ LogDensityProblems = "2" MLUtils = "0.3, 0.4" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" +MCMCChains = "5" diff --git a/docs/make.jl b/docs/make.jl index 9e0694019..c71187ee1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -10,9 +10,7 @@ using DynamicPPL: AbstractPPL using Distributions # Doctest setup -DocMeta.setdocmeta!( - DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true -) +DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true) makedocs(; sitename="DynamicPPL", diff --git a/src/model.jl b/src/model.jl index 1613efab6..c6858041a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1059,6 +1059,42 @@ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) end +""" + logjoint(model::Model, chain::AbstractMCMC.AbstractChains) + +Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> logjoint(demo_model([1., 2.]), chain); +``` +""" +function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) + var_info = VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + DynamicPPL.logjoint(model, argvals_dict) + end +end + """ logprior(model::Model, varinfo::AbstractVarInfo) @@ -1070,6 +1106,42 @@ function logprior(model::Model, varinfo::AbstractVarInfo) return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) end +""" + logprior(model::Model, chain::AbstractMCMC.AbstractChains) + +Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> logprior(demo_model([1., 2.]), chain); +``` +""" +function logprior(model::Model, chain::AbstractMCMC.AbstractChains) + var_info = VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + DynamicPPL.logprior(model, argvals_dict) + end +end + """ loglikelihood(model::Model, varinfo::AbstractVarInfo) @@ -1081,6 +1153,42 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) end +""" + loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) + +Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. + +# Examples + +```jldoctest +julia> using MCMCChains, Distributions + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + end; + +julia> # construct a chain of samples using MCMCChains + chain = Chains(rand(10, 2, 3), [:s, :m]); + +julia> loglikelihood(demo_model([1., 2.]), chain); +``` +""" +function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) + var_info = VarInfo(model) # extract variables info from the model + map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) + argvals_dict = OrderedDict( + vn_parent => + values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for + vn_parent in keys(var_info) + ) + loglikelihood(model, argvals_dict) + end +end + """ generated_quantities(model::Model, chain::AbstractChains) diff --git a/test/model.jl b/test/model.jl index b9ad9fc7a..481aa4e38 100644 --- a/test/model.jl +++ b/test/model.jl @@ -27,7 +27,7 @@ end @testset "model.jl" begin @testset "convenience functions" begin - model = gdemo_default + model = gdemo_default # defined in test/test_util.jl # sample from model and extract variables vi = VarInfo(model) @@ -49,6 +49,77 @@ end ljoint = logjoint(model, vi) @test ljoint ≈ lprior + llikelihood @test ljoint ≈ lp + + #### logprior, logjoint, loglikelihood for MCMC chains #### + for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12 + var_info = VarInfo(model) + vns = DynamicPPL.TestUtils.varnames(model) + syms = unique(DynamicPPL.getsym.(vns)) + + # generate a chain of sample parameter values. + N = 200 + vals_OrderedDict = mapreduce(hcat, 1:N) do _ + rand(OrderedDict, model) + end + vals_mat = mapreduce(hcat, 1:N) do i + [vals_OrderedDict[i][vn] for vn in vns] + end + i = 1 + for col in eachcol(vals_mat) + col_flattened = [] + [push!(col_flattened, x...) for x in col] + if i == 1 + chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) + else + chain_mat = vcat( + chain_mat, reshape(col_flattened, 1, length(col_flattened)) + ) + end + i += 1 + end + chain_mat = convert(Matrix{Float64}, chain_mat) + + # devise parameter names for chain + sample_values_vec = collect(values(vals_OrderedDict[1])) + symbol_names = [] + chain_sym_map = Dict() + for k in 1:length(keys(var_info)) + vn_parent = keys(var_info)[k] + sym = DynamicPPL.getsym(vn_parent) + vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl + for vn_child in vn_children + chain_sym_map[Symbol(vn_child)] = sym + symbol_names = [symbol_names; Symbol(vn_child)] + end + end + chain = Chains(chain_mat, symbol_names) + + # calculate the pointwise loglikelihoods for the whole chain using the newly written functions + logpriors = logprior(model, chain) + loglikelihoods = loglikelihood(model, chain) + logjoints = logjoint(model, chain) + # compare them with true values + for i in 1:N + samples_dict = Dict() + for chain_key in keys(chain) + value = chain[i, chain_key, 1] + key = chain_sym_map[chain_key] + existing_value = get(samples_dict, key, Float64[]) + push!(existing_value, value) + samples_dict[key] = existing_value + end + samples = (; samples_dict...) + samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl + @test logpriors[i] ≈ + DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m]) + @test loglikelihoods[i] ≈ DynamicPPL.TestUtils.loglikelihood_true( + model, samples[:s], samples[:m] + ) + @test logjoints[i] ≈ + DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m]) + end + println("\n model $(model) passed !!! \n") + end end @testset "rng" begin diff --git a/test/test_util.jl b/test/test_util.jl index f3e54c437..892f7221a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -82,3 +82,18 @@ short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" + +# convenient functions for testing model.jl +# function to modify the representation of values based on their length +function modify_value_representation(nt::NamedTuple) + modified_nt = NamedTuple() + for (key, value) in zip(keys(nt), values(nt)) + if length(value) == 1 # Scalar value + modified_value = value[1] + else # Non-scalar value + modified_value = value + end + modified_nt = merge(modified_nt, (key => modified_value,)) + end + return modified_nt +end