Skip to content

Commit

Permalink
log probability interface for post-inference analysis (#438)
Browse files Browse the repository at this point in the history
* extended methods for `logprior`, `loglikelihood`, `logposterior` for chains.

* accept Github Actions.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* typed `AbstractChains`; 
removed Array inputs.

* re-formatting.

* removed comments to pass formatting test.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* 1. removed the import statements in `lop.jl`;
2. removed the DynamicPPL. namespace declarations of the functions; 
3. used Distributions.loglikelihood (instead of `StatsBase.loglikelihood`); 
4. moved the tests to test/logp.jl and included them in test/runtests.jl.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* modified src/logp.jl; added test/logp.jl

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Modified Docstrings;
Changed names;
Modified methods following Tor's suggestions.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* renamed `chain_logprior`,`chain_loglikelihood',`chain_logposterior' to `logprior`,`loglikelihood',`logposterior' .

* added `include("logdensityfunction.jl")` to `DynamicPPL.jl`

* formatted `test/logp.jl`.

* Update test/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* formatted `scr/logp.jl`.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* formatted `scr/logp.jl`.

* Removed comments.

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/logp.jl

Co-authored-by: David Widmann <[email protected]>

* Update test/logp.jl

Co-authored-by: David Widmann <[email protected]>

* removed redundant methods (NamedTuples and Array inputs).

* added REPL examples to docstrings.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added `start_idx` into `src/logp.jl`; rewrite `test/logp.jl` using `map-do`; added `MCMCChains` & `StableRNGs` to `DynamicPPL.jl`.

* added `start_idx` to the 3 methods.

* Reduced chainn size in the docstrings example.

* applied formatting.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* upated signatures in docstrings.

* applied formatting.

* formatted again.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix doctests setup

* Update docs/make.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/runtests.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* applied formatting.

* Formatting.

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logp.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix doc tests again.

* Fixed formatting.

* Merged `logp.jl` into `model.jl`

* CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat) (#448)

* CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat)

* Update test/turing/Project.toml

Co-authored-by: CompatHelper Julia <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>

* CompatHelper: bump compat for Turing to 0.23 for package turing, (keep existing compat) (#439)

* CompatHelper: bump compat for Turing to 0.23 for package turing, (keep existing compat)

* Update test/turing/Project.toml

Co-authored-by: CompatHelper Julia <[email protected]>
Co-authored-by: Hong Ge <[email protected]>

* Fixed obsolete `TArray` reference.

* Fixed incorrect code.

* More bugfixes in logp tests.

* Avoid calling Turing sampler.

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Replace SampleFromPrior with synthetic chain.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Minor bugfix.

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update Project.toml

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update test/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

Co-authored-by: David Widmann <[email protected]>

* Added `logprior_true(model,NamedTuple)' and
`loglikelihood_true(model, NamedTuple)' methods;
revised test/model.jl accordingly (removed `MCMCChains.get_param()' ).

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fixed missing prefix and imports.

* Move tests into convenience functions.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Removed constraints on floating number precision.

* Fix type constraint again.

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* 1. removed `StableRNGs`;
2. replaced `map(1:N) do i` in test/model.jl by `for i in 1:N`.

* Bugfix.

* Import TestUtils -- it is not exported by DPPL.

* Specialise on model type.

* Improve test.

* Update src/test_utils.jl

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* midified the way chain value was extracted in all 3 methods.

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Update src/model.jl

* Update src/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/utils.jl

* rewrote the tests (mainly the way extracting parameter values from chain).

* removed BangBang from doctest setup; fixed imcomplete end in test/model.jl.

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed a naming bug (argvals_mat_dict) in src/model.jl.

* fixed a typo - missing `var_info`.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Explicitly added `using Distributions` in doctests; Accepted suggestion in test/model.jl, tests passed locally.

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* rm unnecessary deps

* replace contains with subsumes.

* rm redundant deps in docs build script.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix format.

* Replaced `subsumes` by `contains`.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* replaced 'contains' by a new, temporary method 'subsumes_sym', just for testing purpose.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* modified `/test/model.jl`:
1. build a map between model parameter symbol (`s`,`m`) and chain parameter names (which is obtained via `varname_leaves`.)
2. use this naming map to collect sample values from chain and drop into `log_true` for validation.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fixed a mistake in `modify_value_representation`.

* fixed `gdemo_default`

* assigned `model=gdemo_default`.

* src/model.jl: added `DynamicPPL.` to  `logprior` and `logjoint`.

* commented out `gdemo_d()` as a trial test.

* used `Symbol(vn_child)` as keys in `chain_sym_map`.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* explicitly loaded  `varname_leaves` and  `values_from_chain`.

* added `print` statements for temporary diagnosis purpose.

* added 'print` statements for temporary diagnostics purpose.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* diagnostics again.

* Removed some `print` statements as it's working.

* Update test/model.jl

Co-authored-by: Hong Ge <[email protected]>

* Update test/model.jl

Co-authored-by: Hong Ge <[email protected]>

* Update test/model.jl

Co-authored-by: Hong Ge <[email protected]>

* Update test/model.jl

Co-authored-by: Hong Ge <[email protected]>

* 1. moved helper functions to `test_util.jl`; 2. re-wrote the way `chain_mat` can be generated.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* formatting.

* formatting.

* Update utils.jl

* Update test_util.jl

* Update Project.toml

* replaced 'varname_leaves' by 'DynamicPPL.varname_leaves'.

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: CompatHelper Julia <[email protected]>
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Jose Storopoli <[email protected]>
  • Loading branch information
8 people authored Jul 24, 2023
1 parent ba206f4 commit f6f02ac
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.7"
version = "0.23.8"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -17,3 +18,4 @@ LogDensityProblems = "2"
MLUtils = "0.3, 0.4"
Setfield = "0.7.1, 0.8, 1"
StableRNGs = "1"
MCMCChains = "5"
4 changes: 1 addition & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ using DynamicPPL: AbstractPPL
using Distributions

# Doctest setup
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
)
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)

makedocs(;
sitename="DynamicPPL",
Expand Down
108 changes: 108 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,42 @@ function logjoint(model::Model, varinfo::AbstractVarInfo)
return getlogp(last(evaluate!!(model, varinfo, DefaultContext())))
end

"""
logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;
julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logjoint(demo_model([1., 2.]), chain);
```
"""
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
end
end

"""
logprior(model::Model, varinfo::AbstractVarInfo)
Expand All @@ -1070,6 +1106,42 @@ function logprior(model::Model, varinfo::AbstractVarInfo)
return getlogp(last(evaluate!!(model, varinfo, PriorContext())))
end

"""
logprior(model::Model, chain::AbstractMCMC.AbstractChains)
Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;
julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> logprior(demo_model([1., 2.]), chain);
```
"""
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
end
end

"""
loglikelihood(model::Model, varinfo::AbstractVarInfo)
Expand All @@ -1081,6 +1153,42 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext())))
end

"""
loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
# Examples
```jldoctest
julia> using MCMCChains, Distributions
julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end
end;
julia> # construct a chain of samples using MCMCChains
chain = Chains(rand(10, 2, 3), [:s, :m]);
julia> loglikelihood(demo_model([1., 2.]), chain);
```
"""
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
vn_parent in keys(var_info)
)
loglikelihood(model, argvals_dict)
end
end

"""
generated_quantities(model::Model, chain::AbstractChains)
Expand Down
73 changes: 72 additions & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

@testset "model.jl" begin
@testset "convenience functions" begin
model = gdemo_default
model = gdemo_default # defined in test/test_util.jl

# sample from model and extract variables
vi = VarInfo(model)
Expand All @@ -49,6 +49,77 @@ end
ljoint = logjoint(model, vi)
@test ljoint lprior + llikelihood
@test ljoint lp

#### logprior, logjoint, loglikelihood for MCMC chains ####
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
var_info = VarInfo(model)
vns = DynamicPPL.TestUtils.varnames(model)
syms = unique(DynamicPPL.getsym.(vns))

# generate a chain of sample parameter values.
N = 200
vals_OrderedDict = mapreduce(hcat, 1:N) do _
rand(OrderedDict, model)
end
vals_mat = mapreduce(hcat, 1:N) do i
[vals_OrderedDict[i][vn] for vn in vns]
end
i = 1
for col in eachcol(vals_mat)
col_flattened = []
[push!(col_flattened, x...) for x in col]
if i == 1
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
else
chain_mat = vcat(
chain_mat, reshape(col_flattened, 1, length(col_flattened))
)
end
i += 1
end
chain_mat = convert(Matrix{Float64}, chain_mat)

# devise parameter names for chain
sample_values_vec = collect(values(vals_OrderedDict[1]))
symbol_names = []
chain_sym_map = Dict()
for k in 1:length(keys(var_info))
vn_parent = keys(var_info)[k]
sym = DynamicPPL.getsym(vn_parent)
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
for vn_child in vn_children
chain_sym_map[Symbol(vn_child)] = sym
symbol_names = [symbol_names; Symbol(vn_child)]
end
end
chain = Chains(chain_mat, symbol_names)

# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
logpriors = logprior(model, chain)
loglikelihoods = loglikelihood(model, chain)
logjoints = logjoint(model, chain)
# compare them with true values
for i in 1:N
samples_dict = Dict()
for chain_key in keys(chain)
value = chain[i, chain_key, 1]
key = chain_sym_map[chain_key]
existing_value = get(samples_dict, key, Float64[])
push!(existing_value, value)
samples_dict[key] = existing_value
end
samples = (; samples_dict...)
samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl
@test logpriors[i]
DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m])
@test loglikelihoods[i] DynamicPPL.TestUtils.loglikelihood_true(
model, samples[:s], samples[:m]
)
@test logjoints[i]
DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m])
end
println("\n model $(model) passed !!! \n")
end
end

@testset "rng" begin
Expand Down
15 changes: 15 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,18 @@ short_varinfo_name(::TypedVarInfo) = "TypedVarInfo"
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"

# convenient functions for testing model.jl
# function to modify the representation of values based on their length
function modify_value_representation(nt::NamedTuple)
modified_nt = NamedTuple()
for (key, value) in zip(keys(nt), values(nt))
if length(value) == 1 # Scalar value
modified_value = value[1]
else # Non-scalar value
modified_value = value
end
modified_nt = merge(modified_nt, (key => modified_value,))
end
return modified_nt
end

1 comment on commit f6f02ac

@torfjelde
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai This shouldn't have been merged.

Please sign in to comment.