Skip to content

Commit

Permalink
Support (un)whiten and (inv)quad with static arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Sep 26, 2023
1 parent af62114 commit edada9f
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 73 deletions.
54 changes: 24 additions & 30 deletions src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,6 @@ LinearAlgebra.checksquare(a::AbstractPDMat) = size(a, 1)

## whiten and unwhiten

whiten!(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(x, a, x)

function whiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
ldiv!(chol_lower(cholesky(a)), v)
end

function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
lmul!(chol_lower(cholesky(a)), v)
end

"""
whiten(a::AbstractMatrix, x::AbstractVecOrMat)
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat)
Expand Down Expand Up @@ -80,35 +67,41 @@ julia> W * W'
0.0 1.0
```
"""
whiten(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(similar(x), a, x)
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(similar(x), a, x)
whiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten(AbstractPDMat(a), x)
unwhiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten(AbstractPDMat(a), x)

whiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten!(x, a, x)

function whiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return whiten!(r, AbstractPDMat(a), x)
end
function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return unwhiten!(r, AbstractPDMat(a), x)
end

## quad

"""
quad(a::AbstractMatrix, x::AbstractVecOrMat)
Return the value of the quadratic form defined by `a` applied to `x`
Return the value of the quadratic form defined by `a` applied to `x`.
If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
function quad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
quad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
function quad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return quad(AbstractPDMat(a), x)
end

quad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_upper(cholesky(a)) * x)
invquad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_lower(cholesky(a)) \ x)

"""
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)
Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`
Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`.
"""
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a * x)

function quad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix)
return quad!(r, AbstractPDMat(a), x)
end

"""
invquad(a::AbstractMatrix, x::AbstractVecOrMat)
Expand All @@ -120,15 +113,16 @@ For most `PDMat` types this is done in a way that does not require evaluation of
If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
invquad(a::AbstractMatrix, x::AbstractVecOrMat) = x' / a * x
function invquad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
invquad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
function invquad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return invquad(AbstractPDMat(a), x)
end

"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)
Overwrite `r` with the value of the quadratic form defined by `inv(a)` applied columnwise to `x`
"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a \ x)
function invquad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix)
return invquad!(r, AbstractPDMat(a), x)
end

65 changes: 34 additions & 31 deletions src/pdiagmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,45 +86,48 @@ LinearAlgebra.sqrt(a::PDiagMat) = PDiagMat(map(sqrt, a.diag))

### whiten and unwhiten

function whiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
n = a.dim
@check_argdims length(r) == length(x) == n
v = a.diag
for i = 1:n
r[i] = x[i] / sqrt(v[i])
end
return r
function whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
return r .= x ./ sqrt.(a.diag)
end

function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
n = a.dim
@check_argdims length(r) == length(x) == n
v = a.diag
for i = 1:n
r[i] = x[i] * sqrt(v[i])
end
return r
function unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
return r .= x .* sqrt.(a.diag)
end

function whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
r .= x ./ sqrt.(a.diag)
return r
function whiten(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return x ./ sqrt.(a.diag)
end

function unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
r .= x .* sqrt.(a.diag)
return r
function unwhiten(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return x .* sqrt.(a.diag)
end


whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x ./ sqrt.(a.diag)
unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x .* sqrt.(a.diag)


### quadratic forms

quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x)
invquad(a::PDiagMat, x::AbstractVector) = invwsumsq(a.diag, x)
function quad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return wsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) .* a.diag; dims = 1))
end
end
function invquad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return invwsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) ./ a.diag; dims = 1))
end
end

function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
ad = a.diag
Expand Down
70 changes: 70 additions & 0 deletions src/pdmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,76 @@ LinearAlgebra.eigmin(a::PDMat) = eigmin(a.mat)
Base.kron(A::PDMat, B::PDMat) = PDMat(kron(A.mat, B.mat), Cholesky(kron(A.chol.U, B.chol.U), 'U', A.chol.info))
LinearAlgebra.sqrt(A::PDMat) = PDMat(sqrt(Hermitian(A.mat)))

### (un)whitening

function whiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
v = _rcopy!(r, x)
return ldiv!(chol_lower(cholesky(a)), v)
end
function unwhiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
v = _rcopy!(r, x)
return lmul!(chol_lower(cholesky(a)), v)
end

function whiten(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) \ x
end
function unwhiten(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) * x
end

## quad/invquad

function quad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
aU_x = chol_upper(cholesky(a)) * x
if x isa AbstractVector
return sum(abs2, aU_x)
else
return vec(sum(abs2, aU_x; dims = 1))
end
end
function invquad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
inv_aL_x = chol_lower(cholesky(a)) \ x
if x isa AbstractVector
return sum(abs2, inv_aL_x)
else
return vec(sum(abs2, inv_aL_x; dims = 1))
end
end

function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
aU = chol_upper(cholesky(a))
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
copyto!(z, view(x, :, i))
lmul!(aU, z)
r[i] = sum(abs2, z)
end
return r
end
function invquad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
aL = chol_lower(cholesky(a))
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
copyto!(z, view(x, :, i))
ldiv!(aL, z)
r[i] = sum(abs2, z)
end
return r
end

### tri products

function X_A_Xt(a::PDMat, x::AbstractMatrix)
Expand Down
33 changes: 27 additions & 6 deletions src/pdsparsemat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,41 @@ LinearAlgebra.sqrt(A::PDSparseMat) = PDMat(sqrt(Hermitian(Matrix(A))))
### whiten and unwhiten

function whiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat)
r[:] = sparse(chol_lower(a.chol)) \ x
return r
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
# ldiv! is not defined for SparseMatrixCSC
return copyto!(r, sparse(chol_lower(a.chol)) \ x)
end

function unwhiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat)
r[:] = sparse(chol_lower(a.chol)) * x
return r
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
# lmul! is not defined for SparseMatrixCSC
return copyto!(r, sparse(chol_lower(a.chol)) * x)
end

function whiten(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return sparse(chol_lower(cholesky(a))) \ x
end

function unwhiten(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return sparse(chol_lower(cholesky(a))) * x
end

### quadratic forms

quad(a::PDSparseMat, x::AbstractVector) = dot(x, a * x)
invquad(a::PDSparseMat, x::AbstractVector) = dot(x, a \ x)
function quad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
z = sparse(chol_lower(cholesky(a)))' * x
return x isa AbstractVector ? sum(abs2, z) : vec(sum(abs2, z; dims = 1))
end
function invquad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
z = sparse(chol_lower(cholesky(a))) \ x
return x isa AbstractVector ? sum(abs2, z) : vec(sum(abs2, z; dims = 1))
end

function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
Expand Down
54 changes: 48 additions & 6 deletions src/scalmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,65 @@ LinearAlgebra.sqrt(a::ScalMat) = ScalMat(a.dim, sqrt(a.value))
### whiten and unwhiten

function whiten!(r::AbstractVecOrMat, a::ScalMat, x::AbstractVecOrMat)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
_ldiv!(r, sqrt(a.value), x)
end

function unwhiten!(r::AbstractVecOrMat, a::ScalMat, x::AbstractVecOrMat)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
mul!(r, x, sqrt(a.value))
end

function whiten(a::ScalMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return x / sqrt(a.value)
end
function unwhiten(a::ScalMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return sqrt(a.value) * x
end

### quadratic forms

quad(a::ScalMat, x::AbstractVector) = sum(abs2, x) * a.value
invquad(a::ScalMat, x::AbstractVector) = sum(abs2, x) / a.value
function quad(a::ScalMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return sum(abs2, x) * a.value
else
# map(Base.Fix1(quad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
wsq = let w = a.value
x -> w * abs2(x)
end
return vec(sum(wsq, x; dims=1))
end
end
function invquad(a::ScalMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return sum(abs2, x) / a.value
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
wsq = let w = a.value
x -> abs2(x) / w
end
return vec(sum(wsq, x; dims=1))
end
end

quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsq!(r, x, a.value)
invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsqinv!(r, x, a.value)
function quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
return map!(Base.Fix1(quad, a), r, eachcol(x))
end
function invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
return map!(Base.Fix1(invquad, a), r, eachcol(x))
end


### tri products
Expand Down
Loading

0 comments on commit edada9f

Please sign in to comment.