From 4fee35cb0344aacc5e0748e66838f801d4f1d282 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 24 Aug 2023 15:34:02 +0100 Subject: [PATCH 1/9] simplification of vectorize and make use of non-dist version in SimpleVarInfo --- src/simple_varinfo.jl | 2 +- src/utils.jl | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b3ffcec8d..025b4aad7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -544,7 +544,7 @@ values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} isempty(vi) && return T[] - return mapreduce(v -> vec([v;]), vcat, values(vi.values)) + return mapreduce(vectorize, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) diff --git a/src/utils.jl b/src/utils.jl index 9a0c9c2b2..d7a28a83e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -209,11 +209,10 @@ invlink_transform(dist) = inverse(link_transform(dist)) # Helper functions for vectorize/reconstruct values # ##################################################### -vectorize(d, r) = vec(r) -vectorize(d::UnivariateDistribution, r::Real) = [r] -vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r) -vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) -vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL)) +vectorize(d, r) = vectorize(r) +vectorize(r::Real) = [r] +vectorize(r::AbstractArray{<:Real}) = copy(vec(r)) +vectorize(r::Cholesky) = copy(vec(r.UL)) # NOTE: # We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real. From 43bc65955e56b9339d4a35898921169f428a8604 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 24 Aug 2023 15:34:27 +0100 Subject: [PATCH 2/9] added special reconstruct for LKJCholeksy --- src/utils.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index d7a28a83e..d28697127 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -236,6 +236,15 @@ reconstruct(::UnivariateDistribution, val::Real) = val reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val) + +function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real}) + return reconstruct(dist, reshape(val, size(dist))) +end +function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real}) + return Cholesky(val, dist.uplo, 0) +end +reconstruct(::LKJCholesky, val::Cholesky) = val + function reconstruct( ::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector ) From 192474673f42b2f3ca6266147de5f48bb297011e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 24 Aug 2023 15:34:39 +0100 Subject: [PATCH 3/9] make use of vectorize in setval! for VarInfo --- src/varinfo.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index c1ccc34b9..e257f0212 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -325,7 +325,12 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) -setval!(md::Metadata, val, vn::VarName) = md.vals[getrange(md, vn)] = [val;] +function setval!(md::Metadata, val::AbstractVector, vn::VarName) + md.vals[getrange(md, vn)] = val +end +function setval!(md::Metadata, val, vn::VarName) + md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val) +end """ getval(vi::VarInfo, vns::Vector{<:VarName}) From 9b7504eb8607679ec9343bede4ace0787d06117d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 24 Aug 2023 15:34:51 +0100 Subject: [PATCH 4/9] added tests --- test/linking.jl | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/test/linking.jl b/test/linking.jl index bb0081780..c9c0c318f 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -91,8 +91,42 @@ end end end + @testset "LKJCholesky" begin + @testset "uplo=$uplo" for uplo in ['L', 'U'] + @model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo) + @testset "d=$d" for d in [2, 3, 5] + model = demo_lkj(d) + dist = LKJCholesky(d, 1.0, uplo) + values_original = rand(model) + vis = DynamicPPL.TestUtils.setup_varinfos( + model, values_original, (@varname(x),) + ) + @testset "$(short_varinfo_name(vi))" for vi in vis + val = vi[@varname(x), dist] + # Ensure that `reconstruct` works as intended. + @test val isa Cholesky + @test val.uplo == uplo + + @test length(vi[:]) == d^2 + lp = logpdf(dist, val) + lp_model = logjoint(model, vi) + @test lp_model ≈ lp + # Linked. + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + @test length(vi_linked[:]) == d * (d - 1) ÷ 2 + # Should now include the log-absdet-jacobian correction. + @test !(getlogp(vi_linked) ≈ lp) + # Invlinked. + vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + @test length(vi_invlinked[:]) == d^2 + @test getlogp(vi_invlinked) ≈ lp + end + end + end + end + # Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504 - @testset "dirichlet" begin + @testset "Dirichlet" begin @model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0) @testset "d=$d" for d in [2, 3, 5] model = demo_dirichlet(d) @@ -100,7 +134,8 @@ end @testset "$(short_varinfo_name(vi))" for vi in vis lp = logpdf(Dirichlet(d, 1.0), vi[:]) @test length(vi[:]) == d - @test getlogp(vi) ≈ lp + lp_model = logjoint(model, vi) + @test lp_model ≈ lp # Linked. vi_linked = DynamicPPL.link!!(deepcopy(vi), model) @test length(vi_linked[:]) == d - 1 From af227a44106a5f6c3d9f74219994d7ca042dc642 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 24 Aug 2023 15:49:11 +0100 Subject: [PATCH 5/9] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index e257f0212..60b6e93c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -326,10 +326,10 @@ The values may or may not be transformed to Euclidean space. """ setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) function setval!(md::Metadata, val::AbstractVector, vn::VarName) - md.vals[getrange(md, vn)] = val + return md.vals[getrange(md, vn)] = val end function setval!(md::Metadata, val, vn::VarName) - md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val) + return md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val) end """ From 41f26c177c93cf10404e729187d2d2c142d6f888 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Aug 2023 14:17:05 +0100 Subject: [PATCH 6/9] fixed test_setval! not working when the true value is not a vector --- test/test_util.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_util.jl b/test/test_util.jl index 892f7221a..8fad5aa46 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -67,7 +67,7 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1) else chain[sample_idx, n, chain_idx] end - @test v == chain_val + @test all(v .== chain_val) end end end From 60865d1ed57c6207b1b2449977c9bc98ebd5d2af Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Aug 2023 15:17:53 +0100 Subject: [PATCH 7/9] okay now we actually fixed the test_setval! --- test/test_util.jl | 11 +++++++---- test/turing/runtests.jl | 1 + 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_util.jl b/test/test_util.jl index 8fad5aa46..994086070 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -61,13 +61,16 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1) nt = DynamicPPL.tonamedtuple(var_info) for (k, (vals, names)) in pairs(nt) for (n, v) in zip(names, vals) - chain_val = if Symbol(n) ∉ keys(chain) + if Symbol(n) ∉ keys(chain) # Assume it's a group - vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) + chain_val = vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) + v_true = vec(v) else - chain[sample_idx, n, chain_idx] + chain_val = chain[sample_idx, n, chain_idx] + v_true = v end - @test all(v .== chain_val) + + @test v_true == chain_val end end end diff --git a/test/turing/runtests.jl b/test/turing/runtests.jl index 7d53cb4db..2c1d5085d 100644 --- a/test/turing/runtests.jl +++ b/test/turing/runtests.jl @@ -10,6 +10,7 @@ setprogress!(false) Random.seed!(100) # load test utilities +include(joinpath(pathof(DynamicPPL), "..", "..", "test", "test_util.jl")) include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_tests.jl")) @testset "Turing" begin From 9e0e0e70543985a81c5d0c9e4033a36d349d24ec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Aug 2023 15:22:43 +0100 Subject: [PATCH 8/9] Update test/test_util.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/test_util.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_util.jl b/test/test_util.jl index 994086070..31296f79a 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -63,7 +63,9 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1) for (n, v) in zip(names, vals) if Symbol(n) ∉ keys(chain) # Assume it's a group - chain_val = vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]) + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) v_true = vec(v) else chain_val = chain[sample_idx, n, chain_idx] From 7e9b25d3fc53fc000d2f36fb605b2fca44618d38 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:28:26 +0100 Subject: [PATCH 9/9] Update Project.toml (#522) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c5b7a0241..a20e3546a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.13" +version = "0.23.14" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"