From 52a9b76b6dabcca568e17c19191002cb3ed59ffc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Jul 2024 09:30:51 +0100 Subject: [PATCH 01/10] Initial work on `NamedTuple` as `initial_params` --- src/sampler.jl | 51 ++++++++++++++++++++++++++++---------------------- src/varinfo.jl | 4 ++++ 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index b2fc6f4ec..35822371f 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -142,38 +142,45 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function initialize_parameters!!( - vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model -) - @debug "Using passed-in initial variable values" initial_params - - # Flatten parameters. - init_theta = mapreduce(vcat, initial_params) do x - vec([x;]) - end - - # Get all values. - linked = islinked(vi, spl) - if linked - vi = invlink!!(vi, spl, model) - end - theta = vi[spl] - length(theta) == length(init_theta) || throw( +function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Real}, spl::AbstractSampler) + theta = varinfo[spl] + length(theta) == length(initial_params) || throw( DimensionMismatch( - "Provided initial value size ($(length(init_theta))) doesn't match the model size ($(length(theta)))", + "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))", ), ) # Update values that are provided. - for i in eachindex(init_theta) - x = init_theta[i] + for i in eachindex(initial_params) + x = initial_params[i] if x !== missing theta[i] = x end end - # Update in `vi`. - vi = setindex!!(vi, theta, spl) + # Update in `varinfo`. + return setindex!!(varinfo, theta, spl) +end + +function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler) + return DynamicPPL.TestUtils.update_values!!(varinfo, initial_params, keys(varinfo, spl)) +end + +function initialize_parameters!!( + vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model +) + @debug "Using passed-in initial variable values" initial_params + + # `link` the varinfo if needed. + linked = islinked(vi, spl) + if linked + vi = invlink!!(vi, spl, model) + end + + # Set the values in `vi`. + vi = set_values!!(vi, initial_params, spl) + + # `invlink` if needed. if linked vi = link!!(vi, spl, model) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 68d36141e..f471e2099 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -892,6 +892,10 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] return expr end +# FIXME(torfjelde): Don't use `_getvns`. +Base.keys(vi::UntypedVarInfo, spl::AbstractSampler) = _getvns(vi, spl) +Base.keys(vi::TypedVarInfo, spl::AbstractSampler) = mapreduce(values, vcat, _getvns(vi, spl)) + """ setgid!(vi::VarInfo, gid::Selector, vn::VarName) From 94c2ef74169bbb0a1aba389c1a7e4950ac379b21 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Jul 2024 10:11:09 +0100 Subject: [PATCH 02/10] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sampler.jl | 8 ++++++-- src/varinfo.jl | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 35822371f..8aa6db33c 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -142,7 +142,9 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Real}, spl::AbstractSampler) +function set_values!!( + varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Real}, spl::AbstractSampler +) theta = varinfo[spl] length(theta) == length(initial_params) || throw( DimensionMismatch( @@ -162,7 +164,9 @@ function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector{< return setindex!!(varinfo, theta, spl) end -function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler) +function set_values!!( + varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler +) return DynamicPPL.TestUtils.update_values!!(varinfo, initial_params, keys(varinfo, spl)) end diff --git a/src/varinfo.jl b/src/varinfo.jl index f471e2099..903789325 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -894,7 +894,9 @@ end # FIXME(torfjelde): Don't use `_getvns`. Base.keys(vi::UntypedVarInfo, spl::AbstractSampler) = _getvns(vi, spl) -Base.keys(vi::TypedVarInfo, spl::AbstractSampler) = mapreduce(values, vcat, _getvns(vi, spl)) +function Base.keys(vi::TypedVarInfo, spl::AbstractSampler) + return mapreduce(values, vcat, _getvns(vi, spl)) +end """ setgid!(vi::VarInfo, gid::Selector, vn::VarName) From 5b5f4dd2b53b42d8f2eed7cfb15d1f60e4a69b78 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 15 Jul 2024 15:54:38 +0100 Subject: [PATCH 03/10] add some tests --- src/sampler.jl | 8 +++- test/sampler.jl | 112 +++++++++++++++++++++++++----------------------- 2 files changed, 66 insertions(+), 54 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 8aa6db33c..675278537 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -164,10 +164,16 @@ function set_values!!( return setindex!!(varinfo, theta, spl) end +# if initialize with scalar, convert to vector +function set_values!!(varinfo::AbstractVarInfo, initial_params, spl::AbstractSampler) + return set_values!!(varinfo, [initial_params], spl) +end + function set_values!!( varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler ) - return DynamicPPL.TestUtils.update_values!!(varinfo, initial_params, keys(varinfo, spl)) + initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) + return DynamicPPL.TestUtils.update_values!!(varinfo, initial_params, map(VarName, keys(initial_params))) end function initialize_parameters!!( diff --git a/test/sampler.jl b/test/sampler.jl index b52a9c921..714362195 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,23 +84,25 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - chain = sample(model, sampler, 1; initial_params=0.2, progress=false) - @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(0.2, 10), - progress=false, - ) - for c in chains - @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue + for inits in (0.2, (; p=0.2)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) + @test chain[1].metadata.p.vals == [0.2] + @test getlogp(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.p.vals == [0.2] + @test getlogp(c[1]) == lptrue + end end # model with two variables: initialization s = 4, m = -1 @@ -110,45 +112,49 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false) - @test chain[1].metadata.s.vals == [4] - @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill([4, -1], 10), - progress=false, - ) - for c in chains - @test c[1].metadata.s.vals == [4] - @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue + for inits in ([4, -1], (; s=4, m=-1)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) + @test chain[1].metadata.s.vals == [4] + @test chain[1].metadata.m.vals == [-1] + @test getlogp(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.s.vals == [4] + @test c[1].metadata.m.vals == [-1] + @test getlogp(c[1]) == lptrue + end end # set only m = -1 - chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false) - @test !ismissing(chain[1].metadata.s.vals[1]) - @test chain[1].metadata.m.vals == [-1] - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill([missing, -1], 10), - progress=false, - ) - for c in chains - @test !ismissing(c[1].metadata.s.vals[1]) - @test c[1].metadata.m.vals == [-1] + for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) + @test !ismissing(chain[1].metadata.s.vals[1]) + @test chain[1].metadata.m.vals == [-1] + + # parallel sampling + chains = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test !ismissing(c[1].metadata.s.vals[1]) + @test c[1].metadata.m.vals == [-1] + end end # specify `initial_params=nothing` From 4693fa210db73bbcbd54ae15303f92e2ee841cc7 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Mon, 15 Jul 2024 22:58:42 +0800 Subject: [PATCH 04/10] Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sampler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 675278537..3feedfd4e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -173,7 +173,9 @@ function set_values!!( varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler ) initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return DynamicPPL.TestUtils.update_values!!(varinfo, initial_params, map(VarName, keys(initial_params))) + return DynamicPPL.TestUtils.update_values!!( + varinfo, initial_params, map(VarName, keys(initial_params)) + ) end function initialize_parameters!!( From bab789e9a3f631c014fd97ac75621d0dcd0e3c02 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 15 Jul 2024 17:25:51 +0100 Subject: [PATCH 05/10] fix type error for inits with `nothing` --- src/sampler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 3feedfd4e..cba4ece1a 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -143,7 +143,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). initialsampler(spl::Sampler) = SampleFromPrior() function set_values!!( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Real}, spl::AbstractSampler + varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}, spl::AbstractSampler ) theta = varinfo[spl] length(theta) == length(initial_params) || throw( @@ -165,7 +165,7 @@ function set_values!!( end # if initialize with scalar, convert to vector -function set_values!!(varinfo::AbstractVarInfo, initial_params, spl::AbstractSampler) +function set_values!!(varinfo::AbstractVarInfo, initial_params::Real, spl::AbstractSampler) return set_values!!(varinfo, [initial_params], spl) end @@ -174,7 +174,7 @@ function set_values!!( ) initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) return DynamicPPL.TestUtils.update_values!!( - varinfo, initial_params, map(VarName, keys(initial_params)) + varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) ) end From c3ce3fb8f7c876622ed2c344f7f91888b54d464a Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 16 Jul 2024 00:35:12 +0800 Subject: [PATCH 06/10] Update src/sampler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sampler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index cba4ece1a..3a67dc778 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -143,7 +143,9 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). initialsampler(spl::Sampler) = SampleFromPrior() function set_values!!( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}, spl::AbstractSampler + varinfo::AbstractVarInfo, + initial_params::AbstractVector{<:Union{Real,Missing}}, + spl::AbstractSampler, ) theta = varinfo[spl] length(theta) == length(initial_params) || throw( From 22942d498f5a414a743e67aa8f4619315b3ac852 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 08:47:20 +0100 Subject: [PATCH 07/10] use better variable names --- src/sampler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 3a67dc778..1e11a280b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -147,8 +147,8 @@ function set_values!!( initial_params::AbstractVector{<:Union{Real,Missing}}, spl::AbstractSampler, ) - theta = varinfo[spl] - length(theta) == length(initial_params) || throw( + flattened_param_vals = varinfo[spl] + length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))", ), @@ -158,12 +158,12 @@ function set_values!!( for i in eachindex(initial_params) x = initial_params[i] if x !== missing - theta[i] = x + flattened_param_vals[i] = x end end # Update in `varinfo`. - return setindex!!(varinfo, theta, spl) + return setindex!!(varinfo, flattened_param_vals, spl) end # if initialize with scalar, convert to vector From a3969d80b098bd5a239366a98f2498a38d7958be Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 08:49:57 +0100 Subject: [PATCH 08/10] remove init with scalar --- src/sampler.jl | 5 ----- test/sampler.jl | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 1e11a280b..0db63609b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -166,11 +166,6 @@ function set_values!!( return setindex!!(varinfo, flattened_param_vals, spl) end -# if initialize with scalar, convert to vector -function set_values!!(varinfo::AbstractVarInfo, initial_params::Real, spl::AbstractSampler) - return set_values!!(varinfo, [initial_params], spl) -end - function set_values!!( varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler ) diff --git a/test/sampler.jl b/test/sampler.jl index 714362195..b29d3caf1 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,7 +84,7 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - for inits in (0.2, (; p=0.2)) + let inits = (; p=0.2) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogp(chain[1]) == lptrue From 1b2ce4763254463582285bd91e8d7d1cdc0db1f9 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 08:53:36 +0100 Subject: [PATCH 09/10] move `update_values!!` out of `TestUtils` --- src/sampler.jl | 2 +- src/test_utils.jl | 12 ------------ src/utils.jl | 12 ++++++++++++ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 0db63609b..cfc58942e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -170,7 +170,7 @@ function set_values!!( varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler ) initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return DynamicPPL.TestUtils.update_values!!( + return update_values!!( varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) ) end diff --git a/src/test_utils.jl b/src/test_utils.jl index 72ccf6e4f..02501f510 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -13,18 +13,6 @@ using Accessors: Accessors # For backwards compat. using DynamicPPL: varname_leaves -""" - update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) - -Return instance similar to `vi` but with `vns` set to values from `vals`. -""" -function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) - for vn in vns - vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) - end - return vi -end - """ test_values(vi::AbstractVarInfo, vals::NamedTuple, vns) diff --git a/src/utils.jl b/src/utils.jl index 9493e1bc9..4bf652363 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -796,6 +796,18 @@ function nested_getindex(values::AbstractDict, vn::VarName) return child(value) end +""" + update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) + +Return instance similar to `vi` but with `vns` set to values from `vals`. +""" +function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) + for vn in vns + vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) + end + return vi +end + """ float_type_with_fallback(x) From 14bb2cf934a839669b4a089535cd7b6a025913fd Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 22 Jul 2024 09:47:47 +0100 Subject: [PATCH 10/10] fix error --- src/test_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 02501f510..bf7be0a9a 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,7 +11,7 @@ using Bijectors: Bijectors using Accessors: Accessors # For backwards compat. -using DynamicPPL: varname_leaves +using DynamicPPL: varname_leaves, update_values!! """ test_values(vi::AbstractVarInfo, vals::NamedTuple, vns)