diff --git a/src/generics.jl b/src/generics.jl index 686d190..78e95d2 100644 --- a/src/generics.jl +++ b/src/generics.jl @@ -33,19 +33,6 @@ LinearAlgebra.checksquare(a::AbstractPDMat) = size(a, 1) ## whiten and unwhiten -whiten!(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(x, a, x) -unwhiten!(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(x, a, x) - -function whiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat) - v = _rcopy!(r, x) - ldiv!(chol_lower(cholesky(a)), v) -end - -function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat) - v = _rcopy!(r, x) - lmul!(chol_lower(cholesky(a)), v) -end - """ whiten(a::AbstractMatrix, x::AbstractVecOrMat) unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) @@ -80,35 +67,41 @@ julia> W * W' 0.0 1.0 ``` """ -whiten(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(similar(x), a, x) -unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(similar(x), a, x) +whiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten(AbstractPDMat(a), x) +unwhiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten(AbstractPDMat(a), x) +whiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten!(x, a, x) +unwhiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten!(x, a, x) + +function whiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return whiten!(r, AbstractPDMat(a), x) +end +function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return unwhiten!(r, AbstractPDMat(a), x) +end ## quad """ quad(a::AbstractMatrix, x::AbstractVecOrMat) -Return the value of the quadratic form defined by `a` applied to `x` +Return the value of the quadratic form defined by `a` applied to `x`. If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix the quadratic form is applied column-wise. """ -function quad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real} - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) - quad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x) +function quad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return quad(AbstractPDMat(a), x) end -quad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_upper(cholesky(a)) * x) -invquad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_lower(cholesky(a)) \ x) - """ quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) -Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x` +Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`. """ -quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a * x) - +function quad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix) + return quad!(r, AbstractPDMat(a), x) +end """ invquad(a::AbstractMatrix, x::AbstractVecOrMat) @@ -120,10 +113,8 @@ For most `PDMat` types this is done in a way that does not require evaluation of If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix the quadratic form is applied column-wise. """ -invquad(a::AbstractMatrix, x::AbstractVecOrMat) = x' / a * x -function invquad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real} - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) - invquad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x) +function invquad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) + return invquad(AbstractPDMat(a), x) end """ @@ -131,4 +122,7 @@ end Overwrite `r` with the value of the quadratic form defined by `inv(a)` applied columnwise to `x` """ -invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a \ x) +function invquad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix) + return invquad!(r, AbstractPDMat(a), x) +end + diff --git a/src/pdiagmat.jl b/src/pdiagmat.jl index df6f74f..2147883 100644 --- a/src/pdiagmat.jl +++ b/src/pdiagmat.jl @@ -86,45 +86,48 @@ LinearAlgebra.sqrt(a::PDiagMat) = PDiagMat(map(sqrt, a.diag)) ### whiten and unwhiten -function whiten!(r::StridedVector, a::PDiagMat, x::StridedVector) - n = a.dim - @check_argdims length(r) == length(x) == n - v = a.diag - for i = 1:n - r[i] = x[i] / sqrt(v[i]) - end - return r +function whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + return r .= x ./ sqrt.(a.diag) end - -function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector) - n = a.dim - @check_argdims length(r) == length(x) == n - v = a.diag - for i = 1:n - r[i] = x[i] * sqrt(v[i]) - end - return r +function unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + return r .= x .* sqrt.(a.diag) end -function whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) - r .= x ./ sqrt.(a.diag) - return r +function whiten(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return x ./ sqrt.(a.diag) end - -function unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) - r .= x .* sqrt.(a.diag) - return r +function unwhiten(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return x .* sqrt.(a.diag) end - -whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x ./ sqrt.(a.diag) -unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x .* sqrt.(a.diag) - - ### quadratic forms -quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x) -invquad(a::PDiagMat, x::AbstractVector) = invwsumsq(a.diag, x) +function quad(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return wsumsq(a.diag, x) + else + # map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + return vec(sum(abs2.(x) .* a.diag; dims = 1)) + end +end +function invquad(a::PDiagMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return invwsumsq(a.diag, x) + else + # map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + return vec(sum(abs2.(x) ./ a.diag; dims = 1)) + end +end function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix) ad = a.diag diff --git a/src/pdmat.jl b/src/pdmat.jl index b8fd077..98a75ba 100644 --- a/src/pdmat.jl +++ b/src/pdmat.jl @@ -78,6 +78,76 @@ LinearAlgebra.eigmin(a::PDMat) = eigmin(a.mat) Base.kron(A::PDMat, B::PDMat) = PDMat(kron(A.mat, B.mat), Cholesky(kron(A.chol.U, B.chol.U), 'U', A.chol.info)) LinearAlgebra.sqrt(A::PDMat) = PDMat(sqrt(Hermitian(A.mat))) +### (un)whitening + +function whiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + v = _rcopy!(r, x) + return ldiv!(chol_lower(cholesky(a)), v) +end +function unwhiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + v = _rcopy!(r, x) + return lmul!(chol_lower(cholesky(a)), v) +end + +function whiten(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return chol_lower(cholesky(a)) \ x +end +function unwhiten(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return chol_lower(cholesky(a)) * x +end + +## quad/invquad + +function quad(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + aU_x = chol_upper(cholesky(a)) * x + if x isa AbstractVector + return sum(abs2, aU_x) + else + return vec(sum(abs2, aU_x; dims = 1)) + end +end +function invquad(a::PDMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + inv_aL_x = chol_lower(cholesky(a)) \ x + if x isa AbstractVector + return sum(abs2, inv_aL_x) + else + return vec(sum(abs2, inv_aL_x; dims = 1)) + end +end + +function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix) + @check_argdims axes(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + aU = chol_upper(cholesky(a)) + z = similar(r, a.dim) # buffer to save allocations + @inbounds for i in axes(x, 2) + copyto!(z, view(x, :, i)) + lmul!(aU, z) + r[i] = sum(abs2, z) + end + return r +end +function invquad!(r::AbstractArray, a::PDMat, x::AbstractMatrix) + @check_argdims axes(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + aL = chol_lower(cholesky(a)) + z = similar(r, a.dim) # buffer to save allocations + @inbounds for i in axes(x, 2) + copyto!(z, view(x, :, i)) + ldiv!(aL, z) + r[i] = sum(abs2, z) + end + return r +end + ### tri products function X_A_Xt(a::PDMat, x::AbstractMatrix) diff --git a/src/pdsparsemat.jl b/src/pdsparsemat.jl index d7f10ea..68e4121 100644 --- a/src/pdsparsemat.jl +++ b/src/pdsparsemat.jl @@ -70,20 +70,41 @@ LinearAlgebra.sqrt(A::PDSparseMat) = PDMat(sqrt(Hermitian(Matrix(A)))) ### whiten and unwhiten function whiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat) - r[:] = sparse(chol_lower(a.chol)) \ x - return r + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + # ldiv! is not defined for SparseMatrixCSC + return copyto!(r, sparse(chol_lower(a.chol)) \ x) end function unwhiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat) - r[:] = sparse(chol_lower(a.chol)) * x - return r + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) + # lmul! is not defined for SparseMatrixCSC + return copyto!(r, sparse(chol_lower(a.chol)) * x) end +function whiten(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return sparse(chol_lower(cholesky(a))) \ x +end + +function unwhiten(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return sparse(chol_lower(cholesky(a))) * x +end ### quadratic forms -quad(a::PDSparseMat, x::AbstractVector) = dot(x, a * x) -invquad(a::PDSparseMat, x::AbstractVector) = dot(x, a \ x) +function quad(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + z = sparse(chol_lower(cholesky(a)))' * x + return x isa AbstractVector ? sum(abs2, z) : vec(sum(abs2, z; dims = 1)) +end +function invquad(a::PDSparseMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + z = sparse(chol_lower(cholesky(a))) \ x + return x isa AbstractVector ? sum(abs2, z) : vec(sum(abs2, z; dims = 1)) +end function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix) @check_argdims eachindex(r) == axes(x, 2) diff --git a/src/scalmat.jl b/src/scalmat.jl index e1c5038..f2ddbeb 100644 --- a/src/scalmat.jl +++ b/src/scalmat.jl @@ -76,23 +76,65 @@ LinearAlgebra.sqrt(a::ScalMat) = ScalMat(a.dim, sqrt(a.value)) ### whiten and unwhiten function whiten!(r::AbstractVecOrMat, a::ScalMat, x::AbstractVecOrMat) - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) _ldiv!(r, sqrt(a.value), x) end function unwhiten!(r::AbstractVecOrMat, a::ScalMat, x::AbstractVecOrMat) - @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) + @check_argdims axes(r) == axes(x) + @check_argdims a.dim == size(x, 1) mul!(r, x, sqrt(a.value)) end +function whiten(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return x / sqrt(a.value) +end +function unwhiten(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + return sqrt(a.value) * x +end ### quadratic forms -quad(a::ScalMat, x::AbstractVector) = sum(abs2, x) * a.value -invquad(a::ScalMat, x::AbstractVector) = sum(abs2, x) / a.value +function quad(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return sum(abs2, x) * a.value + else + # map(Base.Fix1(quad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + wsq = let w = a.value + x -> w * abs2(x) + end + return vec(sum(wsq, x; dims=1)) + end +end +function invquad(a::ScalMat, x::AbstractVecOrMat) + @check_argdims a.dim == size(x, 1) + if x isa AbstractVector + return sum(abs2, x) / a.value + else + # map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives + # do NOT return a `SVector` for inputs `x::SMatrix`. + wsq = let w = a.value + x -> abs2(x) / w + end + return vec(sum(wsq, x; dims=1)) + end +end -quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsq!(r, x, a.value) -invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsqinv!(r, x, a.value) +function quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) + @check_argdims eachindex(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + return map!(Base.Fix1(quad, a), r, eachcol(x)) +end +function invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) + @check_argdims eachindex(r) == axes(x, 2) + @check_argdims a.dim == size(x, 1) + return map!(Base.Fix1(invquad, a), r, eachcol(x)) +end ### tri products diff --git a/test/specialarrays.jl b/test/specialarrays.jl index b812a80..3b6417e 100644 --- a/test/specialarrays.jl +++ b/test/specialarrays.jl @@ -43,6 +43,30 @@ using StaticArrays @test A \ Y isa SMatrix{4, 10, Float64} @test A \ Y ≈ Matrix(A) \ Matrix(Y) + @test whiten(A, x) isa SVector{4, Float64} + @test whiten(A, x) ≈ cholesky(Matrix(A)).L \ Vector(x) + + @test whiten(A, Y) isa SMatrix{4, 10, Float64} + @test whiten(A, Y) ≈ cholesky(Matrix(A)).L \ Matrix(Y) + + @test unwhiten(A, x) isa SVector{4, Float64} + @test unwhiten(A, x) ≈ cholesky(Matrix(A)).L * Vector(x) + + @test unwhiten(A, Y) isa SMatrix{4, 10, Float64} + @test unwhiten(A, Y) ≈ cholesky(Matrix(A)).L * Matrix(Y) + + @test quad(A, x) isa Float64 + @test quad(A, x) ≈ Vector(x)' * Matrix(A) * Vector(x) + + @test quad(A, Y) isa SVector{10, Float64} + @test quad(A, Y) ≈ diag(Matrix(Y)' * Matrix(A) * Matrix(Y)) + + @test invquad(A, x) isa Float64 + @test invquad(A, x) ≈ Vector(x)' * (Matrix(A) \ Vector(x)) + + @test invquad(A, Y) isa SVector{10, Float64} + @test invquad(A, Y) ≈ diag(Matrix(Y)' * (Matrix(A) \ Matrix(Y))) + @test X_A_Xt(A, X) isa SMatrix{10, 10, Float64} @test X_A_Xt(A, X) ≈ Matrix(X) * Matrix(A) * Matrix(X)'