Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into torfjelde/varnameve…
Browse files Browse the repository at this point in the history
…ctor
  • Loading branch information
mhauru committed Sep 4, 2024
2 parents b5677b4 + bf73fd0 commit 9d1c8d3
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ else
using ..MCMCChains: MCMCChains
end

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("$(typeof(c)) do not support indexing using varnmes.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
Expand All @@ -26,10 +20,17 @@ function DynamicPPL.loadstate(chain::MCMCChains.Chains)
return chain.info[:samplerstate]
end

# A few methods needed.
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names

function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx
)
Expand Down

0 comments on commit 9d1c8d3

Please sign in to comment.