Skip to content

Commit

Permalink
Fix for rand + replace overloads of rand with rand_prior_true f…
Browse files Browse the repository at this point in the history
…or testing models (#541)

* preserve context from model in `rand`

* replace rand overloads in TestUtils with definitions of
rand_prior_true so we can properly test rand

* removed NamedTuple from signature of TestUtils.rand_prior_true

* updated references to previous overloads of rand to now use rand_prior_true

* test rand for DEMO_MODELS

* formatting

* fixed tests for rand for DEMO_MODELS

* fixed linkning tests

* added missing impl of rand_prior_true for demo_static_transformation

* formatting

* fixed rand_prior_true for demo_static_transformation

* bump minor version as this will be breaking

* bump patch version

* fixed old usage of rand

* Update test/varinfo.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed another usage of rand

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
torfjelde and github-actions[bot] authored Oct 25, 2023
1 parent 12e7c27 commit 2e8adf4
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.20"
version = "0.23.21"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
evaluate!!(
model,
SimpleVarInfo{Float64}(OrderedDict()),
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
SamplingContext(rng, SampleFromPrior(), model.context),
),
)
return values_as(x, T)
Expand Down
57 changes: 29 additions & 28 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`.
"""
function posterior_mean end

"""
rand_prior_true([rng::AbstractRNG, ]model::DynamicPPL.Model)
Return a `NamedTuple` of realizations from the prior of `model` compatible with `varnames(model)`.
"""
function rand_prior_true(model::DynamicPPL.Model)
return rand_prior_true(Random.default_rng(), model)
end

"""
demo_dynamic_constraint()
Expand Down Expand Up @@ -263,10 +272,8 @@ function logprior_true_with_logabsdet_jacobian(
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG,
::Type{NamedTuple},
model::Model{typeof(demo_one_variable_multiple_constraints)},
function rand_prior_true(
rng::Random.AbstractRNG, model::Model{typeof(demo_one_variable_multiple_constraints)}
)
x = Vector{Float64}(undef, 5)
x[1] = rand(rng, Normal())
Expand Down Expand Up @@ -310,9 +317,7 @@ function logprior_true_with_logabsdet_jacobian(model::Model{typeof(demo_lkjchol)
return (x=x_unconstrained,), logprior_true(model, x) - Δlogp
end

function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::Model{typeof(demo_lkjchol)}
)
function rand_prior_true(rng::Random.AbstractRNG, model::Model{typeof(demo_lkjchol)})
x = rand(rng, LKJCholesky(model.args.d, 1.0))
return (x=x,)
end
Expand Down Expand Up @@ -724,12 +729,6 @@ const DemoModels = Union{
Model{typeof(demo_assume_matrix_dot_observe_matrix)},
}

# We require demo models to have explict impleentations of `rand` since we want
# these to be considered as ground truth.
function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels)
return error("demo models requires explicit implementation of rand")
end

const UnivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
}
Expand All @@ -743,9 +742,7 @@ function posterior_optima(::UnivariateAssumeDemoModels)
# TODO: Figure out exact for `s`.
return (s=0.907407, m=7 / 6)
end
function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels)
s = rand(rng, InverseGamma(2, 3))
m = rand(rng, Normal(0, sqrt(s)))

Expand All @@ -766,7 +763,7 @@ const MultivariateAssumeDemoModels = Union{
}
function posterior_mean(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

vals.s[1] = 19 / 8
vals.m[1] = 3 / 4
Expand All @@ -778,7 +775,7 @@ function posterior_mean(model::MultivariateAssumeDemoModels)
end
function likelihood_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1] = 1e-32
Expand All @@ -791,7 +788,7 @@ function likelihood_optima(model::MultivariateAssumeDemoModels)
end
function posterior_optima(model::MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1] = 0.890625
Expand All @@ -801,9 +798,7 @@ function posterior_optima(model::MultivariateAssumeDemoModels)

return vals
end
function Random.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels)
# Get template values from `model`.
retval = model(rng)
vals = (s=retval.s, m=retval.m)
Expand All @@ -821,7 +816,7 @@ const MatrixvariateAssumeDemoModels = Union{
}
function posterior_mean(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

vals.s[1, 1] = 19 / 8
vals.m[1] = 3 / 4
Expand All @@ -833,7 +828,7 @@ function posterior_mean(model::MatrixvariateAssumeDemoModels)
end
function likelihood_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1, 1] = 1e-32
Expand All @@ -846,7 +841,7 @@ function likelihood_optima(model::MatrixvariateAssumeDemoModels)
end
function posterior_optima(model::MatrixvariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)
vals = rand_prior_true(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1, 1] = 0.890625
Expand All @@ -856,9 +851,7 @@ function posterior_optima(model::MatrixvariateAssumeDemoModels)

return vals
end
function Base.rand(
rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MatrixvariateAssumeDemoModels
)
function rand_prior_true(rng::Random.AbstractRNG, model::MatrixvariateAssumeDemoModels)
# Get template values from `model`.
retval = model(rng)
vals = (s=retval.s, m=retval.m)
Expand Down Expand Up @@ -954,6 +947,14 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end

function rand_prior_true(
rng::Random.AbstractRNG, model::Model{typeof(demo_static_transformation)}
)
s = rand(rng, InverseGamma(2, 3))
m = rand(rng, Normal(0, sqrt(s)))
return (s=s, m=m)
end

"""
marginal_mean_of_samples(chain, varname)
Expand Down
8 changes: 5 additions & 3 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ end
@model demo() = m ~ dist
model = demo()

vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),))
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),))
@testset "$(short_varinfo_name(vi))" for vi in vis
# Evaluate once to ensure we have `logp` value.
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
Expand Down Expand Up @@ -105,7 +106,7 @@ end
@testset "d=$d" for d in [2, 3, 5]
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(model)
values_original = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(
model, values_original, (@varname(x),)
)
Expand Down Expand Up @@ -146,7 +147,8 @@ end
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
example_values = rand(NamedTuple, model)
vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
Expand Down
2 changes: 1 addition & 1 deletion test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test, DynamicPPL, LogDensityProblems

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)

Expand Down
2 changes: 1 addition & 1 deletion test/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testset "loglikelihoods.jl" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, m)
example_values = DynamicPPL.TestUtils.rand_prior_true(m)

# Instantiate a `VarInfo` with the example values.
vi = VarInfo(m)
Expand Down
20 changes: 17 additions & 3 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
Random.seed!(1776)
s, m = model()
sample_namedtuple = (; s=s, m=m)
sample_dict = Dict(@varname(s) => s, @varname(m) => m)
sample_dict = OrderedDict(@varname(s) => s, @varname(m) => m)

# With explicit RNG
@test rand(Random.seed!(1776), model) == sample_namedtuple
Expand All @@ -235,7 +235,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
Random.seed!(1776)
@test rand(NamedTuple, model) == sample_namedtuple
Random.seed!(1776)
@test rand(Dict, model) == sample_dict
@test rand(OrderedDict, model) == sample_dict
end

@testset "default arguments" begin
Expand Down Expand Up @@ -263,7 +263,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true

@testset "TestUtils" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
x = rand(model)
x = DynamicPPL.TestUtils.rand_prior_true(model)
# `rand_prior_true` should return a `NamedTuple`.
@test x isa NamedTuple

# `rand` with a `AbstractDict` should have `varnames` as keys.
x_rand_dict = rand(OrderedDict, model)
for vn in DynamicPPL.TestUtils.varnames(model)
@test haskey(x_rand_dict, vn)
end
# `rand` with a `NamedTuple` should have `map(Symbol, varnames)` as keys.
x_rand_nt = rand(NamedTuple, model)
for vn in DynamicPPL.TestUtils.varnames(model)
@test haskey(x_rand_nt, Symbol(vn))
end

# Ensure log-probability computations are implemented.
@test logprior(model, x) DynamicPPL.TestUtils.logprior_true(model, x...)
@test loglikelihood(model, x)
Expand Down
10 changes: 5 additions & 5 deletions test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

@testset "link!! & invlink!! on $(nameof(model))" for model in
DynamicPPL.TestUtils.DEMO_MODELS
values_constrained = rand(NamedTuple, model)
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
@testset "$(typeof(vi))" for vi in (
SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model)
)
Expand Down Expand Up @@ -112,7 +112,7 @@

# We might need to pre-allocate for the variable `m`, so we need
# to see whether this is the case.
svi_nt = SimpleVarInfo(rand(NamedTuple, model))
svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model))
svi_dict = SimpleVarInfo(VarInfo(model), Dict)

@testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in (
Expand All @@ -121,7 +121,7 @@
DynamicPPL.settrans!!(svi_nt, true),
DynamicPPL.settrans!!(svi_dict, true),
)
# Random seed is set in each `@testset`, so we need to sample
# RandOM seed is set in each `@testset`, so we need to sample
# a new realization for `m` here.
retval = model()

Expand All @@ -138,7 +138,7 @@
@test getlogp(svi_new) != 0

### Evaluation ###
values_eval_constrained = rand(NamedTuple, model)
values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
if DynamicPPL.istrans(svi)
_values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian(
model, values_eval_constrained...
Expand Down Expand Up @@ -225,7 +225,7 @@
model = DynamicPPL.TestUtils.demo_static_transformation()

varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, rand(NamedTuple, model), [@varname(s), @varname(m)]
model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)]
)
@testset "$(short_varinfo_name(vi))" for vi in varinfos
# Initialize varinfo and link.
Expand Down
11 changes: 7 additions & 4 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)

@testset "values_as" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = rand(NamedTuple, model)
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)

# Set up the different instances of `AbstractVarInfo` with the desired values.
Expand Down Expand Up @@ -385,7 +385,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
DynamicPPL.TestUtils.demo_lkjchol(),
]
@testset "mutating=$mutating" for mutating in [false, true]
value_true = rand(model)
value_true = DynamicPPL.TestUtils.rand_prior_true(model)
varnames = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, value_true, varnames; include_threadsafe=true
Expand Down Expand Up @@ -541,7 +541,10 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, rand(model), vns; include_threadsafe=true
model,
DynamicPPL.TestUtils.rand_prior_true(model),
vns;
include_threadsafe=true,
)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
@testset "with itself" begin
Expand Down Expand Up @@ -581,7 +584,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
end

@testset "with different value" begin
x = DynamicPPL.TestUtils.rand(model)
x = DynamicPPL.TestUtils.rand_prior_true(model)
varinfo_changed = DynamicPPL.TestUtils.update_values!!(
deepcopy(varinfo), x, vns
)
Expand Down

2 comments on commit 2e8adf4

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/94137

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.23.21 -m "<description of version>" 2e8adf4069c3b34efa8e3ffa12f5737df4e5d40f
git push origin v0.23.21

Please sign in to comment.