Skip to content

Commit

Permalink
Attach varname_to_symbol mapping to Chains (#2078)
Browse files Browse the repository at this point in the history
* _params_to_array now returns varnames and values instead of symbols
and values

* updated other uses of _params_to_array

* Update Project.toml

* make inclusion of varname_to_symbol mapping in chains optional

* Update Project.toml

* Update Project.toml

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai authored Sep 12, 2023
1 parent c0c8130 commit 8d8416a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.29"
version = "0.29.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -56,7 +56,7 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.23.15"
DynamicPPL = "0.23.17"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand Down
2 changes: 1 addition & 1 deletion ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ function _optimize(
Turing.Inference.getparams(model, f.varinfo),
DynamicPPL.getlogp(f.varinfo)
)]
varnames, _ = Turing.Inference._params_to_array(model, ts)
varnames = map(Symbol, first(Turing.Inference._params_to_array(model, ts)))

# Store the parameters and their names in an array.
vmat = NamedArrays.NamedArray(vals, varnames)
Expand Down
23 changes: 14 additions & 9 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,18 +310,17 @@ end


function _params_to_array(model::DynamicPPL.Model, ts::Vector)
# TODO: Do we really need to use `Symbol` here?
names_set = OrderedSet{Symbol}()
names_set = OrderedSet{VarName}()
# Extract the parameter names and values from each transition.
dicts = map(ts) do t
nms_and_vs = getparams(model, t)
nms = map(Symbol first, nms_and_vs)
nms = map(first, nms_and_vs)
vs = map(last, nms_and_vs)
for nm in nms
push!(names_set, nm)
end
# Convert the names and values to a single dictionary.
return Dict(nms[j] => vs[j] for j in 1:length(vs))
return OrderedDict(zip(nms, vs))
end
names = collect(names_set)
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
Expand Down Expand Up @@ -379,29 +378,35 @@ function AbstractMCMC.bundle_samples(
save_state = false,
stats = missing,
sort_chain = false,
include_varname_to_symbol = true,
discard_initial = 0,
thinning = 1,
kwargs...
)
# Convert transitions to array format.
# Also retrieve the variable names.
nms, vals = _params_to_array(model, ts)
varnames, vals = _params_to_array(model, ts)
varnames_symbol = map(Symbol, varnames)

# Get the values of the extra parameters in each transition.
extra_params, extra_values = get_transition_extras(ts)

# Extract names & construct param array.
nms = [nms; extra_params]
nms = [varnames_symbol; extra_params]
parray = hcat(vals, extra_values)

# Get the average or final log evidence, if it exists.
le = getlogevidence(ts, spl, state)

# Set up the info tuple.
info = NamedTuple()

if include_varname_to_symbol
info = merge(info, (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),))
end

if save_state
info = (model = model, sampler = spl, samplerstate = state)
else
info = NamedTuple()
info = merge(info, (model = model, sampler = spl, samplerstate = state))
end

# Merge in the timing info, if available
Expand Down
10 changes: 5 additions & 5 deletions src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ function AbstractMCMC.bundle_samples(
params_vec = map(Base.Fix1(_params_to_array, model), samples)

# Extract names and values separately.
nms = params_vec[1][1]
varnames = params_vec[1][1]
varnames_symbol = map(Symbol, varnames)
vals_vec = [p[2] for p in params_vec]

# Get the values of the extra parameters in each transition.
Expand All @@ -120,7 +121,7 @@ function AbstractMCMC.bundle_samples(
extra_values_vec = [e[2] for e in extra_vec]

# Extract names & construct param array.
nms = [nms; extra_params]
nms = [varnames_symbol; extra_params]
# `hcat` first to ensure we get the right `eltype`.
x = hcat(first(vals_vec), first(extra_values_vec))
# Pre-allocate to minimize memory usage.
Expand All @@ -133,10 +134,9 @@ function AbstractMCMC.bundle_samples(
le = getlogevidence(samples, state, spl)

# Set up the info tuple.
info = (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),)
if save_state
info = (model = model, sampler = spl, samplerstate = state)
else
info = NamedTuple()
info = merge(info, (model = model, sampler = spl, samplerstate = state))
end

# Concretize the array before giving it to MCMCChains.
Expand Down

2 comments on commit 8d8416a

@ParadaCarleton
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91310

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.29.1 -m "<description of version>" 8d8416ac6c7363c6003ee6ea1fbaac26b4fc8dc3
git push origin v0.29.1

Please sign in to comment.