Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for LKJCholesky #521

Merged
merged 10 commits into from
Aug 29, 2023
Merged

Fix for LKJCholesky #521

merged 10 commits into from
Aug 29, 2023

Conversation

torfjelde
Copy link
Member

LKJCholesky is currently not handled correctly in DynamicPPL, as it is missing implementations of reconstruct and vectorize (+ additional bugs in varinfos were exposed now that we're not simply working with AbstractArray{<:Real}).

On #master we have:

julia> using DynamicPPL, Distributions

julia> @model demo_lkj() = x ~ LKJCholesky(2, 1.0)
demo_lkj (generic function with 2 methods)

julia> model = demo_lkj();

julia> VarInfo(model)[@varname(x)]
2×2 Matrix{Float64}:
 1.0       0.0
 0.172835  0.984951

Notice that indexing into the VarInfo returns a Matrix{Float64} (in particular, it returns the lower-triangular used in the Cholesky). This then changes downstream computation paths, for example resulting in cholesky being called in link!! and causing issues (https://discourse.julialang.org/t/singular-exception-with-lkjcholesky/85020).

On this branch, we now have the correct behavior:

julia> using DynamicPPL, Distributions

julia> @model demo_lkj() = x ~ LKJCholesky(2, 1.0)
demo_lkj (generic function with 4 methods)

julia> model = demo_lkj();

julia> VarInfo(model)[@varname(x)]
LinearAlgebra.Cholesky{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}
L factor:
2×2 LinearAlgebra.LowerTriangular{Float64, Base.ReshapedArray{Float64, 2, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}:
  1.0         
 -0.651814  0.758379

src/varinfo.jl Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@yebai
Copy link
Member

yebai commented Aug 25, 2023

@torfjelde can you take a look at the test error? It seems not related to this PR as I also encountered it in #520.

Maybe also consider fixing the chain conversion issue found by @sethaxen

@torfjelde
Copy link
Member Author

The CI failed because we were comparing a 1x2 matrix to a 2-length vector using ==. Doing elementwise comparisons now, which should fix it :)

test/test_util.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@coveralls
Copy link

Pull Request Test Coverage Report for Build 6001043985

  • 14 of 14 (100.0%) changed or added relevant lines in 3 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+4.0%) to 80.403%

Totals Coverage Status
Change from base Build 5805633221: 4.0%
Covered Lines: 2232
Relevant Lines: 2776

💛 - Coveralls

@codecov
Copy link

codecov bot commented Aug 28, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +4.04% 🎉

Comparison is base (7ef5da7) 76.36% compared to head (7e9b25d) 80.40%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #521      +/-   ##
==========================================
+ Coverage   76.36%   80.40%   +4.04%     
==========================================
  Files          24       24              
  Lines        2771     2776       +5     
==========================================
+ Hits         2116     2232     +116     
+ Misses        655      544     -111     
Files Changed Coverage Δ
src/simple_varinfo.jl 67.77% <100.00%> (ø)
src/utils.jl 78.73% <100.00%> (+1.31%) ⬆️
src/varinfo.jl 92.38% <100.00%> (+10.93%) ⬆️

... and 8 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@yebai
Copy link
Member

yebai commented Aug 28, 2023

The CI failed because we were comparing a 1x2 matrix to a 2-length vector using ==. Doing elementwise comparisons now, which should fix it :)

Do you know what caused it? The same tests work fine on the master branch so it is a bit weird.

@yebai yebai requested a review from sunxd3 August 28, 2023 15:31
@torfjelde
Copy link
Member Author

Good point; I believe it's the change to SimpleVarInfo impl of reconstruct; I'll have a look at it tomorrow morning 👍

@yebai yebai enabled auto-merge August 29, 2023 19:47
@yebai yebai added this pull request to the merge queue Aug 29, 2023
Merged via the queue into master with commit 549d9b1 Aug 29, 2023
12 of 13 checks passed
@yebai yebai deleted the torfjelde/lkjchol-fix branch August 29, 2023 23:19
@torfjelde
Copy link
Member Author

torfjelde commented Aug 30, 2023

Good point; I believe it's the change to SimpleVarInfo impl of reconstruct; I'll have a look at it tomorrow morning 👍

So it definitively has nothing to do with SimpleVarInfo; I don't know what this guy is going on about 🙄

But it's quite strange as v in the offending code

@test v == chain_val

also results in a 1x2 matrix on [email protected] but tests were passing.

@yebai
Copy link
Member

yebai commented Aug 30, 2023

Maybe it's caused by a change to @test somewhere in the Julia standard library?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants