From 83467226ca01a3806c5e616e3ba2e581d1db4a5a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 6 Sep 2023 18:51:55 +0100 Subject: [PATCH 1/6] _params_to_array now returns varnames and values instead of symbols and values --- src/mcmc/Inference.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 7f4fad950..ffd84a5f0 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -310,18 +310,17 @@ end function _params_to_array(model::DynamicPPL.Model, ts::Vector) - # TODO: Do we really need to use `Symbol` here? - names_set = OrderedSet{Symbol}() + names_set = OrderedSet{VarName}() # Extract the parameter names and values from each transition. dicts = map(ts) do t nms_and_vs = getparams(model, t) - nms = map(Symbol ∘ first, nms_and_vs) + nms = map(first, nms_and_vs) vs = map(last, nms_and_vs) for nm in nms push!(names_set, nm) end # Convert the names and values to a single dictionary. - return Dict(nms[j] => vs[j] for j in 1:length(vs)) + return OrderedDict(zip(nms, vs)) end names = collect(names_set) vals = [get(dicts[i], key, missing) for i in eachindex(dicts), @@ -385,23 +384,23 @@ function AbstractMCMC.bundle_samples( ) # Convert transitions to array format. # Also retrieve the variable names. - nms, vals = _params_to_array(model, ts) + varnames, vals = _params_to_array(model, ts) + varnames_symbol = map(Symbol, varnames) # Get the values of the extra parameters in each transition. extra_params, extra_values = get_transition_extras(ts) # Extract names & construct param array. - nms = [nms; extra_params] + nms = [varnames_symbol; extra_params] parray = hcat(vals, extra_values) # Get the average or final log evidence, if it exists. le = getlogevidence(ts, spl, state) # Set up the info tuple. + info = (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),) if save_state - info = (model = model, sampler = spl, samplerstate = state) - else - info = NamedTuple() + info = merge(info, (model = model, sampler = spl, samplerstate = state)) end # Merge in the timing info, if available From 898675f8f107d3a462caba0e7e2c0f356a3afa25 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 6 Sep 2023 18:55:52 +0100 Subject: [PATCH 2/6] updated other uses of _params_to_array --- ext/TuringOptimExt.jl | 2 +- src/mcmc/emcee.jl | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index a0710893e..eb594929d 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -253,7 +253,7 @@ function _optimize( Turing.Inference.getparams(model, f.varinfo), DynamicPPL.getlogp(f.varinfo) )] - varnames, _ = Turing.Inference._params_to_array(model, ts) + varnames = map(Symbol, first(Turing.Inference._params_to_array(model, ts))) # Store the parameters and their names in an array. vmat = NamedArrays.NamedArray(vals, varnames) diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index f89cad955..d41596075 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -109,7 +109,8 @@ function AbstractMCMC.bundle_samples( params_vec = map(Base.Fix1(_params_to_array, model), samples) # Extract names and values separately. - nms = params_vec[1][1] + varnames = params_vec[1][1] + varnames_symbol = map(Symbol, varnames) vals_vec = [p[2] for p in params_vec] # Get the values of the extra parameters in each transition. @@ -120,7 +121,7 @@ function AbstractMCMC.bundle_samples( extra_values_vec = [e[2] for e in extra_vec] # Extract names & construct param array. - nms = [nms; extra_params] + nms = [varnames_symbol; extra_params] # `hcat` first to ensure we get the right `eltype`. x = hcat(first(vals_vec), first(extra_values_vec)) # Pre-allocate to minimize memory usage. @@ -133,10 +134,9 @@ function AbstractMCMC.bundle_samples( le = getlogevidence(samples, state, spl) # Set up the info tuple. + info = (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),) if save_state - info = (model = model, sampler = spl, samplerstate = state) - else - info = NamedTuple() + info = merge(info, (model = model, sampler = spl, samplerstate = state)) end # Concretize the array before giving it to MCMCChains. From 0b9247283c137244631a61eaaa01e2b1a16f0fea Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sat, 9 Sep 2023 13:38:08 +0100 Subject: [PATCH 3/6] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a89e72eec..134d97528 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.23.15" +DynamicPPL = "0.23.16" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" From 05f625f56475c060329016980f214441351c6589 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Sep 2023 17:00:22 +0100 Subject: [PATCH 4/6] make inclusion of varname_to_symbol mapping in chains optional --- src/mcmc/Inference.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index ffd84a5f0..1b09668ec 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -378,6 +378,7 @@ function AbstractMCMC.bundle_samples( save_state = false, stats = missing, sort_chain = false, + include_varname_to_symbol = true, discard_initial = 0, thinning = 1, kwargs... @@ -398,7 +399,12 @@ function AbstractMCMC.bundle_samples( le = getlogevidence(ts, spl, state) # Set up the info tuple. - info = (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),) + info = NamedTuple() + + if include_varname_to_symbol + info = merge(info, (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),)) + end + if save_state info = merge(info, (model = model, sampler = spl, samplerstate = state)) end From 87dea58b37cfcafb6b4dd3941d8e7515c1b913cd Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 12 Sep 2023 19:22:18 +0100 Subject: [PATCH 5/6] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 134d97528..aaf16304d 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.23.16" +DynamicPPL = "0.23.17" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" From c35b04167a03b9c8310925e06428ac61e34de5e1 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 12 Sep 2023 19:22:36 +0100 Subject: [PATCH 6/6] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index aaf16304d..c0231afc6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.29" +version = "0.29.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"