diff --git a/docs/src/api.md b/docs/src/api.md index ddd119816..a729ee754 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -255,6 +255,8 @@ DynamicPPL.reconstruct #### Utils ```@docs +Base.merge(::AbstractVarInfo) +DynamicPPL.subset DynamicPPL.unflatten DynamicPPL.tonamedtuple DynamicPPL.varname_leaves diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8e3a778ad..9853d8140 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -48,6 +48,7 @@ export AbstractVarInfo, SimpleVarInfo, push!!, empty!!, + subset, getlogp, setlogp!!, acclogp!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index de1efe4c1..0218a1882 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation bijector::F end +""" + merge_transformations(transformation_left, transformation_right) + +Merge two transformations. + +The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref). +""" +function merge_transformations(::NoTransformation, ::NoTransformation) + return NoTransformation() +end +function merge_transformations(::DynamicTransformation, ::DynamicTransformation) + return DynamicTransformation() +end +function merge_transformations(left::StaticTransformation, right::StaticTransformation) + return StaticTransformation(merge_bijectors(left.bijector, right.bijector)) +end + +function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform) + return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs)) +end + """ default_transformation(model::Model[, vi::AbstractVarInfo]) @@ -337,6 +358,146 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)})) end +# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert +# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which +# might result in a `Vector{Any}`. +""" + subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName}) + +Subset a `varinfo` to only contain the variables `vns`. + +!!! warning + The ordering of the variables in the resulting `varinfo` is _not_ + guaranteed to follow the ordering of the variables in `varinfo`. + Hence care must be taken, in particular when used in conjunction with + other methods which uses the vector-representation of the `varinfo`, + e.g. `getindex(varinfo, sampler)`. + +# Examples +```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) +julia> @model function demo() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = Vector{Float64}(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> varinfo = VarInfo(model); + +julia> keys(varinfo) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> for (i, vn) in enumerate(keys(varinfo)) + varinfo[vn] = i + end + +julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 + +julia> # Extract one with only `m`. + varinfo_subset1 = subset(varinfo, [@varname(m),]); + + +julia> keys(varinfo_subset1) +1-element Vector{VarName{:m, Setfield.IdentityLens}}: + m + +julia> varinfo_subset1[@varname(m)] +2.0 + +julia> # Extract one with both `s` and `x[2]`. + varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + +julia> keys(varinfo_subset2) +2-element Vector{VarName}: + s + x[2] + +julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +2-element Vector{Float64}: + 1.0 + 4.0 +``` + +`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref) + +```jldoctest varinfo-subset +julia> # Merge the two. + varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + +julia> keys(varinfo_subset_merged) +3-element Vector{VarName}: + m + s + x[2] + +julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +3-element Vector{Float64}: + 1.0 + 2.0 + 4.0 + +julia> # Merge the two with the original. + varinfo_merged = merge(varinfo, varinfo_subset_merged); + +julia> keys(varinfo_merged) +4-element Vector{VarName}: + s + m + x[1] + x[2] + +julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +4-element Vector{Float64}: + 1.0 + 2.0 + 3.0 + 4.0 +``` + +# Notes + +## Type-stability + +!!! warning + This function is only type-stable when `vns` contains only varnames + with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will + be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. +""" +function subset end + +""" + merge(varinfo, other_varinfos...) + +Merge varinfos into one, giving precedence to the right-most varinfo when sensible. + +This is particularly useful when combined with [`subset(varinfo, vns)`](@ref). + +See docstring of [`subset(varinfo, vns)`](@ref) for examples. +""" +Base.merge(varinfo::AbstractVarInfo) = varinfo +# Define 3-argument version so 2-argument version will error if not implemented. +function Base.merge( + varinfo1::AbstractVarInfo, + varinfo2::AbstractVarInfo, + varinfo3::AbstractVarInfo, + varinfo_others::AbstractVarInfo..., +) + return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...) +end + # Transformations """ istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a9d38fb07..400dd93fe 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -419,6 +419,51 @@ function Base.eltype( return V end +# `subset` +function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) + return Setfield.@set varinfo.values = _subset(varinfo.values, vns) +end + +function _subset(x::AbstractDict, vns) + # NOTE: This requires `vns` to be explicitly present in `x`. + if any(!Base.Fix1(haskey, x), vns) + throw( + ArgumentError( + "Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " * + "For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " * + "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.", + ), + ) + end + C = ConstructionBase.constructorof(typeof(x)) + return C(vn => x[vn] for vn in vns) +end + +function _subset(x::NamedTuple, vns) + # NOTE: Here we can only handle `vns` that contain the `IdentityLens`. + if any(Base.Fix1(!==, Setfield.IdentityLens()) ∘ getlens, vns) + throw( + ArgumentError( + "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * + "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", + ), + ) + end + + syms = map(getsym, vns) + return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms))) +end + +# `merge` +function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) + values = merge(varinfo_left.values, varinfo_right.values) + logp = getlogp(varinfo_right) + transformation = merge_transformations( + varinfo_left.transformation, varinfo_right.transformation + ) + return SimpleVarInfo(values, logp, transformation) +end + # Context implementations # NOTE: Evaluations, i.e. those without `rng` are shared with other # implementations of `AbstractVarInfo`. diff --git a/src/test_utils.jl b/src/test_utils.jl index 14da79afa..5b911e77a 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -37,12 +37,17 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal end """ - setup_varinfos(model::Model, example_values::NamedTuple, varnames) + setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false) Return a tuple of instances for different implementations of `AbstractVarInfo` with each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`. + +If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions +of the varinfo instances. """ -function setup_varinfos(model::Model, example_values::NamedTuple, varnames) +function setup_varinfos( + model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false +) # VarInfo vi_untyped = VarInfo() model(vi_untyped) @@ -56,12 +61,18 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames) svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) lp = getlogp(vi_typed) - return map(( + varinfos = map(( vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref )) do vi # Set them all to the same values. DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) end + + if include_threadsafe + varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) + end + + return varinfos end """ diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f7ab3fa85..ab504de23 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -84,25 +84,56 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl function link!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return link!!(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) end function invlink!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return invlink!!(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) end function link( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return link(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model) end function invlink( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return invlink(t, vi.varinfo, spl, model) + return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) +end + +# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. +# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure +# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates +# to define `getlogp(vi)`. +function link!!( + t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) +end + +function invlink!!( + ::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return settrans!!( + last(evaluate!!(model, vi, DynamicTransformationContext{true}())), + NoTransformation(), + ) +end + +function link( + t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return link!!(t, deepcopy(vi), spl, model) +end + +function invlink( + t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return invlink!!(t, deepcopy(vi), spl, model) end function maybe_invlink_before_eval!!( @@ -192,3 +223,20 @@ istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn) + +function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) + return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x) +end +function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector) + return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x) +end + +function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) + return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns) +end + +function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo) + return Setfield.@set varinfo_left.varinfo = merge( + varinfo_left.varinfo, varinfo_right.varinfo + ) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3e7dc119f..0d5dce7aa 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -236,6 +236,220 @@ else _tail(nt::NamedTuple) = Base.tail(nt) end +function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) + metadata = subset(varinfo.metadata, vns) + return VarInfo(metadata, varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} + # If all the variables are using the same symbol, then we can just extract that field from the metadata. + metadata = subset(getfield(varinfo.metadata, sym), vns) + return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce) +end + +function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) + syms = Tuple(unique(map(getsym, vns))) + metadatas = map(syms) do sym + subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) + end + + return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) +end + +function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) + # TODO: Should we error if `vns` contains a variable that is not in `metadata`? + indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) + indices = Dict(vn => i for (i, vn) in enumerate(vns)) + # Construct new `vals` and `ranges`. + vals_original = metadata.vals + ranges_original = metadata.ranges + # Allocate the new `vals`. and `ranges`. + vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns])) + ranges = similar(ranges_original) + # The new range `r` for `vns[i]` is offset by `offset` and + # has the same length as the original range `r_original`. + # The new `indices` (from above) ensures ordering according to `vns`. + # NOTE: This means that the order of the variables in `vns` defines the order + # in the resulting `varinfo`! This can have performance implications, e.g. + # if in the model we have something like + # + # for i = 1:N + # x[i] ~ Normal() + # end + # + # and we then we do + # + # subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) + # + # the resulting `varinfo` will have `vals` ordered differently from the + # original `varinfo`, which can have performance implications. + offset = 0 + for (idx, idx_original) in enumerate(indices_for_vns) + r_original = ranges_original[idx_original] + r = (offset + 1):(offset + length(r_original)) + vals[r] = vals_original[r_original] + ranges[idx] = r + offset = r[end] + end + + flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) + return Metadata( + indices, + vns, + ranges, + vals, + metadata.dists[indices_for_vns], + metadata.gids, + metadata.orders[indices_for_vns], + flags, + ) +end + +function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + return _merge(varinfo_left, varinfo_right) +end + +function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) + return VarInfo( + metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right)) + ) +end + +@generated function merge_metadata( + metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} +) where {names_left,names_right} + names = Expr(:tuple) + vals = Expr(:tuple) + # Loop over `names_left` first because we want to preserve the order of the variables. + for sym in names_left + push!(names.args, QuoteNode(sym)) + if sym in names_right + push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym))) + else + push!(vals.args, :(metadata_left.$sym)) + end + end + # Loop over remaining variables in `names_right`. + names_right_only = filter(∉(names_left), names_right) + for sym in names_right_only + push!(names.args, QuoteNode(sym)) + push!(vals.args, :(metadata_right.$sym)) + end + + return :(NamedTuple{$names}($vals)) +end + +function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) + # Extract the varnames. + vns_left = metadata_left.vns + vns_right = metadata_right.vns + vns_both = union(vns_left, vns_right) + + # Determine `eltype` of `vals`. + T_left = eltype(metadata_left.vals) + T_right = eltype(metadata_right.vals) + T = promote_type(T_left, T_right) + # TODO: Is this necessary? + if !(T <: Real) + T = Real + end + + # Determine `eltype` of `dists`. + D_left = eltype(metadata_left.dists) + D_right = eltype(metadata_right.dists) + D = promote_type(D_left, D_right) + # TODO: Is this necessary? + if !(D <: Distribution) + D = Distribution + end + + # Initialize required fields for `metadata`. + vns = VarName[] + idcs = Dict{VarName,Int}() + ranges = Vector{UnitRange{Int}}() + vals = T[] + dists = D[] + gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + orders = Int[] + flags = Dict{String,BitVector}() + # Initialize the `flags`. + for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) + flags[k] = BitVector() + end + + # Range offset. + offset = 0 + + for (idx, vn) in enumerate(vns_both) + # `idcs` + idcs[vn] = idx + # `vns` + push!(vns, vn) + if vn in vns_left && vn in vns_right + # `vals`: only valid if they're the length. + vals_left = getval(metadata_left, vn) + vals_right = getval(metadata_right, vn) + @assert length(vals_left) == length(vals_right) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists`: only valid if they're the same. + dists_left = getdist(metadata_left, vn) + dists_right = getdist(metadata_right, vn) + @assert dists_left == dists_right + push!(dists, dists_left) + # `orders`: giving precedence to `metadata_right` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + # Using `metadata_right`; should we? + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + elseif vn in vns_left + # Just extract the metadata from `metadata_left`. + # `vals` + vals_left = getval(metadata_left, vn) + append!(vals, vals_left) + # `ranges` + r = (offset + 1):(offset + length(vals_left)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_left = getdist(metadata_left, vn) + push!(dists, dists_left) + # `orders` + push!(orders, getorder(metadata_left, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_left, vn, k)) + end + else + # Just extract the metadata from `metadata_right`. + # `vals` + vals_right = getval(metadata_right, vn) + append!(vals, vals_right) + # `ranges` + r = (offset + 1):(offset + length(vals_right)) + push!(ranges, r) + offset = r[end] + # `dists` + dists_right = getdist(metadata_right, vn) + push!(dists, dists_right) + # `orders` + push!(orders, getorder(metadata_right, vn)) + # `flags` + for k in keys(flags) + push!(flags[k], is_flagged(metadata_right, vn, k)) + end + end + end + + return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) +end + const VarView = Union{Int,UnitRange,Vector{Int}} """ @@ -736,6 +950,17 @@ function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, mod return vi end +function link!!( + t::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) +end + """ link!(vi::VarInfo, spl::Sampler) @@ -811,6 +1036,17 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, m return vi end +function invlink!!( + ::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) +end + function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do @@ -915,6 +1151,16 @@ end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) return _link(varinfo, spl) end +function link( + ::DynamicTransformation, + varinfo::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) +end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) @@ -1000,6 +1246,16 @@ function invlink( ) return _invlink(varinfo, spl) end +function invlink( + ::DynamicTransformation, + varinfo::ThreadSafeVarInfo{<:VarInfo}, + spl::AbstractSampler, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Setfield.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) +end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) @@ -1374,6 +1630,15 @@ function setorder!(vi::VarInfo, vn::VarName, index::Int) return vi end +""" + getorder(vi::VarInfo, vn::VarName) + +Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements +run before sampling `vn`. +""" +getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) +getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] + ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1384,7 +1649,10 @@ end Check whether `vn` has a true value for `flag` in `vi`. """ function is_flagged(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] + return is_flagged(getmetadata(vi, vn), vn, flag) +end +function is_flagged(metadata::Metadata, vn::VarName, flag::String) + return metadata.flags[flag][getidx(metadata, vn)] end """ diff --git a/test/test_util.jl b/test/test_util.jl index 31296f79a..7a7028536 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -82,7 +82,8 @@ end Return string representing a short description of `vi`. """ -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = short_varinfo_name(vi.varinfo) +short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = + "threadsafe($(short_varinfo_name(vi.varinfo)))" short_varinfo_name(::TypedVarInfo) = "TypedVarInfo" short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" diff --git a/test/varinfo.jl b/test/varinfo.jl index 7f96c071e..e306a3df2 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,19 @@ +function check_varinfo_keys(varinfo, vns) + if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} + # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, + # since `keys(varinfo_merged)` only contains `VarName` with `IdentityLens`. + # So we just check that the original keys are present. + for vn in vns + # Should have all the original keys. + @test haskey(varinfo, vn) + end + else + vns_varinfo = keys(varinfo) + # Should be equivalent. + @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) + end +end + # A simple "algorithm" which only has `s` variables in its space. struct MySAlg end DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @@ -326,7 +342,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) vns = DynamicPPL.TestUtils.varnames(model) # Set up the different instances of `AbstractVarInfo` with the desired values. - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, example_values, vns; include_threadsafe=true + ) @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) @@ -369,9 +387,11 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "mutating=$mutating" for mutating in [false, true] value_true = rand(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, value_true, varnames; include_threadsafe=true + ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa SimpleVarInfo{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: this is broken since we'll end up trying to set # # varinfo[@varname(x[4:5])] = [x[4],] @@ -426,6 +446,153 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end + @testset "subset" begin + @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x = TV(undef, 2) + x[1] ~ Normal(m, sqrt(s)) + x[2] ~ Normal(m, sqrt(s)) + return (; s, m, x) + end + model = demo_subsetting_varinfo() + vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] + + # `VarInfo` supports, effectively, arbitrary subsetting. + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, model(), vns; include_threadsafe=true + ) + varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) + varinfos_simple = filter( + Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos + ) + + # `VarInfo` supports subsetting using, basically, arbitrary varnames. + vns_supported_standard = [ + [@varname(s)], + [@varname(m)], + [@varname(x[1])], + [@varname(x[2])], + [@varname(s), @varname(m)], + [@varname(s), @varname(x[1])], + [@varname(s), @varname(x[2])], + [@varname(m), @varname(x[1])], + [@varname(m), @varname(x[2])], + [@varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1])], + [@varname(s), @varname(m), @varname(x[2])], + [@varname(s), @varname(x[1]), @varname(x[2])], + [@varname(m), @varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + ] + + # `SimpleaVarInfo` only supports subsetting using the varnames as they appear + # in the model. + vns_supported_simple = filter(∈(vns), vns_supported_standard) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos_standard + # All variables. + check_varinfo_keys(varinfo, vns) + + # Added a `convert` to make the naming of the testsets a bit more readable. + vns_supported = if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple + vns_supported_simple + else + vns_supported_standard + end + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in + vns_supported + varinfo_subset = subset(varinfo, vns_subset) + # Should now only contain the variables in `vns_subset`. + check_varinfo_keys(varinfo_subset, vns_subset) + # Values should be the same. + @test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset] + + # `merge` with the original. + varinfo_merged = merge(varinfo, varinfo_subset) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + end + + # For certain varinfos we should have errors. + # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `IdentityLens`. + varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] + @testset "$(short_varinfo_name(varinfo)): failure cases" begin + @test_throws ArgumentError subset( + varinfo, [@varname(s), @varname(m), @varname(x[1])] + ) + end + # `SimpleVarInfo{<:AbstractDict}` can only handle varnames as they appear in the model. + varinfo = varinfos[findfirst( + Base.Fix2(isa, SimpleVarInfo{<:AbstractDict}), varinfos + )] + @testset "$(short_varinfo_name(varinfo)): failure cases" begin + @test_throws ArgumentError subset( + varinfo, [@varname(s), @varname(m), @varname(x)] + ) + end + end + + @testset "merge" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, rand(model), vns; include_threadsafe=true + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @testset "with itself" begin + # Merging itself should be a no-op. + varinfo_merged = merge(varinfo, varinfo) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with itself (3-argument version)" begin + # Merging itself should be a no-op. + varinfo_merged = merge(varinfo, varinfo, varinfo) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with empty" begin + # Empty is 1st argument. + # Merging with an empty `VarInfo` should be a no-op. + varinfo_merged = merge(empty!!(deepcopy(varinfo)), varinfo) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + + # Empty is 2nd argument. + # Merging with an empty `VarInfo` should be a no-op. + varinfo_merged = merge(varinfo, empty!!(deepcopy(varinfo))) + # Varnames should be unchanged. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end + + @testset "with different value" begin + x = DynamicPPL.TestUtils.rand(model) + varinfo_changed = DynamicPPL.TestUtils.update_values!!( + deepcopy(varinfo), x, vns + ) + # After `merge`, we should have the same values as `x`. + varinfo_merged = merge(varinfo, varinfo_changed) + DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) + end + end + end + end + @testset "VarInfo with selectors" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS varinfo = VarInfo(model)