Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subset and merge for VarInfo (clean version) #544

Merged
merged 28 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
028a81a
added `subset` which can extract a subset of the varinfo
torfjelde Oct 8, 2023
caa6e25
added testing of `subset` for `VarInfo`
torfjelde Oct 8, 2023
cac7fa8
formatting
torfjelde Oct 8, 2023
5e41c4f
added implementation of `merge` for `VarInfo` and tests for it
torfjelde Oct 8, 2023
d5a2631
more tests
torfjelde Oct 8, 2023
0ade696
formatting
torfjelde Oct 8, 2023
db21844
improved merge_metadata for NamedTuple inputs
torfjelde Oct 9, 2023
1dbca4c
added proper handling of the `vals` in `subset`
torfjelde Oct 9, 2023
b67288f
added docs for `subset` and `merge`
torfjelde Oct 9, 2023
e43029e
added `subset` and `merge` to documentation
torfjelde Oct 9, 2023
cd4033d
formatting
torfjelde Oct 9, 2023
8f47dfe
made merge and subset part of the AbstractVarInfo interface
torfjelde Oct 13, 2023
aba9008
added implementations `subset` and `merge` for `SimpleVarInfo`
torfjelde Oct 13, 2023
3b621ae
follow standard merge semantics where the right one takes precedence
torfjelde Oct 13, 2023
2c2c90b
added proper testing of merge and subset for SimpleVarInfo too
torfjelde Oct 13, 2023
5c1ece3
forgotten inclusion in previous commit
torfjelde Oct 13, 2023
cfff96c
Update src/simple_varinfo.jl
torfjelde Oct 13, 2023
ed5d948
remove two-argument impl of merge
torfjelde Oct 13, 2023
00c36cf
formatting
torfjelde Oct 13, 2023
cf02816
forgot to add more formatting
torfjelde Oct 13, 2023
d02cb61
Merge branch 'master' into torfjelde/subset-and-merge
torfjelde Oct 13, 2023
7f01ada
removed 2-arg version of merge for abstract varinfo in favour of 3-ar…
torfjelde Oct 13, 2023
14105e0
allow inclusion of threadsafe varinfo in setup_varinfos
torfjelde Oct 13, 2023
c164d32
more tests for thread safe varinfo
torfjelde Oct 13, 2023
743162a
bugfixes for link and invlink methods when using thread safe varinfo
torfjelde Oct 13, 2023
dc9ad94
attempt at fixing docs
torfjelde Oct 13, 2023
2f320e6
fixed missing test coverage
torfjelde Oct 14, 2023
d3a9b56
formatting
torfjelde Oct 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
#### Utils

```@docs
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export AbstractVarInfo,
SimpleVarInfo,
push!!,
empty!!,
subset,
getlogp,
setlogp!!,
acclogp!!,
Expand Down
161 changes: 161 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@
bijector::F
end

"""
merge_transformations(transformation_left, transformation_right)

Merge two transformations.

The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref).
"""
function merge_transformations(::NoTransformation, ::NoTransformation)
return NoTransformation()
end
function merge_transformations(::DynamicTransformation, ::DynamicTransformation)
return DynamicTransformation()

Check warning on line 67 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L66-L67

Added lines #L66 - L67 were not covered by tests
end
function merge_transformations(left::StaticTransformation, right::StaticTransformation)
return StaticTransformation(merge_bijectors(left.bijector, right.bijector))

Check warning on line 70 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end

function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform)
return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs))

Check warning on line 74 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L73-L74

Added lines #L73 - L74 were not covered by tests
end

"""
default_transformation(model::Model[, vi::AbstractVarInfo])

Expand Down Expand Up @@ -337,6 +358,146 @@
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)}))
end

# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
# might result in a `Vector{Any}`.
"""
subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we would not already have so many getindex methods, I would have thought that getindex would be a natural name for this function. But maybe it's still an option?

Then we could have getindex(::AbstractVarInfo, ::AbstractVector{<:VarName}) -> AbstractVarInfo and getindex(::T, ::VarName) -> typeof_varname_variate, similar to [1,2,3][[1,3]] = [1, 3] and [1,2,3][2] = 2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd really like this yes, but I also really don't want to touch getindex in this codebase 😅

Happy to make this a long-term goal or something though!


Subset a `varinfo` to only contain the variables `vns`.

!!! warning
The ordering of the variables in the resulting `varinfo` is _not_
guaranteed to follow the ordering of the variables in `varinfo`.
Hence care must be taken, in particular when used in conjunction with
other methods which uses the vector-representation of the `varinfo`,
e.g. `getindex(varinfo, sampler)`.

# Examples
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
julia> @model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x = Vector{Float64}(undef, 2)
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
end
demo (generic function with 2 methods)

julia> model = demo();

julia> varinfo = VarInfo(model);

julia> keys(varinfo)
4-element Vector{VarName}:
s
m
x[1]
x[2]

julia> for (i, vn) in enumerate(keys(varinfo))
varinfo[vn] = i
end

julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0

julia> # Extract one with only `m`.
varinfo_subset1 = subset(varinfo, [@varname(m),]);


julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
m

julia> varinfo_subset1[@varname(m)]
2.0

julia> # Extract one with both `s` and `x[2]`.
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);

julia> keys(varinfo_subset2)
2-element Vector{VarName}:
s
x[2]

julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
1.0
4.0
```

`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref)

```jldoctest varinfo-subset
julia> # Merge the two.
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);

julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
m
s
x[2]

julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
1.0
2.0
4.0

julia> # Merge the two with the original.
varinfo_merged = merge(varinfo, varinfo_subset_merged);

julia> keys(varinfo_merged)
4-element Vector{VarName}:
s
m
x[1]
x[2]

julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
```

# Notes

## Type-stability

!!! warning
This function is only type-stable when `vns` contains only varnames
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be.
"""
function subset end

"""
merge(varinfo, other_varinfos...)

Merge varinfos into one, giving precedence to the right-most varinfo when sensible.

This is particularly useful when combined with [`subset(varinfo, vns)`](@ref).

See docstring of [`subset(varinfo, vns)`](@ref) for examples.
"""
Base.merge(varinfo::AbstractVarInfo) = varinfo

Check warning on line 490 in src/abstract_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/abstract_varinfo.jl#L490

Added line #L490 was not covered by tests
# Define 3-argument version so 2-argument version will error if not implemented.
function Base.merge(
varinfo1::AbstractVarInfo,
varinfo2::AbstractVarInfo,
varinfo3::AbstractVarInfo,
varinfo_others::AbstractVarInfo...,
)
return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...)
end

# Transformations
"""
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
Expand Down
45 changes: 45 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,51 @@
return V
end

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.values = _subset(varinfo.values, vns)
end

function _subset(x::AbstractDict, vns)
# NOTE: This requires `vns` to be explicitly present in `x`.
if any(!Base.Fix1(haskey, x), vns)
throw(
ArgumentError(
"Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " *
"For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " *
"`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.",
),
)
end
C = ConstructionBase.constructorof(typeof(x))
return C(vn => x[vn] for vn in vns)

Check warning on line 439 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L438-L439

Added lines #L438 - L439 were not covered by tests
end

function _subset(x::NamedTuple, vns)
# NOTE: Here we can only handle `vns` that contain the `IdentityLens`.
if any(Base.Fix1(!==, Setfield.IdentityLens()) ∘ getlens, vns)
throw(
ArgumentError(
"Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " *
"For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.",
),
)
end

syms = map(getsym, vns)
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms)))

Check warning on line 454 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L453-L454

Added lines #L453 - L454 were not covered by tests
end

# `merge`
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
values = merge(varinfo_left.values, varinfo_right.values)
logp = getlogp(varinfo_right)
transformation = merge_transformations(
varinfo_left.transformation, varinfo_right.transformation
)
return SimpleVarInfo(values, logp, transformation)
end

# Context implementations
# NOTE: Evaluations, i.e. those without `rng` are shared with other
# implementations of `AbstractVarInfo`.
Expand Down
17 changes: 14 additions & 3 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal
end

"""
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)

Return a tuple of instances for different implementations of `AbstractVarInfo` with
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.

If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
of the varinfo instances.
"""
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
function setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
# VarInfo
vi_untyped = VarInfo()
model(vi_untyped)
Expand All @@ -56,12 +61,18 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))

lp = getlogp(vi_typed)
return map((
varinfos = map((
vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end

if include_threadsafe
varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...)
end

return varinfos
end

"""
Expand Down
56 changes: 52 additions & 4 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,56 @@
function link!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)

Check warning on line 87 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L87

Added line #L87 was not covered by tests
end

function invlink!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)

Check warning on line 93 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L93

Added line #L93 was not covered by tests
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model)

Check warning on line 99 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L99

Added line #L99 was not covered by tests
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)

Check warning on line 105 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L105

Added line #L105 was not covered by tests
end

# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
# to define `getlogp(vi)`.
function link!!(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
end

function invlink!!(
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return settrans!!(
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
end

function invlink(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
end

function maybe_invlink_before_eval!!(
Expand Down Expand Up @@ -192,3 +223,20 @@
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)

function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x)
end
function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector)
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x)

Check warning on line 231 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L230-L231

Added lines #L230 - L231 were not covered by tests
end

function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns)

Check warning on line 235 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L234-L235

Added lines #L234 - L235 were not covered by tests
end

function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo)
return Setfield.@set varinfo_left.varinfo = merge(
varinfo_left.varinfo, varinfo_right.varinfo
)
end
Loading
Loading