From 028a81ab4503c823d55d36510cf3ce11e97c27bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 17:38:54 +0100 Subject: [PATCH 01/27] added `subset` which can extract a subset of the varinfo --- src/DynamicPPL.jl | 1 + src/varinfo.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 042931ebb..4a326a7e8 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -47,6 +47,7 @@ export AbstractVarInfo, SimpleVarInfo, push!!, empty!!, + subset, getlogp, setlogp!!, acclogp!!, diff --git a/src/varinfo.jl b/src/varinfo.jl index ddb4caffb..3bef4cc56 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -236,6 +236,66 @@ else _tail(nt::NamedTuple) = Base.tail(nt) end +# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert +# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which +# might result in a `Vector{Any}`. +""" + subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) + +Subset a `varinfo` to only contain the variables `vns`. +""" +function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) + metadata = subset(varinfo.metadata, vns) + return VarInfo(metadata, varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} + # If all the variables are using the same symbol, then we can just extract that field from the metadata. + metadata = subset(getfield(varinfo.metadata, sym), vns) + return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) + syms = Tuple(unique(map(getsym, vns))) + metadatas = map(syms) do sym + subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) + end + + return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) +end + +""" + subset(metadata::Metadata, vns::AbstractVector{<:VarName}) + +Subset a `metadata` to only contain the variables `vns`. +""" +function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) + # TODO: Should we error if `vns` contains a variable that is not in `metadata`? + indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) + indices = Dict(vn => i for (i, vn) in enumerate(vns)) + # HACK: maintaining consistency between `vals` and `ranges` in scenarios where + # `vns = [@varname(x[2])]` and `metadata` contains `x[1]` and `x[2]` is difficult. + # There are two options: + # 1. Keep ranges as they are and simply `copy` the full `vals`. + # 2. Adjust the ranges to be consistent with the `vals`. + # We choose option 1 for now, though this feels quite hacky. + ranges = metadata.ranges[indices_for_vns] + # vals = mapreduce(Base.Fix1(getindex, metadata.vals), vcat, ranges) + vals = copy(metadata.vals) + + flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) + return Metadata( + indices, + vns, + ranges, + vals, + metadata.dists[indices_for_vns], + metadata.gids, + metadata.orders[indices_for_vns], + flags, + ) +end + const VarView = Union{Int,UnitRange,Vector{Int}} """ From caa6e25a02e1badd9eba1ba0d0adc6fa7a62d319 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 17:40:18 +0100 Subject: [PATCH 02/27] added testing of `subset` for `VarInfo` --- test/varinfo.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 598ea7814..83d585786 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -421,4 +421,48 @@ end end end + + @testset "subset" begin + @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = TV(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + end + model = demo_subsetting_varinfo() + + @testset "$(short_varinfo_name(varinfo))" for varinfo in [ + VarInfo(model), + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + ] + + # All variables. + @test isempty(setdiff(keys(varinfo), [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])])) + + @testset "$(convert(Vector{VarName}, vns))" for vns in [ + [@varname(s)], + [@varname(m)], + [@varname(x[1])], + [@varname(x[2])], + [@varname(s), @varname(m)], + [@varname(s), @varname(x[1])], + [@varname(s), @varname(x[2])], + [@varname(m), @varname(x[1])], + [@varname(m), @varname(x[2])], + [@varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1])], + [@varname(s), @varname(m), @varname(x[2])], + [@varname(s), @varname(x[1]), @varname(x[2])], + [@varname(m), @varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + ] + varinfo_subset = subset(varinfo, vns) + # Should now only contain the variables in `vns`. + @test isempty(setdiff(keys(varinfo_subset), vns)) + # Values should be the same. + @test [varinfo_subset[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + end + end end From cac7fa884def637e74c2be067e3264ea95e653d6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 17:43:06 +0100 Subject: [PATCH 03/27] formatting --- test/varinfo.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 83d585786..57a5ac37e 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -429,16 +429,21 @@ x = TV(undef, 2) x[1] ~ Normal(m, sqrt(s)) x[2] ~ Normal(m, sqrt(s)) + return nothing end model = demo_subsetting_varinfo() @testset "$(short_varinfo_name(varinfo))" for varinfo in [ - VarInfo(model), - last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) - ] + VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + ] # All variables. - @test isempty(setdiff(keys(varinfo), [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])])) + @test isempty( + setdiff( + keys(varinfo), + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + ), + ) @testset "$(convert(Vector{VarName}, vns))" for vns in [ [@varname(s)], @@ -456,7 +461,7 @@ [@varname(s), @varname(x[1]), @varname(x[2])], [@varname(m), @varname(x[1]), @varname(x[2])], [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], - ] + ] varinfo_subset = subset(varinfo, vns) # Should now only contain the variables in `vns`. @test isempty(setdiff(keys(varinfo_subset), vns)) From 5e41c4fdedd095640bd717854be507c485455392 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 20:16:43 +0100 Subject: [PATCH 04/27] added implementation of `merge` for `VarInfo` and tests for it --- src/varinfo.jl | 173 +++++++++++++++++++++++++++++++++++++++++++++++- test/varinfo.jl | 40 +++++++++++ 2 files changed, 211 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3bef4cc56..5b26c1fc9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -269,7 +269,7 @@ end Subset a `metadata` to only contain the variables `vns`. """ -function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) +function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) # TODO: Should we error if `vns` contains a variable that is not in `metadata`? indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) indices = Dict(vn => i for (i, vn) in enumerate(vns)) @@ -279,6 +279,7 @@ function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) # 1. Keep ranges as they are and simply `copy` the full `vals`. # 2. Adjust the ranges to be consistent with the `vals`. # We choose option 1 for now, though this feels quite hacky. + # TODO: Only pick the subset of `vals` needed. ranges = metadata.ranges[indices_for_vns] # vals = mapreduce(Base.Fix1(getindex, metadata.vals), vcat, ranges) vals = copy(metadata.vals) @@ -296,6 +297,163 @@ function subset(metadata::DynamicPPL.Metadata, vns::AbstractVector{<:VarName}) ) end +""" + merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + +Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. +""" +Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =_merge(varinfo_left, varinfo_right) +Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) =_merge(varinfo_left, varinfo_right) + +function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) + lp = getlogp(varinfo_left) + getlogp(varinfo_right) + # TODO: Is this really the way we want to combine `num_produce`? + num_produce = varinfo_left.num_produce[] + varinfo_right.num_produce[] + return VarInfo(metadata, Ref(lp), Ref(num_produce)) +end + +function merge_metadata( + metadata_left::NamedTuple{names_left}, + metadata_right::NamedTuple{names_right} +) where {names_left, names_right} + # TODO: Improve this. Maybe make `@generated`? + metadata = map(names_left) do sym + if sym in names_right + merge_metadata(getfield(metadata_left, sym), getfield(metadata_right, sym)) + else + getfield(metadata_left, sym) + end + end + names_right_only = filter(∉(names_left), names_right) + metadata_right_only = map(Tuple(names_right_only)) do sym + if !(sym in names_left) + getfield(metadata_right, sym) + end + end + + return NamedTuple{(names_left..., names_right_only...)}(tuple(metadata..., metadata_right_only...)) +end + +function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) + # Extract the varnames. + vns_left = metadata_left.vns + vns_right = metadata_right.vns + vns_both = union(vns_left, vns_right) + + # Determine `eltype` of `vals`. + T_left = eltype(metadata_left.vals) + T_right = eltype(metadata_right.vals) + T = promote_type(T_left, T_right) + # TODO: Is this necessary? + if !(T <: Real) + T = Real + end + + # Determine `eltype` of `dists`. + D_left = eltype(metadata_left.dists) + D_right = eltype(metadata_right.dists) + D = promote_type(D_left, D_right) + # TODO: Is this necessary? + if !(D <: Distribution) + D = Distribution + end + + # Initialize required fields for `metadata`. + vns = VarName[] + idcs = Dict{VarName, Int}() + ranges = Vector{UnitRange{Int}}() + vals = T[] + dists = D[] + gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + orders = Int[] + flags = Dict{String, BitVector}() + # Initialize the `flags`. + for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) + flags[k] = BitVector() + end + + # Range offset. + offset = 0 + + for (idx, vn) in enumerate(vns_both) + # `idcs` + idcs[vn] = idx + # `vns` + push!(vns, vn) + if vn in vns_left && vn in vns_right + # `vals`: only valid if they're the length. + vals_left = getval(metadata_left, vn) + vals_right = getval(metadata_right, vn) + @assert length(vals_left) == length(vals_right) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists`: only valid if they're the same. + dists_left = getdist(metadata_left, vn) + dists_right = getdist(metadata_right, vn) + @assert dists_left == dists_right + push!(dists, dists_left) + # `orders`: giving precedence to `metadata_right` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + # Using `metadata_right`; should we? + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + elseif vn in vns_left + # Just extract the metadata from `metadata_left`. + # `vals` + vals_left = getval(metadata_left, vn) + append!(vals, vals_left) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_left = getdist(metadata_left, vn) + push!(dists, dists_left) + # `orders` + push!(orders, getorder(metadata_left, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_left, vn, k)) + end + else + # Just extract the metadata from `metadata_right`. + # `vals` + vals_right = getvals(metadata_right, vn) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_right)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_right = getdist(metadata_right, vn) + push!(dists, dists_right) + # `orders` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + end + end + + return Metadata( + idcs, + vns, + ranges, + vals, + dists, + gids, + orders, + flags, + ) +end + const VarView = Union{Int,UnitRange,Vector{Int}} """ @@ -1391,6 +1549,16 @@ function setorder!(vi::VarInfo, vn::VarName, index::Int) return vi end +""" + getorder(vi::VarInfo, vn::VarName) + +Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements +run before sampling `vn`. +""" +getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) +getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] + + ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1401,8 +1569,9 @@ end Check whether `vn` has a true value for `flag` in `vi`. """ function is_flagged(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] + return is_flagged(getmetadata(vi, vn), vn, flag) end +is_flagged(metadata::Metadata, vn::VarName, flag::String) = metadata.flags[flag][getidx(metadata, vn)] """ unset_flag!(vi::VarInfo, vn::VarName, flag::String) diff --git a/test/varinfo.jl b/test/varinfo.jl index 57a5ac37e..1273b5ab5 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -470,4 +470,44 @@ end end end + + @testset "merge" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(short_varinfo_name(varinfo))" for varinfo in [ + VarInfo(model), + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + ] + + vns = DynamicPPL.TestUtils.varnames(model) + + @testset "with itself" begin + # Merging itself should be a no-op. + varinfo_merged = merge(varinfo, varinfo) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with empty" begin + # Merging with an empty `VarInfo` should be a no-op. + varinfo_merged = merge(varinfo, empty!!(deepcopy(varinfo))) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with different value" begin + x = DynamicPPL.TestUtils.rand(model) + varinfo_changed = DynamicPPL.TestUtils.update_values!!(deepcopy(varinfo), x, vns) + # After `merge`, we should have the same values as `x`. + varinfo_merged = merge(varinfo, varinfo_changed) + DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) + end + end + end + end end From d5a26314ef23541d3204f5c3cbbcc0441c8f8d8f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 20:19:53 +0100 Subject: [PATCH 05/27] more tests --- test/varinfo.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 1273b5ab5..f14e5ca60 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -432,6 +432,7 @@ return nothing end model = demo_subsetting_varinfo() + vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] @testset "$(short_varinfo_name(varinfo))" for varinfo in [ VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) @@ -441,11 +442,11 @@ @test isempty( setdiff( keys(varinfo), - [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + vns, ), ) - @testset "$(convert(Vector{VarName}, vns))" for vns in [ + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [ [@varname(s)], [@varname(m)], [@varname(x[1])], @@ -462,11 +463,19 @@ [@varname(m), @varname(x[1]), @varname(x[2])], [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - varinfo_subset = subset(varinfo, vns) - # Should now only contain the variables in `vns`. - @test isempty(setdiff(keys(varinfo_subset), vns)) + varinfo_subset = subset(varinfo, vns_subset) + # Should now only contain the variables in `vns_subset`. + @test isempty(setdiff(keys(varinfo_subset), vns_subset)) # Values should be the same. - @test [varinfo_subset[vn] for vn in vns] == [varinfo[vn] for vn in vns] + @test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset] + + # `merge` with the original. + varinfo_merged = merge(varinfo, varinfo_subset) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end end end @@ -477,9 +486,7 @@ VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) ] - vns = DynamicPPL.TestUtils.varnames(model) - @testset "with itself" begin # Merging itself should be a no-op. varinfo_merged = merge(varinfo, varinfo) From 0ade696b80a2e153ccf679dd88259830c1dafd00 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 20:20:57 +0100 Subject: [PATCH 06/27] formatting --- src/varinfo.jl | 36 ++++++++++++++++-------------------- test/varinfo.jl | 13 +++++-------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 5b26c1fc9..59e5ca9e3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -302,8 +302,11 @@ end Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. """ -Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) =_merge(varinfo_left, varinfo_right) -Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) =_merge(varinfo_left, varinfo_right) +Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) = + _merge(varinfo_left, varinfo_right) +function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) + return _merge(varinfo_left, varinfo_right) +end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) @@ -314,9 +317,8 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) end function merge_metadata( - metadata_left::NamedTuple{names_left}, - metadata_right::NamedTuple{names_right} -) where {names_left, names_right} + metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} +) where {names_left,names_right} # TODO: Improve this. Maybe make `@generated`? metadata = map(names_left) do sym if sym in names_right @@ -332,7 +334,9 @@ function merge_metadata( end end - return NamedTuple{(names_left..., names_right_only...)}(tuple(metadata..., metadata_right_only...)) + return NamedTuple{(names_left..., names_right_only...)}( + tuple(metadata..., metadata_right_only...) + ) end function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) @@ -361,13 +365,13 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # Initialize required fields for `metadata`. vns = VarName[] - idcs = Dict{VarName, Int}() + idcs = Dict{VarName,Int}() ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` orders = Int[] - flags = Dict{String, BitVector}() + flags = Dict{String,BitVector}() # Initialize the `flags`. for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) flags[k] = BitVector() @@ -442,16 +446,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) end end - return Metadata( - idcs, - vns, - ranges, - vals, - dists, - gids, - orders, - flags, - ) + return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -1558,7 +1553,6 @@ run before sampling `vn`. getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] - ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1571,7 +1565,9 @@ Check whether `vn` has a true value for `flag` in `vi`. function is_flagged(vi::VarInfo, vn::VarName, flag::String) return is_flagged(getmetadata(vi, vn), vn, flag) end -is_flagged(metadata::Metadata, vn::VarName, flag::String) = metadata.flags[flag][getidx(metadata, vn)] +function is_flagged(metadata::Metadata, vn::VarName, flag::String) + return metadata.flags[flag][getidx(metadata, vn)] +end """ unset_flag!(vi::VarInfo, vn::VarName, flag::String) diff --git a/test/varinfo.jl b/test/varinfo.jl index f14e5ca60..bb72e0a50 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -439,12 +439,7 @@ ] # All variables. - @test isempty( - setdiff( - keys(varinfo), - vns, - ), - ) + @test isempty(setdiff(keys(varinfo), vns)) @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [ [@varname(s)], @@ -484,7 +479,7 @@ @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS @testset "$(short_varinfo_name(varinfo))" for varinfo in [ VarInfo(model), - last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())), ] vns = DynamicPPL.TestUtils.varnames(model) @testset "with itself" begin @@ -509,7 +504,9 @@ @testset "with different value" begin x = DynamicPPL.TestUtils.rand(model) - varinfo_changed = DynamicPPL.TestUtils.update_values!!(deepcopy(varinfo), x, vns) + varinfo_changed = DynamicPPL.TestUtils.update_values!!( + deepcopy(varinfo), x, vns + ) # After `merge`, we should have the same values as `x`. varinfo_merged = merge(varinfo, varinfo_changed) DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) From db218449013035b51f231cd6bc38cb09682ddadf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 11:27:13 +0100 Subject: [PATCH 07/27] improved merge_metadata for NamedTuple inputs --- src/varinfo.jl | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 59e5ca9e3..945eda0eb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -302,8 +302,9 @@ end Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. """ -Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) = - _merge(varinfo_left, varinfo_right) +function Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) + return _merge(varinfo_left, varinfo_right) +end function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) return _merge(varinfo_left, varinfo_right) end @@ -316,27 +317,31 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) return VarInfo(metadata, Ref(lp), Ref(num_produce)) end -function merge_metadata( +@generated function merge_metadata( metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} ) where {names_left,names_right} - # TODO: Improve this. Maybe make `@generated`? - metadata = map(names_left) do sym + names = Expr(:tuple) + vals = Expr(:tuple) + # Loop over `names_left` first because we want to preserve the order of the variables. + for sym in names_left + push!(names.args, QuoteNode(sym)) if sym in names_right - merge_metadata(getfield(metadata_left, sym), getfield(metadata_right, sym)) + push!( + vals.args, + :(merge_metadata(metadata_left.$sym, metadata_right.$sym)) + ) else - getfield(metadata_left, sym) + push!(vals.args, :(metadata_left.$sym)) end end + # Loop over remaining variables in `names_right`. names_right_only = filter(∉(names_left), names_right) - metadata_right_only = map(Tuple(names_right_only)) do sym - if !(sym in names_left) - getfield(metadata_right, sym) - end + for sym in names_right_only + push!(names.args, QuoteNode(sym)) + push!(vals.args, :(metadata_right.$sym)) end - return NamedTuple{(names_left..., names_right_only...)}( - tuple(metadata..., metadata_right_only...) - ) + return :(NamedTuple{$names}($vals)) end function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) From 1dbca4c4c64591cb9a96c6f2baa0031bad10950b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 12:35:28 +0100 Subject: [PATCH 08/27] added proper handling of the `vals` in `subset` --- src/varinfo.jl | 51 +++++++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 945eda0eb..07c0df98a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -264,25 +264,41 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) end -""" - subset(metadata::Metadata, vns::AbstractVector{<:VarName}) - -Subset a `metadata` to only contain the variables `vns`. -""" function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) # TODO: Should we error if `vns` contains a variable that is not in `metadata`? indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) indices = Dict(vn => i for (i, vn) in enumerate(vns)) - # HACK: maintaining consistency between `vals` and `ranges` in scenarios where - # `vns = [@varname(x[2])]` and `metadata` contains `x[1]` and `x[2]` is difficult. - # There are two options: - # 1. Keep ranges as they are and simply `copy` the full `vals`. - # 2. Adjust the ranges to be consistent with the `vals`. - # We choose option 1 for now, though this feels quite hacky. - # TODO: Only pick the subset of `vals` needed. - ranges = metadata.ranges[indices_for_vns] - # vals = mapreduce(Base.Fix1(getindex, metadata.vals), vcat, ranges) - vals = copy(metadata.vals) + # Construct new `vals` and `ranges`. + vals_original = metadata.vals + ranges_original = metadata.ranges + # Allocate the new `vals`. and `ranges`. + vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns])) + ranges = similar(ranges_original) + # The new range `r` for `vns[i]` is offset by `offset` and + # has the same length as the original range `r_original`. + # The new `indices` (from above) ensures ordering according to `vns`. + # NOTE: This means that the order of the variables in `vns` defines the order + # in the resulting `varinfo`! This can have performance implications, e.g. + # if in the model we have something like + # + # for i = 1:N + # x[i] ~ Normal() + # end + # + # and we then we do + # + # subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) + # + # the resulting `varinfo` will have `vals` ordered differently from the + # original `varinfo`, which can have performance implications. + offset = 0 + for (idx, idx_original) in enumerate(indices_for_vns) + r_original = ranges_original[idx_original] + r = (offset + 1):(offset + length(r_original)) + vals[r] = vals_original[r_original] + ranges[idx] = r + offset = r[end] + end flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) return Metadata( @@ -302,10 +318,7 @@ end Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. """ -function Base.merge(varinfo_left::UntypedVarInfo, varinfo_right::UntypedVarInfo) - return _merge(varinfo_left, varinfo_right) -end -function Base.merge(varinfo_left::TypedVarInfo, varinfo_right::TypedVarInfo) +function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) return _merge(varinfo_left, varinfo_right) end From b67288f415db1578695810a754f48353feb40b1f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 12:35:50 +0100 Subject: [PATCH 09/27] added docs for `subset` and `merge` --- src/varinfo.jl | 110 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 07c0df98a..352708db5 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -243,6 +243,116 @@ end subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) Subset a `varinfo` to only contain the variables `vns`. + +!!! warning + The ordering of the variables in the resulting `varinfo` will _not_ + necessarily follow the ordering of the variables in `varinfo`. + Hence care must be taken, in particular when used in conjunction with + other methods which uses the vector-representation of `varinfo`, e.g. + `getindex(varinfo, sampler)` + +# Examples +```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) +julia> @model function demo() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = Vector{Float64}(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> varinfo = VarInfo(model); + +julia> keys(varinfo) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> for (i, vn) in enumerate(keys(varinfo)) + varinfo[vn] = i + end + +julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 + +julia> # Extract one with only `m`. + varinfo_subset1 = subset(varinfo, [@varname(m),]); + + +julia> keys(varinfo_subset1) +1-element Vector{VarName{:m, Setfield.IdentityLens}}: + m + +julia> varinfo_subset1[@varname(m)] +2.0 + +julia> # Extract one with both `s` and `x[2]`. + varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + +julia> keys(varinfo_subset2) +2-element Vector{VarName}: + s + x[2] + +julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +2-element Vector{Float64}: + 1.0 + 4.0 +``` + +`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) + +```jldoctest varinfo-subset +julia> # Merge the two. + varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + +julia> keys(varinfo_subset_merged) +3-element Vector{VarName}: + m + s + x[2] + +julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +3-element Vector{Float64}: + 1.0 + 2.0 + 4.0 + +julia> # Merge the two with the original. + varinfo_merged = merge(varinfo, varinfo_subset_merged); + +julia> keys(varinfo_merged) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 +``` + +# Notes + +## Type-stability + +!!! warning + This function is only type-stable when `vns` contains only varnames + with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will + be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. """ function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) From e43029e0172629b8f1a8c4368ff7eab1b3345cca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 12:36:05 +0100 Subject: [PATCH 10/27] added `subset` and `merge` to documentation --- docs/src/api.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index ddd119816..47a92c07b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -255,6 +255,8 @@ DynamicPPL.reconstruct #### Utils ```@docs +Base.merge(::VarInfo, ::VarInfo) +DynamicPPL.subset DynamicPPL.unflatten DynamicPPL.tonamedtuple DynamicPPL.varname_leaves From cd4033d3ee7bf8ac6f2fc48a704f9af340cfe0fc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 12:47:23 +0100 Subject: [PATCH 11/27] formatting --- src/varinfo.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 352708db5..f3298518b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -449,10 +449,7 @@ end for sym in names_left push!(names.args, QuoteNode(sym)) if sym in names_right - push!( - vals.args, - :(merge_metadata(metadata_left.$sym, metadata_right.$sym)) - ) + push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym))) else push!(vals.args, :(metadata_left.$sym)) end From 8f47dfe73a59019e61020f8f7fc2d77279ab6270 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 15:02:01 +0100 Subject: [PATCH 12/27] made merge and subset part of the AbstractVarInfo interface --- src/abstract_varinfo.jl | 152 ++++++++++++++++++++++++++++++++++++++++ src/varinfo.jl | 118 ------------------------------- 2 files changed, 152 insertions(+), 118 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index de1efe4c1..13e0ec753 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation bijector::F end +""" + merge_transformations(transformation_left, transformation_right) + +Merge two transformations. + +The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref). +""" +function merge_transformations(::NoTransformation, ::NoTransformation) + return NoTransformation() +end +function merge_transformations(::DynamicTransformation, ::DynamicTransformation) + return DynamicTransformation() +end +function merge_transformations(left::StaticTransformation, right::StaticTransformation) + return StaticTransformation(merge_bijectors(left.bijector, right.bijector)) +end + +function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform) + return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs)) +end + """ default_transformation(model::Model[, vi::AbstractVarInfo]) @@ -337,6 +358,137 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) end +# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert +# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which +# might result in a `Vector{Any}`. +""" + subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName}) + +Subset a `varinfo` to only contain the variables `vns`. + +!!! warning + The ordering of the variables in the resulting `varinfo` is _not_ + guaranteed to follow the ordering of the variables in `varinfo`. + Hence care must be taken, in particular when used in conjunction with + other methods which uses the vector-representation of the `varinfo`, + e.g. `getindex(varinfo, sampler)`. + +# Examples +```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) +julia> @model function demo() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = Vector{Float64}(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> varinfo = VarInfo(model); + +julia> keys(varinfo) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> for (i, vn) in enumerate(keys(varinfo)) + varinfo[vn] = i + end + +julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 + +julia> # Extract one with only `m`. + varinfo_subset1 = subset(varinfo, [@varname(m),]); + + +julia> keys(varinfo_subset1) +1-element Vector{VarName{:m, Setfield.IdentityLens}}: + m + +julia> varinfo_subset1[@varname(m)] +2.0 + +julia> # Extract one with both `s` and `x[2]`. + varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + +julia> keys(varinfo_subset2) +2-element Vector{VarName}: + s + x[2] + +julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +2-element Vector{Float64}: + 1.0 + 4.0 +``` + +`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) + +```jldoctest varinfo-subset +julia> # Merge the two. + varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + +julia> keys(varinfo_subset_merged) +3-element Vector{VarName}: + m + s + x[2] + +julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +3-element Vector{Float64}: + 1.0 + 2.0 + 4.0 + +julia> # Merge the two with the original. + varinfo_merged = merge(varinfo, varinfo_subset_merged); + +julia> keys(varinfo_merged) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 +``` + +# Notes + +## Type-stability + +!!! warning + This function is only type-stable when `vns` contains only varnames + with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will + be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. +""" +function subset end + +@doc """ + merge(varinfo_left, varinfo_right) + +Merge two varinfos into one, giving precedence to `varinfo_right` when reasonable. + +This is particularly useful when combined with [`subset(varinfo, vns)`](@ref). + +See docstring of [`subset(varinfo, vns)`](@ref) for examples. +""" +Base.merge + # Transformations """ istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) diff --git a/src/varinfo.jl b/src/varinfo.jl index f3298518b..4de4718d1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -236,124 +236,6 @@ else _tail(nt::NamedTuple) = Base.tail(nt) end -# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert -# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which -# might result in a `Vector{Any}`. -""" - subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) - -Subset a `varinfo` to only contain the variables `vns`. - -!!! warning - The ordering of the variables in the resulting `varinfo` will _not_ - necessarily follow the ordering of the variables in `varinfo`. - Hence care must be taken, in particular when used in conjunction with - other methods which uses the vector-representation of `varinfo`, e.g. - `getindex(varinfo, sampler)` - -# Examples -```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) -julia> @model function demo() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - x = Vector{Float64}(undef, 2) - x[1] ~ Normal(m, sqrt(s)) - x[2] ~ Normal(m, sqrt(s)) - end -demo (generic function with 2 methods) - -julia> model = demo(); - -julia> varinfo = VarInfo(model); - -julia> keys(varinfo) -4-element Vector{VarName}: - s - m - x[1] - x[2] - -julia> for (i, vn) in enumerate(keys(varinfo)) - varinfo[vn] = i - end - -julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] -4-element Vector{Float64}: - 1.0 - 2.0 - 3.0 - 4.0 - -julia> # Extract one with only `m`. - varinfo_subset1 = subset(varinfo, [@varname(m),]); - - -julia> keys(varinfo_subset1) -1-element Vector{VarName{:m, Setfield.IdentityLens}}: - m - -julia> varinfo_subset1[@varname(m)] -2.0 - -julia> # Extract one with both `s` and `x[2]`. - varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); - -julia> keys(varinfo_subset2) -2-element Vector{VarName}: - s - x[2] - -julia> varinfo_subset2[[@varname(s), @varname(x[2])]] -2-element Vector{Float64}: - 1.0 - 4.0 -``` - -`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) - -```jldoctest varinfo-subset -julia> # Merge the two. - varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); - -julia> keys(varinfo_subset_merged) -3-element Vector{VarName}: - m - s - x[2] - -julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] -3-element Vector{Float64}: - 1.0 - 2.0 - 4.0 - -julia> # Merge the two with the original. - varinfo_merged = merge(varinfo, varinfo_subset_merged); - -julia> keys(varinfo_merged) -4-element Vector{VarName}: - s - m - x[1] - x[2] - -julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] -4-element Vector{Float64}: - 1.0 - 2.0 - 3.0 - 4.0 -``` - -# Notes - -## Type-stability - -!!! warning - This function is only type-stable when `vns` contains only varnames - with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will - be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. -""" function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) return VarInfo(metadata, varinfo.logp, varinfo.num_produce) From aba9008d40a4086a819d06b07aaeed0ad15d7b92 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 15:02:25 +0100 Subject: [PATCH 13/27] added implementations `subset` and `merge` for `SimpleVarInfo` --- src/simple_varinfo.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ src/varinfo.jl | 5 ----- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a9d38fb07..6602920ac 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -419,6 +419,48 @@ function Base.eltype( return V end +# `subset` +function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) + return Setfield.@set varinfo.values = _subset(varinfo.values, vns) +end + +function _subset(x::AbstractDict, vns) + # NOTE: This requires `vns` to be explicitly present in `x`. + if any(!Base.Fix1(haskey, x), vns) + error( + "Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " * + "For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " * + "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not." + ) + end + C = ConstructionBase.constructorof(typeof(x)) + return C(vn => x[vn] for vn in vns) +end + +function _subset(x::NamedTuple, vns) + # NOTE: Here we can only handle `vns` that contain the `IdentityLens`. + if any(!==(Setfield.IdentityLens()) ∘ getlens, vns) + error( + "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * + "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not." + ) + end + + syms = map(getsym, vns) + return NamedTuple{(syms...,)}((map(Base.Fix2(getindex, x), syms)...,)) +end + +# `merge` +function merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) + values = merge(varinfo_left.values, varinfo_right.values) + logp = getlogp(varinfo_left) + getlogp(varinfo_right) + transformation = merge_transformations( + varinfo_left.transformation, + varinfo_right.transformation, + ) + return SimpleVarInfo(values, logp, transformation) +end + # Context implementations # NOTE: Evaluations, i.e. those without `rng` are shared with other # implementations of `AbstractVarInfo`. diff --git a/src/varinfo.jl b/src/varinfo.jl index 4de4718d1..aeca100f2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -305,11 +305,6 @@ function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) ) end -""" - merge(varinfo_left::VarInfo, varinfo_right::VarInfo) - -Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. -""" function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) return _merge(varinfo_left, varinfo_right) end From 3b621ae227aa7dbc4a30e42f62a34c7b06175b5e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 15:11:12 +0100 Subject: [PATCH 14/27] follow standard merge semantics where the right one takes precedence --- src/simple_varinfo.jl | 2 +- src/varinfo.jl | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 6602920ac..156bba21b 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -453,7 +453,7 @@ end # `merge` function merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - logp = getlogp(varinfo_left) + getlogp(varinfo_right) + logp = getlogp(varinfo_right) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation, diff --git a/src/varinfo.jl b/src/varinfo.jl index aeca100f2..752451087 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -311,10 +311,7 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - lp = getlogp(varinfo_left) + getlogp(varinfo_right) - # TODO: Is this really the way we want to combine `num_produce`? - num_produce = varinfo_left.num_produce[] + varinfo_right.num_produce[] - return VarInfo(metadata, Ref(lp), Ref(num_produce)) + return VarInfo(metadata, Ref(getlogp(varinfo_right)), Ref(varinfo_right.num_produce[])) end @generated function merge_metadata( From 2c2c90b5b87e7ebbacb31f72e7f49744182677f3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 15:45:02 +0100 Subject: [PATCH 15/27] added proper testing of merge and subset for SimpleVarInfo too --- src/abstract_varinfo.jl | 15 +++++++++++---- src/simple_varinfo.jl | 2 +- test/varinfo.jl | 16 ++++++++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 13e0ec753..d74f11ec5 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -478,16 +478,23 @@ julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] """ function subset end -@doc """ - merge(varinfo_left, varinfo_right) +""" + merge(varinfo, other_varinfos...) -Merge two varinfos into one, giving precedence to `varinfo_right` when reasonable. +Merge varinfos into one, giving precedence to the right-most varinfo when sensible. This is particularly useful when combined with [`subset(varinfo, vns)`](@ref). See docstring of [`subset(varinfo, vns)`](@ref) for examples. """ -Base.merge +function Base.merge(varinfo::AbstractVarInfo, varinfo_others::AbstractVarInfo...) + return merge(Base.merge(varinfo, first(varinfo_others)), Base.tail(varinfo_others)...) +end + +# Avoid `StackoverFlowError` if implementation is missing. +function Base.merge(varinfo::AbstractVarInfo, varinfo_other::AbstractVarInfo) + throw(MethodError(Base.merge, (varinfo, varinfo_other))) +end # Transformations """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 156bba21b..fa46d2db9 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -451,7 +451,7 @@ function _subset(x::NamedTuple, vns) end # `merge` -function merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) +function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) logp = getlogp(varinfo_right) transformation = merge_transformations( diff --git a/test/varinfo.jl b/test/varinfo.jl index bb72e0a50..0c397e222 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,19 @@ +function check_varinfo_keys(varinfo, vns) + if varinfo isa SimpleVarInfo{<:NamedTuple} + # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, + # since `keys(varinfo_merged)` only contains `VarName` with `IdentityLens`. + # So we just check that the original keys are present. + for vn in vns + # Should have all the original keys. + @test haskey(varinfo, vn) + end + else + vns_varinfo = keys(varinfo) + # Should be equivalent. + @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) + end +end + @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin From 5c1ece3059e7f49c4d4e12555ef7fe59e0bf052a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 15:45:16 +0100 Subject: [PATCH 16/27] forgotten inclusion in previous commit --- test/varinfo.jl | 77 +++++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 0c397e222..cfc9f2ebd 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -445,38 +445,49 @@ end x = TV(undef, 2) x[1] ~ Normal(m, sqrt(s)) x[2] ~ Normal(m, sqrt(s)) - return nothing + return (; s, m, x) end model = demo_subsetting_varinfo() vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] - @testset "$(short_varinfo_name(varinfo))" for varinfo in [ - VarInfo(model), last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())) + # `VarInfo` supports, effectively, arbitrary subsetting. + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, model(), vns) + varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) + varinfos_simple = filter(Base.Fix2(isa, SimpleVarInfo), varinfos) + + # `VarInfo` supports subsetting using, basically, arbitrary varnames. + vns_supported_standard = [ + [@varname(s)], + [@varname(m)], + [@varname(x[1])], + [@varname(x[2])], + [@varname(s), @varname(m)], + [@varname(s), @varname(x[1])], + [@varname(s), @varname(x[2])], + [@varname(m), @varname(x[1])], + [@varname(m), @varname(x[2])], + [@varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1])], + [@varname(s), @varname(m), @varname(x[2])], + [@varname(s), @varname(x[1]), @varname(x[2])], + [@varname(m), @varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] + # `SimpleaVarInfo` only supports subsetting using the varnames as they appear + # in the model. + vns_supported_simple = filter(∈(vns), vns_supported_standard) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos_standard # All variables. - @test isempty(setdiff(keys(varinfo), vns)) - - @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in [ - [@varname(s)], - [@varname(m)], - [@varname(x[1])], - [@varname(x[2])], - [@varname(s), @varname(m)], - [@varname(s), @varname(x[1])], - [@varname(s), @varname(x[2])], - [@varname(m), @varname(x[1])], - [@varname(m), @varname(x[2])], - [@varname(x[1]), @varname(x[2])], - [@varname(s), @varname(m), @varname(x[1])], - [@varname(s), @varname(m), @varname(x[2])], - [@varname(s), @varname(x[1]), @varname(x[2])], - [@varname(m), @varname(x[1]), @varname(x[2])], - [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], - ] + check_varinfo_keys(varinfo, vns) + + # Added a `convert` to make the naming of the testsets a bit more readable. + vns_supported = varinfo isa SimpleVarInfo ? vns_supported_simple : vns_supported_standard + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in vns_supported varinfo_subset = subset(varinfo, vns_subset) # Should now only contain the variables in `vns_subset`. - @test isempty(setdiff(keys(varinfo_subset), vns_subset)) + check_varinfo_keys(varinfo_subset, vns_subset) # Values should be the same. @test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset] @@ -484,7 +495,7 @@ end varinfo_merged = merge(varinfo, varinfo_subset) vns_merged = keys(varinfo_merged) # Should be equivalent. - @test union(vns_merged, vns) == intersect(vns_merged, vns) + check_varinfo_keys(varinfo_merged, vns) # Values should be the same. @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end @@ -493,17 +504,14 @@ end @testset "merge" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - @testset "$(short_varinfo_name(varinfo))" for varinfo in [ - VarInfo(model), - last(DynamicPPL.evaluate!!(model, VarInfo(), SamplingContext())), - ] - vns = DynamicPPL.TestUtils.varnames(model) + vns = DynamicPPL.TestUtils.varnames(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), vns) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @testset "with itself" begin # Merging itself should be a no-op. varinfo_merged = merge(varinfo, varinfo) - vns_merged = keys(varinfo_merged) - # Should be equivalent. - @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) # Values should be the same. @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end @@ -511,9 +519,8 @@ end @testset "with empty" begin # Merging with an empty `VarInfo` should be a no-op. varinfo_merged = merge(varinfo, empty!!(deepcopy(varinfo))) - vns_merged = keys(varinfo_merged) - # Should be equivalent. - @test union(vns_merged, vns) == intersect(vns_merged, vns) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) # Values should be the same. @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end From cfff96c3d1ea0515267fb1168f637ac049255796 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 17:01:17 +0100 Subject: [PATCH 17/27] Update src/simple_varinfo.jl Co-authored-by: David Widmann --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index fa46d2db9..cbd55c86d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -447,7 +447,7 @@ function _subset(x::NamedTuple, vns) end syms = map(getsym, vns) - return NamedTuple{(syms...,)}((map(Base.Fix2(getindex, x), syms)...,)) + return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms))) end # `merge` From ed5d9488734401e0ab1aefde9c525edf492b1102 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 17:08:58 +0100 Subject: [PATCH 18/27] remove two-argument impl of merge --- src/abstract_varinfo.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index d74f11ec5..9a6b31da0 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -487,13 +487,13 @@ This is particularly useful when combined with [`subset(varinfo, vns)`](@ref). See docstring of [`subset(varinfo, vns)`](@ref) for examples. """ -function Base.merge(varinfo::AbstractVarInfo, varinfo_others::AbstractVarInfo...) - return merge(Base.merge(varinfo, first(varinfo_others)), Base.tail(varinfo_others)...) -end - -# Avoid `StackoverFlowError` if implementation is missing. -function Base.merge(varinfo::AbstractVarInfo, varinfo_other::AbstractVarInfo) - throw(MethodError(Base.merge, (varinfo, varinfo_other))) +Base.merge(varinfo::AbstractVarInfo) = varinfo +function Base.merge( + varinfo_left::AbstractVarInfo, + varinfo_right::AbstractVarInfo, + varinfo_others::AbstractVarInfo... +) + return merge(Base.merge(varinfo_left, varinfo_right), varinfo_others...) end # Transformations From 00c36cf12bf7461b4f08788174c58fdb6b761a25 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 17:10:43 +0100 Subject: [PATCH 19/27] formatting --- src/abstract_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 9a6b31da0..470040ed0 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -491,7 +491,7 @@ Base.merge(varinfo::AbstractVarInfo) = varinfo function Base.merge( varinfo_left::AbstractVarInfo, varinfo_right::AbstractVarInfo, - varinfo_others::AbstractVarInfo... + varinfo_others::AbstractVarInfo..., ) return merge(Base.merge(varinfo_left, varinfo_right), varinfo_others...) end From cf02816a3304ec399fbfbdcc8a921a62cf7e5f91 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 17:11:17 +0100 Subject: [PATCH 20/27] forgot to add more formatting --- src/simple_varinfo.jl | 7 +++---- src/varinfo.jl | 4 +++- test/varinfo.jl | 6 ++++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index cbd55c86d..3be8d2947 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -430,7 +430,7 @@ function _subset(x::AbstractDict, vns) error( "Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " * "For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " * - "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not." + "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.", ) end C = ConstructionBase.constructorof(typeof(x)) @@ -442,7 +442,7 @@ function _subset(x::NamedTuple, vns) if any(!==(Setfield.IdentityLens()) ∘ getlens, vns) error( "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * - "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not." + "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", ) end @@ -455,8 +455,7 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) logp = getlogp(varinfo_right) transformation = merge_transformations( - varinfo_left.transformation, - varinfo_right.transformation, + varinfo_left.transformation, varinfo_right.transformation ) return SimpleVarInfo(values, logp, transformation) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 752451087..a222a6e8c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -311,7 +311,9 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo(metadata, Ref(getlogp(varinfo_right)), Ref(varinfo_right.num_produce[])) + return VarInfo( + metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right)) + ) end @generated function merge_metadata( diff --git a/test/varinfo.jl b/test/varinfo.jl index cfc9f2ebd..8b961082f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -483,8 +483,10 @@ end check_varinfo_keys(varinfo, vns) # Added a `convert` to make the naming of the testsets a bit more readable. - vns_supported = varinfo isa SimpleVarInfo ? vns_supported_simple : vns_supported_standard - @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in vns_supported + vns_supported = + varinfo isa SimpleVarInfo ? vns_supported_simple : vns_supported_standard + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in + vns_supported varinfo_subset = subset(varinfo, vns_subset) # Should now only contain the variables in `vns_subset`. check_varinfo_keys(varinfo_subset, vns_subset) From 7f01ada7515a22a64e7550a443e6065b2803c392 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 17:33:36 +0100 Subject: [PATCH 21/27] removed 2-arg version of merge for abstract varinfo in favour of 3-arg version --- src/abstract_varinfo.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 470040ed0..3f528d544 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -488,12 +488,14 @@ This is particularly useful when combined with [`subset(varinfo, vns)`](@ref). See docstring of [`subset(varinfo, vns)`](@ref) for examples. """ Base.merge(varinfo::AbstractVarInfo) = varinfo +# Define 3-argument version so 2-argument version will error if not implemented. function Base.merge( - varinfo_left::AbstractVarInfo, - varinfo_right::AbstractVarInfo, + varinfo1::AbstractVarInfo, + varinfo2::AbstractVarInfo, + varinfo3::AbstractVarInfo, varinfo_others::AbstractVarInfo..., ) - return merge(Base.merge(varinfo_left, varinfo_right), varinfo_others...) + return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...) end # Transformations From 14105e06c10d1d03f8bf39b7dba8f1d374ab7c85 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 22:42:16 +0100 Subject: [PATCH 22/27] allow inclusion of threadsafe varinfo in setup_varinfos --- src/test_utils.jl | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 14da79afa..5b911e77a 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -37,12 +37,17 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal end """ - setup_varinfos(model::Model, example_values::NamedTuple, varnames) + setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false) Return a tuple of instances for different implementations of `AbstractVarInfo` with each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`. + +If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions +of the varinfo instances. """ -function setup_varinfos(model::Model, example_values::NamedTuple, varnames) +function setup_varinfos( + model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false +) # VarInfo vi_untyped = VarInfo() model(vi_untyped) @@ -56,12 +61,18 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames) svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) lp = getlogp(vi_typed) - return map(( + varinfos = map(( vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref )) do vi # Set them all to the same values. DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) end + + if include_threadsafe + varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) + end + + return varinfos end """ From c164d32bf80cdef1718a3a22e77c36ae50261009 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 22:42:43 +0100 Subject: [PATCH 23/27] more tests for thread safe varinfo --- test/test_util.jl | 3 ++- test/varinfo.jl | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/test/test_util.jl b/test/test_util.jl index 31296f79a..7a7028536 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -82,7 +82,8 @@ end Return string representing a short description of `vi`. """ -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = short_varinfo_name(vi.varinfo) +short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = + "threadsafe($(short_varinfo_name(vi.varinfo)))" short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" diff --git a/test/varinfo.jl b/test/varinfo.jl index e25a56200..21b38c9c2 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,5 +1,5 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa SimpleVarInfo{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, # since `keys(varinfo_merged)` only contains `VarName` with `IdentityLens`. # So we just check that the original keys are present. @@ -342,7 +342,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) vns = DynamicPPL.TestUtils.varnames(model) # Set up the different instances of `AbstractVarInfo` with the desired values. - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, example_values, vns; include_threadsafe=true + ) @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) @@ -385,9 +387,11 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "mutating=$mutating" for mutating in [false, true] value_true = rand(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, value_true, varnames; include_threadsafe=true + ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa SimpleVarInfo{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: this is broken since we'll end up trying to set # # varinfo[@varname(x[4:5])] = [x[4],] @@ -455,7 +459,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] # `VarInfo` supports, effectively, arbitrary subsetting. - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, model(), vns) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, model(), vns; include_threadsafe=true + ) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) varinfos_simple = filter(Base.Fix2(isa, SimpleVarInfo), varinfos) @@ -511,7 +517,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "merge" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), vns) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, rand(model), vns; include_threadsafe=true + ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @testset "with itself" begin # Merging itself should be a no-op. From 743162a1959546d6d8e76987864fde451e69218d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 22:43:06 +0100 Subject: [PATCH 24/27] bugfixes for link and invlink methods when using thread safe varinfo --- src/threadsafe.jl | 56 +++++++++++++++++++++++++++++++++++++++++++---- src/varinfo.jl | 42 +++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f7ab3fa85..ab504de23 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -84,25 +84,56 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl function link!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return link!!(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) end function invlink!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return invlink!!(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) end function link( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return link(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model) end function invlink( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return invlink(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) +end + +# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. +# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure +# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates +# to define `getlogp(vi)`. +function link!!( + t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) +end + +function invlink!!( + ::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return settrans!!( + last(evaluate!!(model, vi, DynamicTransformationContext{true}())), + NoTransformation(), + ) +end + +function link( + t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return link!!(t, deepcopy(vi), spl, model) +end + +function invlink( + t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return invlink!!(t, deepcopy(vi), spl, model) end function maybe_invlink_before_eval!!( @@ -192,3 +223,20 @@ istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn) + +function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) + return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x) +end +function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector) + return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x) +end + +function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) + return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns) +end + +function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo) + return Setfield.@set varinfo_left.varinfo = merge( + varinfo_left.varinfo, varinfo_right.varinfo + ) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 4e4fd4a3f..a1a154133 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -950,6 +950,17 @@ function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, mod return vi end +function link!!( + t::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) +end + """ link!(vi::VarInfo, spl::Sampler) @@ -1025,6 +1036,17 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, m return vi end +function invlink!!( + ::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) +end + function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do @@ -1129,6 +1151,16 @@ end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) return _link(varinfo, spl) end +function link( + ::DynamicTransformation, + varinfo::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) +end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) @@ -1214,6 +1246,16 @@ function invlink( ) return _invlink(varinfo, spl) end +function invlink( + ::DynamicTransformation, + varinfo::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) +end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) From dc9ad94a57be95abfd9190404a2bb3c97fcf8559 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Oct 2023 22:54:12 +0100 Subject: [PATCH 25/27] attempt at fixing docs --- docs/src/api.md | 2 +- src/abstract_varinfo.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 47a92c07b..a729ee754 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -255,7 +255,7 @@ DynamicPPL.reconstruct #### Utils ```@docs -Base.merge(::VarInfo, ::VarInfo) +Base.merge(::AbstractVarInfo) DynamicPPL.subset DynamicPPL.unflatten DynamicPPL.tonamedtuple diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 3f528d544..0218a1882 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -431,7 +431,7 @@ julia> varinfo_subset2[[@varname(s), @varname(x[2])]] 4.0 ``` -`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) +`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref) ```jldoctest varinfo-subset julia> # Merge the two. From 2f320e6566efd9cdef979f280dfa555011cc091b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Oct 2023 17:38:24 +0100 Subject: [PATCH 26/27] fixed missing test coverage --- src/simple_varinfo.jl | 10 +++++----- src/varinfo.jl | 2 +- test/varinfo.jl | 41 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3be8d2947..294ac8f58 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -427,11 +427,11 @@ end function _subset(x::AbstractDict, vns) # NOTE: This requires `vns` to be explicitly present in `x`. if any(!Base.Fix1(haskey, x), vns) - error( + throw(ArgumentError( "Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " * "For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " * "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.", - ) + )) end C = ConstructionBase.constructorof(typeof(x)) return C(vn => x[vn] for vn in vns) @@ -439,11 +439,11 @@ end function _subset(x::NamedTuple, vns) # NOTE: Here we can only handle `vns` that contain the `IdentityLens`. - if any(!==(Setfield.IdentityLens()) ∘ getlens, vns) - error( + if any(Base.Fix1(!==, Setfield.IdentityLens()) ∘ getlens, vns) + throw(ArgumentError( "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", - ) + )) end syms = map(getsym, vns) diff --git a/src/varinfo.jl b/src/varinfo.jl index a1a154133..0d5dce7aa 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -429,7 +429,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) else # Just extract the metadata from `metadata_right`. # `vals` - vals_right = getvals(metadata_right, vn) + vals_right = getval(metadata_right, vn) append!(vals, vals_right) # `ranges` r = (offset + 1):(offset + length(vals_right)) diff --git a/test/varinfo.jl b/test/varinfo.jl index 21b38c9c2..045c7f8a0 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -463,7 +463,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) model, model(), vns; include_threadsafe=true ) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter(Base.Fix2(isa, SimpleVarInfo), varinfos) + varinfos_simple = filter( + Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos + ) # `VarInfo` supports subsetting using, basically, arbitrary varnames. vns_supported_standard = [ @@ -493,8 +495,11 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) check_varinfo_keys(varinfo, vns) # Added a `convert` to make the naming of the testsets a bit more readable. - vns_supported = - varinfo isa SimpleVarInfo ? vns_supported_simple : vns_supported_standard + vns_supported = if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple + vns_supported_simple + else + vns_supported_standard + end @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in vns_supported varinfo_subset = subset(varinfo, vns_subset) @@ -512,6 +517,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end end + + # For certain varinfos we should have errors. + # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `IdentityLens`. + varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] + @testset "$(short_varinfo_name(varinfo)): failure cases" begin + @test_throws ArgumentError subset(varinfo, [@varname(s), @varname(m), @varname(x[1])]) + end + # `SimpleVarInfo{<:AbstractDict}` can only handle varnames as they appear in the model. + varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:AbstractDict}), varinfos)] + @testset "$(short_varinfo_name(varinfo)): failure cases" begin + @test_throws ArgumentError subset(varinfo, [@varname(s), @varname(m), @varname(x)]) + end end @testset "merge" begin @@ -530,7 +547,25 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end + @testset "with itself (3-argument version)" begin + # Merging itself should be a no-op. + varinfo_merged = merge(varinfo, varinfo, varinfo) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + @testset "with empty" begin + # Empty is 1st argument. + # Merging with an empty `VarInfo` should be a no-op. + varinfo_merged = merge(empty!!(deepcopy(varinfo)), varinfo) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + + # Empty is 2nd argument. # Merging with an empty `VarInfo` should be a no-op. varinfo_merged = merge(varinfo, empty!!(deepcopy(varinfo))) # Varnames should be unchanged. From d3a9b56920c7f4e6864a43eaa42db00c54499ece Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Oct 2023 17:38:45 +0100 Subject: [PATCH 27/27] formatting --- src/simple_varinfo.jl | 22 +++++++++++++--------- test/varinfo.jl | 12 +++++++++--- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 294ac8f58..400dd93fe 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -427,11 +427,13 @@ end function _subset(x::AbstractDict, vns) # NOTE: This requires `vns` to be explicitly present in `x`. if any(!Base.Fix1(haskey, x), vns) - throw(ArgumentError( - "Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " * - "For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " * - "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.", - )) + throw( + ArgumentError( + "Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " * + "For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " * + "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.", + ), + ) end C = ConstructionBase.constructorof(typeof(x)) return C(vn => x[vn] for vn in vns) @@ -440,10 +442,12 @@ end function _subset(x::NamedTuple, vns) # NOTE: Here we can only handle `vns` that contain the `IdentityLens`. if any(Base.Fix1(!==, Setfield.IdentityLens()) ∘ getlens, vns) - throw(ArgumentError( - "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * - "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", - )) + throw( + ArgumentError( + "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * + "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", + ), + ) end syms = map(getsym, vns) diff --git a/test/varinfo.jl b/test/varinfo.jl index 045c7f8a0..e306a3df2 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -522,12 +522,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `IdentityLens`. varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset(varinfo, [@varname(s), @varname(m), @varname(x[1])]) + @test_throws ArgumentError subset( + varinfo, [@varname(s), @varname(m), @varname(x[1])] + ) end # `SimpleVarInfo{<:AbstractDict}` can only handle varnames as they appear in the model. - varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:AbstractDict}), varinfos)] + varinfo = varinfos[findfirst( + Base.Fix2(isa, SimpleVarInfo{<:AbstractDict}), varinfos + )] @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset(varinfo, [@varname(s), @varname(m), @varname(x)]) + @test_throws ArgumentError subset( + varinfo, [@varname(s), @varname(m), @varname(x)] + ) end end