Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for LKJCholesky #521

Merged
merged 10 commits into from
Aug 29, 2023
2 changes: 1 addition & 1 deletion src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ values_as(vi::SimpleVarInfo) = vi.values
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T}
isempty(vi) && return T[]
return mapreduce(v -> vec([v;]), vcat, values(vi.values))
return mapreduce(vectorize, vcat, values(vi.values))
end
function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict}
return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values)))
Expand Down
18 changes: 13 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,10 @@ invlink_transform(dist) = inverse(link_transform(dist))
# Helper functions for vectorize/reconstruct values #
#####################################################

vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL))
vectorize(d, r) = vectorize(r)
vectorize(r::Real) = [r]
vectorize(r::AbstractArray{<:Real}) = copy(vec(r))
vectorize(r::Cholesky) = copy(vec(r.UL))

# NOTE:
# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real.
Expand All @@ -237,6 +236,15 @@ reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)

function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})
return reconstruct(dist, reshape(val, size(dist)))
end
function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real})
return Cholesky(val, dist.uplo, 0)
end
reconstruct(::LKJCholesky, val::Cholesky) = val

function reconstruct(
::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector
)
Expand Down
7 changes: 6 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,12 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
The values may or may not be transformed to Euclidean space.
"""
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
setval!(md::Metadata, val, vn::VarName) = md.vals[getrange(md, vn)] = [val;]
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
md.vals[getrange(md, vn)] = val
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
function setval!(md::Metadata, val, vn::VarName)
md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

"""
getval(vi::VarInfo, vns::Vector{<:VarName})
Expand Down
39 changes: 37 additions & 2 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,51 @@ end
end
end

@testset "LKJCholesky" begin
@testset "uplo=$uplo" for uplo in ['L', 'U']
@model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo)
@testset "d=$d" for d in [2, 3, 5]
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(model)
vis = DynamicPPL.TestUtils.setup_varinfos(
model, values_original, (@varname(x),)
)
@testset "$(short_varinfo_name(vi))" for vi in vis
val = vi[@varname(x), dist]
# Ensure that `reconstruct` works as intended.
@test val isa Cholesky
@test val.uplo == uplo

@test length(vi[:]) == d^2
lp = logpdf(dist, val)
lp_model = logjoint(model, vi)
@test lp_model ≈ lp
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == d * (d - 1) ÷ 2
# Should now include the log-absdet-jacobian correction.
@test !(getlogp(vi_linked) ≈ lp)
# Invlinked.
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@test length(vi_invlinked[:]) == d^2
@test getlogp(vi_invlinked) ≈ lp
end
end
end
end

# Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504
@testset "dirichlet" begin
@testset "Dirichlet" begin
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
@test getlogp(vi) ≈ lp
lp_model = logjoint(model, vi)
@test lp_model ≈ lp
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == d - 1
Expand Down