From f28cdfb7309c69575d081ad13caec11e8060fe35 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 13 Oct 2023 14:49:50 +0200 Subject: [PATCH] Define functions for `Cholesky` (#168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Define functions for `Cholesky` * Only test `Cholesky` * Fix numerical issues in tests * Update chol.jl * Fixes * Fix tests (hopefully) * Remove old but now unnecessary tests --------- Co-authored-by: Mathieu Besançon --- README.md | 11 ++++++++ src/chol.jl | 64 +++++++++++++++++++++++++++++++++++++++++++ test/chol.jl | 13 +++++++-- test/specialarrays.jl | 15 ++++++---- test/testutils.jl | 2 +- 5 files changed, 96 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ffd8dc0..f48baab 100644 --- a/README.md +++ b/README.md @@ -214,6 +214,17 @@ While in theory all of them can be defined, at present only the following subset PRs to implement more generic fallbacks are welcome. +### Fallbacks for `LinearAlgebra.Cholesky` + +For Cholesky decompositions of type `Cholesky` the following functions are defined as well: + + - `dim` + - `whiten`, `whiten!` + - `unwhiten`, `unwhiten!` + - `quad`, `quad!` + - `invquad`, `invquad!` + - `X_A_Xt`, `Xt_A_X`, `X_invA_Xt`, `Xt_invA_X` + ## Define Customized Subtypes In some situation, it is useful to define a customized subtype of `AbstractPDMat` to capture positive definite matrices with special structures. For this purpose, one has to define a subset of methods (as listed below), and other methods will be automatically provided. diff --git a/src/chol.jl b/src/chol.jl index 56c5c7c..313694d 100644 --- a/src/chol.jl +++ b/src/chol.jl @@ -24,3 +24,67 @@ if HAVE_CHOLMOD chol_lower(cf::CholTypeSparse) = cf.PtL chol_upper(cf::CholTypeSparse) = cf.UP end + +# Interface for `Cholesky` + +dim(A::Cholesky) = LinearAlgebra.checksquare(A) + +# whiten +whiten(A::Cholesky, x::AbstractVecOrMat) = chol_lower(A) \ x +whiten!(A::Cholesky, x::AbstractVecOrMat) = ldiv!(chol_lower(A), x) + +# unwhiten +unwhiten(A::Cholesky, x::AbstractVecOrMat) = chol_lower(A) * x +unwhiten!(A::Cholesky, x::AbstractVecOrMat) = lmul!(chol_lower(A), x) + +# 3-argument whiten/unwhiten +for T in (:AbstractVector, :AbstractMatrix) + @eval begin + whiten!(r::$T, A::Cholesky, x::$T) = whiten!(A, copyto!(r, x)) + unwhiten!(r::$T, A::Cholesky, x::$T) = unwhiten!(A, copyto!(r, x)) + end +end + +# quad +quad(A::Cholesky, x::AbstractVector) = sum(abs2, chol_upper(A) * x) +function quad(A::Cholesky, X::AbstractMatrix) + Z = chol_upper(A) * X + return vec(sum(abs2, Z; dims=1)) +end +function quad!(r::AbstractArray, A::Cholesky, X::AbstractMatrix) + Z = chol_upper(A) * X + return map!(Base.Fix1(sum, abs2), r, eachcol(Z)) +end + +# invquad +invquad(A::Cholesky, x::AbstractVector) = sum(abs2, chol_lower(A) \ x) +function invquad(A::Cholesky, X::AbstractMatrix) + Z = chol_lower(A) \ X + return vec(sum(abs2, Z; dims=1)) +end +function invquad!(r::AbstractArray, A::Cholesky, X::AbstractMatrix) + Z = chol_lower(A) * X + return map!(Base.Fix1(sum, abs2), r, eachcol(Z)) +end + +# tri products + +function X_A_Xt(A::Cholesky, X::AbstractMatrix) + Z = X * chol_lower(A) + return Z * transpose(Z) +end + +function Xt_A_X(A::Cholesky, X::AbstractMatrix) + Z = chol_upper(A) * X + return transpose(Z) * Z +end + +function X_invA_Xt(A::Cholesky, X::AbstractMatrix) + Z = X / chol_upper(A) + return Z * transpose(Z) +end + +function Xt_invA_X(A::Cholesky, X::AbstractMatrix) + Z = chol_lower(A) \ X + return transpose(Z) * Z +end diff --git a/test/chol.jl b/test/chol.jl index 06b43cf..ebdabb4 100644 --- a/test/chol.jl +++ b/test/chol.jl @@ -3,20 +3,29 @@ using PDMats: chol_lower, chol_upper @testset "chol_lower and chol_upper" begin @testset "allocations" begin - A = rand(100, 100) + d = 100 + A = rand(d, d) C = A'A + invC = inv(C) size_of_one_copy = sizeof(C) - @assert size_of_one_copy > 100 # ensure the matrix is large enough that few-byte allocations don't matter + @assert size_of_one_copy > d # ensure the matrix is large enough that few-byte allocations don't matter @test chol_lower(C) ≈ chol_upper(C)' @test (@allocated chol_lower(C)) < 1.05 * size_of_one_copy # allow 5% overhead @test (@allocated chol_upper(C)) < 1.05 * size_of_one_copy + X = randn(d, 10) for uplo in (:L, :U) ch = cholesky(Symmetric(C, uplo)) @test chol_lower(ch) ≈ chol_upper(ch)' @test (@allocated chol_lower(ch)) < 33 # allow small overhead for wrapper types @test (@allocated chol_upper(ch)) < 33 # allow small overhead for wrapper types + + # Only test dim, `quad`/`invquad`, `whiten`/`unwhiten`, and tri products + @test dim(ch) == size(C, 1) + pdtest_quad(ch, C, invC, X, 0) + pdtest_triprod(ch, C, invC, X, 0) + pdtest_whiten(ch, C, 0) end end diff --git a/test/specialarrays.jl b/test/specialarrays.jl index 3b6417e..cc819bb 100644 --- a/test/specialarrays.jl +++ b/test/specialarrays.jl @@ -4,7 +4,7 @@ using StaticArrays @testset "Special matrix types" begin @testset "StaticArrays" begin # Full matrix - S = (x -> x * x')(@SMatrix(randn(4, 7))) + S = (x -> x * x' + I)(@SMatrix(randn(4, 7))) PDS = PDMat(S) @test PDS isa PDMat{Float64, <:SMatrix{4, 4, Float64}} @test isbits(PDS) @@ -27,12 +27,15 @@ using StaticArrays X = @SMatrix rand(10, 4) Y = @SMatrix rand(4, 10) - for A in (PDS, D, E) - @test A * x isa SVector{4, Float64} - @test A * x ≈ Matrix(A) * Vector(x) + for A in (PDS, D, E, C) + if !(A isa Cholesky) + # `*(::Cholesky, ::SArray)` is not defined + @test A * x isa SVector{4, Float64} + @test A * x ≈ Matrix(A) * Vector(x) - @test A * Y isa SMatrix{4, 10, Float64} - @test A * Y ≈ Matrix(A) * Matrix(Y) + @test A * Y isa SMatrix{4, 10, Float64} + @test A * Y ≈ Matrix(A) * Matrix(Y) + end @test X / A isa SMatrix{10, 4, Float64} @test X / A ≈ Matrix(X) / Matrix(A) diff --git a/test/testutils.jl b/test/testutils.jl index 6bc2d09..c01f5b0 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -17,7 +17,7 @@ function test_pdmat(C, Cmat::Matrix; t_cholesky::Bool=true, # whether to test cholesky method t_scale::Bool=true, # whether to test scaling t_add::Bool=true, # whether to test pdadd - t_det::Bool=true, # whether to test det method + t_det::Bool=true, # whether to test det method t_logdet::Bool=true, # whether to test logdet method t_eig::Bool=true, # whether to test eigmax and eigmin t_mul::Bool=true, # whether to test multiplication