Skip to content

Commit

Permalink
Merge branch 'master' into py/cherry-pick-0.28.5
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm authored Oct 18, 2024
2 parents 4fbae44 + 54691bf commit e994196
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 25 deletions.
17 changes: 6 additions & 11 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,22 +439,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
end

function _subset(x::AbstractDict, vns)
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
vns_present = collect(keys(x))
vns_found = mapreduce(vcat, vns) do vn
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
return filter(Base.Fix1(subsumes, vn), vns_present)
end

# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
C = ConstructionBase.constructorof(typeof(x))
if isempty(vns_found)
throw(
ArgumentError(
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
),
)
return C()
else
return C(vn => x[vn] for vn in vns_found)
end
C = ConstructionBase.constructorof(typeof(x))
return C(vn => x[vn] for vn in vns_found)
end

function _subset(x::NamedTuple, vns)
Expand Down
52 changes: 39 additions & 13 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,20 +368,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
)
end

function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName}
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
# For each `vn` in `vns`, get the variables subsumed by `vn`.
vns = mapreduce(vcat, vns_given) do vn
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
filter(Base.Fix1(subsumes, vn), metadata.vns)
end
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
indices = Dict(vn => i for (i, vn) in enumerate(vns))
indices = if isempty(vns)
Dict{VarName,Int}()
else
Dict(vn => i for (i, vn) in enumerate(vns))
end
# Construct new `vals` and `ranges`.
vals_original = metadata.vals
ranges_original = metadata.ranges
# Allocate the new `vals`. and `ranges`.
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
ranges = similar(ranges_original)
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0))
ranges = similar(ranges_original, length(vns))
# The new range `r` for `vns[i]` is offset by `offset` and
# has the same length as the original range `r_original`.
# The new `indices` (from above) ensures ordering according to `vns`.
Expand Down Expand Up @@ -415,7 +419,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
ranges,
vals,
metadata.dists[indices_for_vns],
metadata.gids,
metadata.gids[indices_for_vns],
metadata.orders[indices_for_vns],
flags,
)
Expand Down Expand Up @@ -490,7 +494,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
ranges = Vector{UnitRange{Int}}()
vals = T[]
dists = D[]
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
gids = Set{Selector}[]
orders = Int[]
flags = Dict{String,BitVector}()
# Initialize the `flags`.
Expand Down Expand Up @@ -520,6 +524,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
dist_right = getdist(metadata_right, vn)
# Give precedence to `metadata_right`.
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`: giving precedence to `metadata_right`
push!(orders, getorder(metadata_right, vn))
# `flags`
Expand All @@ -539,6 +545,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
# `dists`
dist_left = getdist(metadata_left, vn)
push!(dists, dist_left)
gid = metadata_left.gids[getidx(metadata_left, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_left, vn))
# `flags`
Expand All @@ -557,6 +565,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
# `dists`
dist_right = getdist(metadata_right, vn)
push!(dists, dist_right)
gid = metadata_right.gids[getidx(metadata_right, vn)]
push!(gids, gid)
# `orders`
push!(orders, getorder(metadata_right, vn))
# `flags`
Expand Down Expand Up @@ -1826,14 +1836,31 @@ function BangBang.push!!(
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
)
if vi isa UntypedVarInfo
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
elseif vi isa TypedVarInfo
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
end

sym = getsym(vn)
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
# The NamedTuple doesn't have an entry for this variable, let's add one.
val = tovec(r)
md = Metadata(
Dict(vn => 1),
[vn],
[1:length(val)],
val,
[dist],
[gidset],
[get_num_produce(vi)],
Dict{String,BitVector}("trans" => [false], "del" => [false]),
)
vi = Accessors.@set vi.metadata[sym] = md
else
meta = getmetadata(vi, vn)
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
end

meta = getmetadata(vi, vn)
push!(meta, vn, r, dist, gidset, get_num_produce(vi))

return vi
end

Expand Down Expand Up @@ -1864,7 +1891,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
push!(meta.orders, num_produce)
push!(meta.flags["del"], false)
push!(meta.flags["trans"], false)

return meta
end

Expand Down
2 changes: 1 addition & 1 deletion test/turing/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
Distributions = "0.25"
DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29"
DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30"
HypothesisTests = "0.11"
MCMCChains = "6"
ReverseDiff = "1.15"
Expand Down
32 changes: 32 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
test_varinfo!(vi)
test_varinfo!(empty!!(TypedVarInfo(vi)))
end

@testset "push!! to TypedVarInfo" begin
vn_x = @varname x
vn_y = @varname y
untyped_vi = VarInfo()
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
typed_vi = TypedVarInfo(untyped_vi)
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
@test typed_vi[vn_x] == 1.0
@test typed_vi[vn_y] == 2.0
end

@testset "setgid!" begin
vi = VarInfo(DynamicPPL.Metadata())
meta = vi.metadata
Expand Down Expand Up @@ -566,6 +578,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
else
vns_supported_standard
end

@testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in
vns_supported
varinfo_subset = subset(varinfo, VarName[])
@test isempty(varinfo_subset)
end

@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
vns_supported
varinfo_subset = subset(varinfo, vns_subset)
Expand Down Expand Up @@ -694,6 +713,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)]
@test DynamicPPL.istrans(varinfo_merged, @varname(x))
end

# The below used to error, testing to avoid regression.
@testset "merge gids" begin
gidset_left = Set([Selector(1)])
vi_left = VarInfo()
vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left)
gidset_right = Set([Selector(2)])
vi_right = VarInfo()
vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right)
varinfo_merged = merge(vi_left, vi_right)
@test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left
@test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right
end
end

@testset "VarInfo with selectors" begin
Expand Down

0 comments on commit e994196

Please sign in to comment.