Skip to content

Commit

Permalink
CompatHelper: bump compat for AbstractMCMC to 5 for package test, (ke…
Browse files Browse the repository at this point in the history
…ep existing compat) (#553)

* Fix for `rand` + replace overloads of `rand` with `rand_prior_true` for testing models (#541)

* preserve context from model in `rand`

* replace rand overloads in TestUtils with definitions of
rand_prior_true so we can properly test rand

* removed NamedTuple from signature of TestUtils.rand_prior_true

* updated references to previous overloads of rand to now use rand_prior_true

* test rand for DEMO_MODELS

* formatting

* fixed tests for rand for DEMO_MODELS

* fixed linkning tests

* added missing impl of rand_prior_true for demo_static_transformation

* formatting

* fixed rand_prior_true for demo_static_transformation

* bump minor version as this will be breaking

* bump patch version

* fixed old usage of rand

* Update test/varinfo.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed another usage of rand

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Remove `tonamedtuple` (#547)

* Remove dependencies to `tonamedtuple`

* Remove `tonamedtuple`s

* Minor version bump

---------

Co-authored-by: Hong Ge <[email protected]>

* CompatHelper: bump compat for AbstractMCMC to 5 for package test, (keep existing compat)

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: CompatHelper Julia <[email protected]>
  • Loading branch information
5 people authored Nov 1, 2023
1 parent 9a2d2e5 commit a5bbe61
Show file tree
Hide file tree
Showing 15 changed files with 81 additions and 134 deletions.
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ DynamicPPL.reconstruct
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```
Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ export AbstractVarInfo,
invlink,
invlink!,
invlink!!,
tonamedtuple,
values_as,
# VarName (reexport from AbstractPPL)
VarName,
Expand Down
15 changes: 0 additions & 15 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,21 +738,6 @@ function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::Abstrac
return unflatten(varinfo, sampler, θ)
end

"""
tonamedtuple(vi::AbstractVarInfo)
Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and
indexing string of the variable.
For example, a model that had a vector of vector-valued
variables `x` would return
```julia
(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), )
```
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
evaluate!!(
model,
SimpleVarInfo{Float64}(OrderedDict()),
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
SamplingContext(rng, SampleFromPrior(), model.context),
),
)
return values_as(x, T)
Expand Down
38 changes: 0 additions & 38 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,44 +532,6 @@ function dot_assume(
return value, lp, vi
end

# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl.
# TODO: Move away from using these `tonamedtuple` methods.
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names}
nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end

return NamedTuple{names}(nt_vals)
end

function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}()
for vn in keys(vi)
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
sym = only(unique(map(getsym, vns)))

# Initialize entry if not yet initialized.
if !haskey(syms_to_result, sym)
syms_to_result[sym] = (Real[], String[])
end

# Combine with old result.
old_vals, old_string_vns = syms_to_result[sym]
syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns)))
end

# Construct `NamedTuple`.
return NamedTuple(pairs(syms_to_result))
end

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
function settrans!!(vi::SimpleVarInfo, trans)
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
Expand Down
57 changes: 29 additions & 28 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`.
"""
function posterior_mean end

"""
rand_prior_true([rng::AbstractRNG, ]model::DynamicPPL.Model)
Return a `NamedTuple` of realizations from the prior of `model` compatible with `varnames(model)`.
"""
function rand_prior_true(model::DynamicPPL.Model)
return rand_prior_true(Random.default_rng(), model)
end

"""
demo_dynamic_constraint()
Expand Down Expand Up @@ -263,10 +272,8 @@ function logprior_true_with_logabsdet_jacobian(
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG,
::Type{NamedTuple},
model::Model{typeof(demo_one_variable_multiple_constraints)},
function rand_prior_true(
rng::Random.AbstractRNG, model::Model{typeof(demo_one_variable_multiple_constraints)}
)
x = Vector{Float64}(undef, 5)
x[1] = rand(rng, Normal())
Expand Down Expand Up @@ -310,9 +317,7 @@ function logprior_true_with_logabsdet_jacobian(model::Model{typeof(demo_lkjchol)
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::Model{typeof(demo_lkjchol)}
)
function rand_prior_true(rng::Random.AbstractRNG, model::Model{typeof(demo_lkjchol)})
x = rand(rng, LKJCholesky(model.args.d, 1.0))
return (x=x,)
end
Expand Down Expand Up @@ -724,12 +729,6 @@ const DemoModels = Union{
Model{typeof(demo_assume_matrix_dot_observe_matrix)},
}

# We require demo models to have explict impleentations of `rand` since we want
# these to be considered as ground truth.
function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels)
return error("demo models requires explicit implementation of rand")
end

const UnivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
}
Expand All @@ -743,9 +742,7 @@ function posterior_optima(::UnivariateAssumeDemoModels)
# TODO: Figure out exact for `s`.
return (s=0.907407, m=7 / 6)
end
function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels)
s = rand(rng, InverseGamma(2, 3))
m = rand(rng, Normal(0, sqrt(s)))

Expand All @@ -766,7 +763,7 @@ const MultivariateAssumeDemoModels = Union{
}
function posterior_mean(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

vals.s[1] = 19 / 8
vals.m[1] = 3 / 4
Expand All @@ -778,7 +775,7 @@ function posterior_mean(model::MultivariateAssumeDemoModels)
end
function likelihood_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1] = 1e-32
Expand All @@ -791,7 +788,7 @@ function likelihood_optima(model::MultivariateAssumeDemoModels)
end
function posterior_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1] = 0.890625
Expand All @@ -801,9 +798,7 @@ function posterior_optima(model::MultivariateAssumeDemoModels)

return vals
end
function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels)
# Get template values from `model`.
retval = model(rng)
vals = (s=retval.s, m=retval.m)
Expand All @@ -821,7 +816,7 @@ const MatrixvariateAssumeDemoModels = Union{
}
function posterior_mean(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

vals.s[1, 1] = 19 / 8
vals.m[1] = 3 / 4
Expand All @@ -833,7 +828,7 @@ function posterior_mean(model::MatrixvariateAssumeDemoModels)
end
function likelihood_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1, 1] = 1e-32
Expand All @@ -846,7 +841,7 @@ function likelihood_optima(model::MatrixvariateAssumeDemoModels)
end
function posterior_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1, 1] = 0.890625
Expand All @@ -856,9 +851,7 @@ function posterior_optima(model::MatrixvariateAssumeDemoModels)

return vals
end
function Base.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MatrixvariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::MatrixvariateAssumeDemoModels)
# Get template values from `model`.
retval = model(rng)
vals = (s=retval.s, m=retval.m)
Expand Down Expand Up @@ -954,6 +947,14 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end

function rand_prior_true(
rng::Random.AbstractRNG, model::Model{typeof(demo_static_transformation)}
)
s = rand(rng, InverseGamma(2, 3))
m = rand(rng, Normal(0, sqrt(s)))
return (s=s, m=m)
end

"""
marginal_mean_of_samples(chain, varname)
Expand Down
2 changes: 0 additions & 2 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return is_flagged(vi.varinfo, vn, flag)
end

tonamedtuple(vi::ThreadSafeVarInfo) = tonamedtuple(vi.varinfo)

# Transformations.
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
Expand Down
16 changes: 0 additions & 16 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1506,22 +1506,6 @@ end
return expr
end

# TODO: Remove this completely.
tonamedtuple(varinfo::VarInfo) = tonamedtuple(varinfo.metadata, varinfo)
function tonamedtuple(metadata::NamedTuple{names}, varinfo::VarInfo) where {names}
length(names) === 0 && return NamedTuple()

vals_tuple = map(values(metadata)) do x
# NOTE: `tonamedtuple` is really only used in Turing.jl to convert to
# a "transition". This means that we really don't mutations of the values
# in `varinfo` to propoagate the previous samples. Hence we `copy.`
vals = map(copy Base.Fix1(getindex, varinfo), x.vns)
return vals, map(string, x.vns)
end

return NamedTuple{names}(vals_tuple)
end

@inline function findvns(vi, f_vns)
if length(f_vns) == 0
throw("Unidentified error, please report this error in an issue.")
Expand Down
8 changes: 5 additions & 3 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ end
@model demo() = m ~ dist
model = demo()

vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),))
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
@testset "$(short_varinfo_name(vi))" for vi in vis
# Evaluate once to ensure we have `logp` value.
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
Expand Down Expand Up @@ -105,7 +106,7 @@ end
@testset "d=$d" for d in [2, 3, 5]
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(model)
values_original = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(
model, values_original, (@varname(x),)
)
Expand Down Expand Up @@ -146,7 +147,8 @@ end
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
Expand Down
2 changes: 1 addition & 1 deletion test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test, DynamicPPL, LogDensityProblems

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)

Expand Down
2 changes: 1 addition & 1 deletion test/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testset "loglikelihoods.jl" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, m)
example_values = DynamicPPL.TestUtils.rand_prior_true(m)

# Instantiate a `VarInfo` with the example values.
vi = VarInfo(m)
Expand Down
20 changes: 17 additions & 3 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
Random.seed!(1776)
s, m = model()
sample_namedtuple = (; s=s, m=m)
sample_dict = Dict(@varname(s) => s, @varname(m) => m)
sample_dict = OrderedDict(@varname(s) => s, @varname(m) => m)

# With explicit RNG
@test rand(Random.seed!(1776), model) == sample_namedtuple
Expand All @@ -235,7 +235,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
Random.seed!(1776)
@test rand(NamedTuple, model) == sample_namedtuple
Random.seed!(1776)
@test rand(Dict, model) == sample_dict
@test rand(OrderedDict, model) == sample_dict
end

@testset "default arguments" begin
Expand Down Expand Up @@ -263,7 +263,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true

@testset "TestUtils" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
x = rand(model)
x = DynamicPPL.TestUtils.rand_prior_true(model)
# `rand_prior_true` should return a `NamedTuple`.
@test x isa NamedTuple

# `rand` with a `AbstractDict` should have `varnames` as keys.
x_rand_dict = rand(OrderedDict, model)
for vn in DynamicPPL.TestUtils.varnames(model)
@test haskey(x_rand_dict, vn)
end
# `rand` with a `NamedTuple` should have `map(Symbol, varnames)` as keys.
x_rand_nt = rand(NamedTuple, model)
for vn in DynamicPPL.TestUtils.varnames(model)
@test haskey(x_rand_nt, Symbol(vn))
end

# Ensure log-probability computations are implemented.
@test logprior(model, x) DynamicPPL.TestUtils.logprior_true(model, x...)
@test loglikelihood(model, x)
Expand Down
Loading

0 comments on commit a5bbe61

Please sign in to comment.