Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subset and merge for VarInfo (clean version) #544

Merged
merged 28 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
028a81a
added `subset` which can extract a subset of the varinfo
torfjelde Oct 8, 2023
caa6e25
added testing of `subset` for `VarInfo`
torfjelde Oct 8, 2023
cac7fa8
formatting
torfjelde Oct 8, 2023
5e41c4f
added implementation of `merge` for `VarInfo` and tests for it
torfjelde Oct 8, 2023
d5a2631
more tests
torfjelde Oct 8, 2023
0ade696
formatting
torfjelde Oct 8, 2023
db21844
improved merge_metadata for NamedTuple inputs
torfjelde Oct 9, 2023
1dbca4c
added proper handling of the `vals` in `subset`
torfjelde Oct 9, 2023
b67288f
added docs for `subset` and `merge`
torfjelde Oct 9, 2023
e43029e
added `subset` and `merge` to documentation
torfjelde Oct 9, 2023
cd4033d
formatting
torfjelde Oct 9, 2023
8f47dfe
made merge and subset part of the AbstractVarInfo interface
torfjelde Oct 13, 2023
aba9008
added implementations `subset` and `merge` for `SimpleVarInfo`
torfjelde Oct 13, 2023
3b621ae
follow standard merge semantics where the right one takes precedence
torfjelde Oct 13, 2023
2c2c90b
added proper testing of merge and subset for SimpleVarInfo too
torfjelde Oct 13, 2023
5c1ece3
forgotten inclusion in previous commit
torfjelde Oct 13, 2023
cfff96c
Update src/simple_varinfo.jl
torfjelde Oct 13, 2023
ed5d948
remove two-argument impl of merge
torfjelde Oct 13, 2023
00c36cf
formatting
torfjelde Oct 13, 2023
cf02816
forgot to add more formatting
torfjelde Oct 13, 2023
d02cb61
Merge branch 'master' into torfjelde/subset-and-merge
torfjelde Oct 13, 2023
7f01ada
removed 2-arg version of merge for abstract varinfo in favour of 3-ar…
torfjelde Oct 13, 2023
14105e0
allow inclusion of threadsafe varinfo in setup_varinfos
torfjelde Oct 13, 2023
c164d32
more tests for thread safe varinfo
torfjelde Oct 13, 2023
743162a
bugfixes for link and invlink methods when using thread safe varinfo
torfjelde Oct 13, 2023
dc9ad94
attempt at fixing docs
torfjelde Oct 13, 2023
2f320e6
fixed missing test coverage
torfjelde Oct 14, 2023
d3a9b56
formatting
torfjelde Oct 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
#### Utils

```@docs
Base.merge(::VarInfo, ::VarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export AbstractVarInfo,
SimpleVarInfo,
push!!,
empty!!,
subset,
getlogp,
setlogp!!,
acclogp!!,
Expand Down
352 changes: 351 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,344 @@
_tail(nt::NamedTuple) = Base.tail(nt)
end

# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
# might result in a `Vector{Any}`.
"""
subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})

Subset a `varinfo` to only contain the variables `vns`.

!!! warning
The ordering of the variables in the resulting `varinfo` will _not_
necessarily follow the ordering of the variables in `varinfo`.
Hence care must be taken, in particular when used in conjunction with
other methods which uses the vector-representation of `varinfo`, e.g.
`getindex(varinfo, sampler)`

# Examples
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
julia> @model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x = Vector{Float64}(undef, 2)
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
end
demo (generic function with 2 methods)

julia> model = demo();

julia> varinfo = VarInfo(model);

julia> keys(varinfo)
4-element Vector{VarName}:
s
m
x[1]
x[2]

julia> for (i, vn) in enumerate(keys(varinfo))
varinfo[vn] = i
end

julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0

julia> # Extract one with only `m`.
varinfo_subset1 = subset(varinfo, [@varname(m),]);


julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
m

julia> varinfo_subset1[@varname(m)]
2.0

julia> # Extract one with both `s` and `x[2]`.
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);

julia> keys(varinfo_subset2)
2-element Vector{VarName}:
s
x[2]

julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
1.0
4.0
```

`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref)

```jldoctest varinfo-subset
julia> # Merge the two.
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);

julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
m
s
x[2]

julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
1.0
2.0
4.0

julia> # Merge the two with the original.
varinfo_merged = merge(varinfo, varinfo_subset_merged);

julia> keys(varinfo_merged)
4-element Vector{VarName}:
s
m
x[1]
x[2]

julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
```

# Notes

## Type-stability

!!! warning
This function is only type-stable when `vns` contains only varnames
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
"""
function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName})
metadata = subset(varinfo.metadata, vns)
return VarInfo(metadata, varinfo.logp, varinfo.num_produce)
end

function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym}
# If all the variables are using the same symbol, then we can just extract that field from the metadata.
metadata = subset(getfield(varinfo.metadata, sym), vns)
return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce)
end

function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
syms = Tuple(unique(map(getsym, vns)))
metadatas = map(syms) do sym
subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns))
end

return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce)
end

function subset(metadata::Metadata, vns::AbstractVector{<:VarName})
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least a warning?

indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
indices = Dict(vn => i for (i, vn) in enumerate(vns))
# Construct new `vals` and `ranges`.
vals_original = metadata.vals
ranges_original = metadata.ranges
# Allocate the new `vals`. and `ranges`.
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
ranges = similar(ranges_original)
# The new range `r` for `vns[i]` is offset by `offset` and
# has the same length as the original range `r_original`.
# The new `indices` (from above) ensures ordering according to `vns`.
# NOTE: This means that the order of the variables in `vns` defines the order
# in the resulting `varinfo`! This can have performance implications, e.g.
# if in the model we have something like
#
# for i = 1:N
# x[i] ~ Normal()
# end
#
# and we then we do
#
# subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))])
#
# the resulting `varinfo` will have `vals` ordered differently from the
# original `varinfo`, which can have performance implications.
offset = 0
for (idx, idx_original) in enumerate(indices_for_vns)
r_original = ranges_original[idx_original]
r = (offset + 1):(offset + length(r_original))
vals[r] = vals_original[r_original]
ranges[idx] = r
offset = r[end]
end

flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags)
return Metadata(
indices,
vns,
ranges,
vals,
metadata.dists[indices_for_vns],
metadata.gids,
metadata.orders[indices_for_vns],
flags,
)
end

"""
merge(varinfo_left::VarInfo, varinfo_right::VarInfo)

Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add some more docs for this a bit later today

"""
function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
return _merge(varinfo_left, varinfo_right)
end

function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
lp = getlogp(varinfo_left) + getlogp(varinfo_right)
# TODO: Is this really the way we want to combine `num_produce`?
num_produce = varinfo_left.num_produce[] + varinfo_right.num_produce[]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me the use of num_produce?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semantically, taking the sum of two num_produce produces no meaning. I think we should throw an error if both num_produce are non-zero.

Copy link
Member Author

@torfjelde torfjelde Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_produce is used by the particle samplers to indicate how many observations we've seen / what's the current "observation index".

Semantically, taking the sum of two num_produce produces no meaning. I think we should throw an error if both num_produce are non-zero.

I don't think that's quite right. It makes sense if you go "execute model A and then separately execute model B, and now we want to merge the two"; to ensure that the observation index is correctly, we need to add them, no?

Is this a scenario we want to consider? Probably not. But to me it wasn't obvious what else to do.

No matter, I don't think erroring if num_produce is non-zero is the right way to go. We should allow something like

varinfo_with_num_produce = last(evaluate!!(model, varinfo, context))
merge(varinfo_with_num_produce, varinfo)

If there's no meaning in adding them, then I suggest the alternative is to just give precedence to varinfo_right, as is the semantics of merge (if a field/property/key is present in both left and right, then the value in right takes precedence).

EDIT: It's fair to ask if we should also do that for logp. I went with sum because I was imagining scenarios where we want to do something like execute two (sub-)models in parallel and then merge the resulting varinfos. In such a scenario, we want to add the logp fields together. But we could maybe make this a separate function, e.g. merge_with_add_logp or something (preferably named in a better way).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what should we do here? Give precedence to varinfo_right or leave as it is?

return VarInfo(metadata, Ref(lp), Ref(num_produce))
end

@generated function merge_metadata(
metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right}
) where {names_left,names_right}
names = Expr(:tuple)
vals = Expr(:tuple)
# Loop over `names_left` first because we want to preserve the order of the variables.
for sym in names_left
push!(names.args, QuoteNode(sym))
if sym in names_right
push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym)))
else
push!(vals.args, :(metadata_left.$sym))
end
end
# Loop over remaining variables in `names_right`.
names_right_only = filter(∉(names_left), names_right)
for sym in names_right_only
push!(names.args, QuoteNode(sym))
push!(vals.args, :(metadata_right.$sym))
end

return :(NamedTuple{$names}($vals))
end

function merge_metadata(metadata_left::Metadata, metadata_right::Metadata)
# Extract the varnames.
vns_left = metadata_left.vns
vns_right = metadata_right.vns
vns_both = union(vns_left, vns_right)

# Determine `eltype` of `vals`.
T_left = eltype(metadata_left.vals)
T_right = eltype(metadata_right.vals)
T = promote_type(T_left, T_right)
# TODO: Is this necessary?
if !(T <: Real)
T = Real

Check warning on line 479 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L479

Added line #L479 was not covered by tests
end

# Determine `eltype` of `dists`.
D_left = eltype(metadata_left.dists)
D_right = eltype(metadata_right.dists)
D = promote_type(D_left, D_right)
# TODO: Is this necessary?
if !(D <: Distribution)
D = Distribution

Check warning on line 488 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L488

Added line #L488 was not covered by tests
end

# Initialize required fields for `metadata`.
vns = VarName[]
idcs = Dict{VarName,Int}()
ranges = Vector{UnitRange{Int}}()
vals = T[]
dists = D[]
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right`
orders = Int[]
flags = Dict{String,BitVector}()
# Initialize the `flags`.
for k in union(keys(metadata_left.flags), keys(metadata_right.flags))
flags[k] = BitVector()
end

# Range offset.
offset = 0

for (idx, vn) in enumerate(vns_both)
# `idcs`
idcs[vn] = idx
# `vns`
push!(vns, vn)
if vn in vns_left && vn in vns_right
# `vals`: only valid if they're the length.
vals_left = getval(metadata_left, vn)
vals_right = getval(metadata_right, vn)
@assert length(vals_left) == length(vals_right)
append!(vals, vals_right)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`: only valid if they're the same.
dists_left = getdist(metadata_left, vn)
dists_right = getdist(metadata_right, vn)
@assert dists_left == dists_right
push!(dists, dists_left)
# `orders`: giving precedence to `metadata_right`
push!(orders, getorder(metadata_right, vn))
# `flags`
for k in keys(flags)
# Using `metadata_right`; should we?
push!(flags[k], is_flagged(metadata_right, vn, k))
end
elseif vn in vns_left
# Just extract the metadata from `metadata_left`.
# `vals`
vals_left = getval(metadata_left, vn)
append!(vals, vals_left)
# `ranges`
r = (offset + 1):(offset + length(vals_left))
push!(ranges, r)
offset = r[end]
# `dists`
dists_left = getdist(metadata_left, vn)
push!(dists, dists_left)
# `orders`
push!(orders, getorder(metadata_left, vn))
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_left, vn, k))
end
else
# Just extract the metadata from `metadata_right`.
# `vals`
vals_right = getvals(metadata_right, vn)
append!(vals, vals_right)

Check warning on line 557 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L556-L557

Added lines #L556 - L557 were not covered by tests
# `ranges`
r = (offset + 1):(offset + length(vals_right))
push!(ranges, r)
offset = r[end]

Check warning on line 561 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L559-L561

Added lines #L559 - L561 were not covered by tests
# `dists`
dists_right = getdist(metadata_right, vn)
push!(dists, dists_right)

Check warning on line 564 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L563-L564

Added lines #L563 - L564 were not covered by tests
# `orders`
push!(orders, getorder(metadata_right, vn))

Check warning on line 566 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L566

Added line #L566 was not covered by tests
# `flags`
for k in keys(flags)
push!(flags[k], is_flagged(metadata_right, vn, k))
end

Check warning on line 570 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L568-L570

Added lines #L568 - L570 were not covered by tests
end
end

return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags)
end

const VarView = Union{Int,UnitRange,Vector{Int}}

"""
Expand Down Expand Up @@ -1331,6 +1669,15 @@
return vi
end

"""
getorder(vi::VarInfo, vn::VarName)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this API -- we can consider depreciating num_produce in favour of getorder in the longer run.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. Isn't num_produce and order different things? At least they are two different fields in VarInfo. I I just added a getorder method because I've been trying to make the interface of VarInfo simpler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_produce is used to track the current observation index (for all vns), and its current value inserted to VarInfo.metadata when a new vn is created.


Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
run before sampling `vn`.
"""
getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn)

Check warning on line 1678 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L1678

Added line #L1678 was not covered by tests
getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)]

#######################################
# Rand & replaying method for VarInfo #
#######################################
Expand All @@ -1341,7 +1688,10 @@
Check whether `vn` has a true value for `flag` in `vi`.
"""
function is_flagged(vi::VarInfo, vn::VarName, flag::String)
return getmetadata(vi, vn).flags[flag][getidx(vi, vn)]
return is_flagged(getmetadata(vi, vn), vn, flag)
end
function is_flagged(metadata::Metadata, vn::VarName, flag::String)
return metadata.flags[flag][getidx(metadata, vn)]
end

"""
Expand Down
Loading
Loading