Skip to content

Commit

Permalink
Merge branch 'master' into sunxd/remove_tonamedtuple
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Oct 26, 2023
2 parents cea3ed9 + 2e8adf4 commit 1bdc4ea
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
evaluate!!(
model,
SimpleVarInfo{Float64}(OrderedDict()),
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
SamplingContext(rng, SampleFromPrior(), model.context),
),
)
return values_as(x, T)
Expand Down
57 changes: 29 additions & 28 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`.
"""
function posterior_mean end

"""
rand_prior_true([rng::AbstractRNG, ]model::DynamicPPL.Model)
Return a `NamedTuple` of realizations from the prior of `model` compatible with `varnames(model)`.
"""
function rand_prior_true(model::DynamicPPL.Model)
return rand_prior_true(Random.default_rng(), model)
end

"""
demo_dynamic_constraint()
Expand Down Expand Up @@ -263,10 +272,8 @@ function logprior_true_with_logabsdet_jacobian(
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG,
::Type{NamedTuple},
model::Model{typeof(demo_one_variable_multiple_constraints)},
function rand_prior_true(
rng::Random.AbstractRNG, model::Model{typeof(demo_one_variable_multiple_constraints)}
)
x = Vector{Float64}(undef, 5)
x[1] = rand(rng, Normal())
Expand Down Expand Up @@ -310,9 +317,7 @@ function logprior_true_with_logabsdet_jacobian(model::Model{typeof(demo_lkjchol)
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::Model{typeof(demo_lkjchol)}
)
function rand_prior_true(rng::Random.AbstractRNG, model::Model{typeof(demo_lkjchol)})
x = rand(rng, LKJCholesky(model.args.d, 1.0))
return (x=x,)
end
Expand Down Expand Up @@ -724,12 +729,6 @@ const DemoModels = Union{
Model{typeof(demo_assume_matrix_dot_observe_matrix)},
}

# We require demo models to have explict impleentations of `rand` since we want
# these to be considered as ground truth.
function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels)
return error("demo models requires explicit implementation of rand")
end

const UnivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
}
Expand All @@ -743,9 +742,7 @@ function posterior_optima(::UnivariateAssumeDemoModels)
# TODO: Figure out exact for `s`.
return (s=0.907407, m=7 / 6)
end
function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels)
s = rand(rng, InverseGamma(2, 3))
m = rand(rng, Normal(0, sqrt(s)))

Expand All @@ -766,7 +763,7 @@ const MultivariateAssumeDemoModels = Union{
}
function posterior_mean(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

vals.s[1] = 19 / 8
vals.m[1] = 3 / 4
Expand All @@ -778,7 +775,7 @@ function posterior_mean(model::MultivariateAssumeDemoModels)
end
function likelihood_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1] = 1e-32
Expand All @@ -791,7 +788,7 @@ function likelihood_optima(model::MultivariateAssumeDemoModels)
end
function posterior_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1] = 0.890625
Expand All @@ -801,9 +798,7 @@ function posterior_optima(model::MultivariateAssumeDemoModels)

return vals
end
function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels)
# Get template values from `model`.
retval = model(rng)
vals = (s=retval.s, m=retval.m)
Expand All @@ -821,7 +816,7 @@ const MatrixvariateAssumeDemoModels = Union{
}
function posterior_mean(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

vals.s[1, 1] = 19 / 8
vals.m[1] = 3 / 4
Expand All @@ -833,7 +828,7 @@ function posterior_mean(model::MatrixvariateAssumeDemoModels)
end
function likelihood_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1, 1] = 1e-32
Expand All @@ -846,7 +841,7 @@ function likelihood_optima(model::MatrixvariateAssumeDemoModels)
end
function posterior_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1, 1] = 0.890625
Expand All @@ -856,9 +851,7 @@ function posterior_optima(model::MatrixvariateAssumeDemoModels)

return vals
end
function Base.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MatrixvariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::MatrixvariateAssumeDemoModels)
# Get template values from `model`.
retval = model(rng)
vals = (s=retval.s, m=retval.m)
Expand Down Expand Up @@ -954,6 +947,14 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end

function rand_prior_true(
rng::Random.AbstractRNG, model::Model{typeof(demo_static_transformation)}
)
s = rand(rng, InverseGamma(2, 3))
m = rand(rng, Normal(0, sqrt(s)))
return (s=s, m=m)
end

"""
marginal_mean_of_samples(chain, varname)
Expand Down
8 changes: 5 additions & 3 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ end
@model demo() = m ~ dist
model = demo()

vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),))
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
@testset "$(short_varinfo_name(vi))" for vi in vis
# Evaluate once to ensure we have `logp` value.
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
Expand Down Expand Up @@ -105,7 +106,7 @@ end
@testset "d=$d" for d in [2, 3, 5]
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(model)
values_original = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(
model, values_original, (@varname(x),)
)
Expand Down Expand Up @@ -146,7 +147,8 @@ end
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
Expand Down
2 changes: 1 addition & 1 deletion test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test, DynamicPPL, LogDensityProblems

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)

Expand Down
2 changes: 1 addition & 1 deletion test/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testset "loglikelihoods.jl" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, m)
example_values = DynamicPPL.TestUtils.rand_prior_true(m)

# Instantiate a `VarInfo` with the example values.
vi = VarInfo(m)
Expand Down
20 changes: 17 additions & 3 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
Random.seed!(1776)
s, m = model()
sample_namedtuple = (; s=s, m=m)
sample_dict = Dict(@varname(s) => s, @varname(m) => m)
sample_dict = OrderedDict(@varname(s) => s, @varname(m) => m)

# With explicit RNG
@test rand(Random.seed!(1776), model) == sample_namedtuple
Expand All @@ -235,7 +235,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
Random.seed!(1776)
@test rand(NamedTuple, model) == sample_namedtuple
Random.seed!(1776)
@test rand(Dict, model) == sample_dict
@test rand(OrderedDict, model) == sample_dict
end

@testset "default arguments" begin
Expand Down Expand Up @@ -263,7 +263,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true

@testset "TestUtils" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
x = rand(model)
x = DynamicPPL.TestUtils.rand_prior_true(model)
# `rand_prior_true` should return a `NamedTuple`.
@test x isa NamedTuple

# `rand` with a `AbstractDict` should have `varnames` as keys.
x_rand_dict = rand(OrderedDict, model)
for vn in DynamicPPL.TestUtils.varnames(model)
@test haskey(x_rand_dict, vn)
end
# `rand` with a `NamedTuple` should have `map(Symbol, varnames)` as keys.
x_rand_nt = rand(NamedTuple, model)
for vn in DynamicPPL.TestUtils.varnames(model)
@test haskey(x_rand_nt, Symbol(vn))
end

# Ensure log-probability computations are implemented.
@test logprior(model, x) DynamicPPL.TestUtils.logprior_true(model, x...)
@test loglikelihood(model, x)
Expand Down
10 changes: 5 additions & 5 deletions test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

@testset "link!! & invlink!! on $(nameof(model))" for model in
DynamicPPL.TestUtils.DEMO_MODELS
values_constrained = rand(NamedTuple, model)
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
@testset "$(typeof(vi))" for vi in (
SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model)
)
Expand Down Expand Up @@ -112,7 +112,7 @@

# We might need to pre-allocate for the variable `m`, so we need
# to see whether this is the case.
svi_nt = SimpleVarInfo(rand(NamedTuple, model))
svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model))
svi_dict = SimpleVarInfo(VarInfo(model), Dict)

@testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in (
Expand All @@ -121,7 +121,7 @@
DynamicPPL.settrans!!(svi_nt, true),
DynamicPPL.settrans!!(svi_dict, true),
)
# Random seed is set in each `@testset`, so we need to sample
# RandOM seed is set in each `@testset`, so we need to sample
# a new realization for `m` here.
retval = model()

Expand All @@ -138,7 +138,7 @@
@test getlogp(svi_new) != 0

### Evaluation ###
values_eval_constrained = rand(NamedTuple, model)
values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
if DynamicPPL.istrans(svi)
_values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian(
model, values_eval_constrained...
Expand Down Expand Up @@ -225,7 +225,7 @@
model = DynamicPPL.TestUtils.demo_static_transformation()

varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, rand(NamedTuple, model), [@varname(s), @varname(m)]
model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)]
)
@testset "$(short_varinfo_name(vi))" for vi in varinfos
# Initialize varinfo and link.
Expand Down
11 changes: 7 additions & 4 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)

@testset "values_as" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)

# Set up the different instances of `AbstractVarInfo` with the desired values.
Expand Down Expand Up @@ -385,7 +385,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
DynamicPPL.TestUtils.demo_lkjchol(),
]
@testset "mutating=$mutating" for mutating in [false, true]
value_true = rand(model)
value_true = DynamicPPL.TestUtils.rand_prior_true(model)
varnames = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, value_true, varnames; include_threadsafe=true
Expand Down Expand Up @@ -541,7 +541,10 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, rand(model), vns; include_threadsafe=true
model,
DynamicPPL.TestUtils.rand_prior_true(model),
vns;
include_threadsafe=true,
)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
@testset "with itself" begin
Expand Down Expand Up @@ -581,7 +584,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
end

@testset "with different value" begin
x = DynamicPPL.TestUtils.rand(model)
x = DynamicPPL.TestUtils.rand_prior_true(model)
varinfo_changed = DynamicPPL.TestUtils.update_values!!(
deepcopy(varinfo), x, vns
)
Expand Down

0 comments on commit 1bdc4ea

Please sign in to comment.