Skip to content

Commit

Permalink
Allowing pushing new symbols to TypedVarInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Oct 14, 2024
1 parent bd4baf1 commit d804ef1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1832,14 +1832,31 @@ function BangBang.push!!(
vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}
)
if vi isa UntypedVarInfo
@assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
@assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset"
elseif vi isa TypedVarInfo
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
@assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset"
end

sym = getsym(vn)
if vi isa TypedVarInfo && ~haskey(vi.metadata, sym)
# The NamedTuple doesn't have an entry for this variable, let's add one.
val = tovec(r)
md = Metadata(
Dict(vn => 1),
[vn],
[1:length(val)],
val,
[dist],
[gidset],
[get_num_produce(vi)],
Dict{String,BitVector}("trans" => [false], "del" => [false]),
)
vi = Accessors.@set vi.metadata[sym] = md
else
meta = getmetadata(vi, vn)
push!(meta, vn, r, dist, gidset, get_num_produce(vi))
end

meta = getmetadata(vi, vn)
push!(meta, vn, r, dist, gidset, get_num_produce(vi))

return vi
end

Expand Down Expand Up @@ -1870,7 +1887,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce)
push!(meta.orders, num_produce)
push!(meta.flags["del"], false)
push!(meta.flags["trans"], false)

return meta
end

Expand Down
12 changes: 12 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
test_varinfo!(vi)
test_varinfo!(empty!!(TypedVarInfo(vi)))
end

@testset "push!! to TypedVarInfo" begin
vn_x = @varname x
vn_y = @varname y
untyped_vi = VarInfo()
untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector())
typed_vi = TypedVarInfo(untyped_vi)
typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector())
@test typed_vi[vn_x] == 1.0
@test typed_vi[vn_y] == 2.0
end

@testset "setgid!" begin
vi = VarInfo(DynamicPPL.Metadata())
meta = vi.metadata
Expand Down

0 comments on commit d804ef1

Please sign in to comment.