Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert to using setval_and_resample! in generated_quantities #709

Merged
merged 4 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.30.2"
version = "0.30.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
86 changes: 7 additions & 79 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities(
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
if DynamicPPL.supports_varname_indexing(chain)
varname_pairs = _varname_pairs_with_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
else
varname_pairs = _varname_pairs_without_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
end
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
return fixed_model()
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
end
end

"""
_varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation assumes `chain` can be indexed using variable names, and is the
preffered implementation.
"""
function _varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
vns = DynamicPPL.varnames(chain)
vn_parents = Iterators.map(vns) do vn
# The call nested_setindex_maybe! is used to handle cases where vn is not
# the variable name used in the model, but rather subsumed by one. Except
# for the subsumption part, this could be
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
DynamicPPL.nested_setindex_maybe!(
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
)
end
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
vn_parent => varinfo[vn_parent]
end
return varname_pairs
end

"""
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.

The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
won't catch all cases. We should get rid of this if we can.
"""
# TODO(mhauru) See docstring above.
function _vcat_subsumed_values(vn_string, values, key_strings)
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
end

"""
_varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation does not assume that `chain` can be indexed using variable names. It is
thus not guaranteed to work in cases where the variable names have complex subsumption
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
"""
function _varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
values = chain.value[sample_idx, :, chain_idx]
keys = Base.keys(chain)
keys_strings = map(string, keys)
varname_pairs = [
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
vn in Base.keys(varinfo)
]
return varname_pairs
end

end
Loading