Skip to content

Commit

Permalink
Fixes to adddiag (#186)
Browse files Browse the repository at this point in the history
* _adddiag! should be _adddiag

* better fix for addition Diag/ScalMat + AbstractPDMat

* test for `AbstractVector` change in `_adddiag!`

* test for `AbstractVector` change in `_adddiag`

* making actual @tests

* Additional fix and tests

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
olivierverdier and devmotion authored Oct 3, 2023
1 parent 967eec6 commit e97e813
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/addition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# between pdmat and pdmat

+(a::PDMat, b::AbstractPDMat) = PDMat(a.mat + Matrix(b))
+(a::PDiagMat, b::AbstractPDMat) = PDMat(_adddiag!(Matrix(b), a.diag))
+(a::PDiagMat, b::AbstractPDMat) = PDMat(_adddiag!(Matrix(b), a.diag, true))
+(a::ScalMat, b::AbstractPDMat) = PDMat(_adddiag!(Matrix(b), a.value))
if HAVE_CHOLMOD
+(a::PDSparseMat, b::AbstractPDMat) = PDMat(a.mat + Matrix(b))
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function _adddiag!(a::Union{Matrix, SparseMatrixCSC}, v::Real)
return a
end

function _adddiag!(a::Union{Matrix, SparseMatrixCSC}, v::Vector, c::Real)
function _adddiag!(a::Union{Matrix, SparseMatrixCSC}, v::AbstractVector, c::Real)
@check_argdims eachindex(v) == axes(a, 1) == axes(a, 2)
if c == one(c)
for i in eachindex(v)
Expand All @@ -45,8 +45,8 @@ function _adddiag!(a::Union{Matrix, SparseMatrixCSC}, v::Vector, c::Real)
end

_adddiag(a::Union{Matrix, SparseMatrixCSC}, v::Real) = _adddiag!(copy(a), v)
_adddiag(a::Union{Matrix, SparseMatrixCSC}, v::Vector, c::Real) = _adddiag!(copy(a), v, c)
_adddiag(a::Union{Matrix, SparseMatrixCSC}, v::Vector{T}) where {T<:Real} = _adddiag!(copy(a), v, one(T))
_adddiag(a::Union{Matrix, SparseMatrixCSC}, v::AbstractVector, c::Real) = _adddiag!(copy(a), v, c)
_adddiag(a::Union{Matrix, SparseMatrixCSC}, v::AbstractVector{T}) where {T<:Real} = _adddiag!(copy(a), v, one(T))


function wsumsq(w::AbstractVector, a::AbstractVector)
Expand Down
39 changes: 30 additions & 9 deletions test/addition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

using PDMats


# New AbstractPDMat type for the tests below
# Supports only functions needed in the tests below
struct ScalMat3D{T<:Real} <: AbstractPDMat{T}
value::T
end
Base.Matrix(a::ScalMat3D) = Matrix(Diagonal(fill(a.value, 3)))
Base.size(::ScalMat3D) = (3, 3)
# Not generally correct
Base.:*(a::ScalMat3D, c::Real) = ScalMat3D(a.value * c)
Base.getindex(a::ScalMat3D, i::Int, j::Int) = i == j ? a.value : zero(a.value)

@testset "addition" begin
for T in (Float64, Float32)
printstyled("Testing addition with eltype = $T\n"; color=:blue)
Expand All @@ -11,26 +23,35 @@ using PDMats

pm1 = PDMat(M)
pm2 = PDiagMat(V)
pm3 = ScalMat(3, X)
pm4 = X * I
pm3 = PDiagMat(sparse(V))
pm4 = ScalMat(3, X)
pm5 = PDSparseMat(sparse(M))
pm6 = ScalMat3D(X)

pmats = Any[pm1, pm2, pm3] #, pm5]
pmats = Any[pm1, pm2, pm3, pm4, pm5, pm6]

for p1 in pmats, p2 in pmats
pr = p1 + p2
@test size(pr) == size(p1)
@test Matrix(pr) Matrix(p1) + Matrix(p2)

pr = pdadd(p1, p2, convert(T, 1.5))
@test size(pr) == size(p1)
@test Matrix(pr) Matrix(p1) + Matrix(p2) * convert(T, 1.5)
if p1 isa ScalMat3D
@test_broken pdadd(p1, p2, convert(T, 1.5))
else
pr = pdadd(p1, p2, convert(T, 1.5))
@test size(pr) == size(p1)
@test Matrix(pr) Matrix(p1) + Matrix(p2) * convert(T, 1.5)
end
end

for p1 in pmats
pr = p1 + pm4
@test size(pr) == size(p1)
@test Matrix(pr) Matrix(p1) + pm4
if p1 isa ScalMat3D
@test_broken p1 + X * I
else
pr = p1 + X * I
@test size(pr) == size(p1)
@test Matrix(pr) Matrix(p1) + X * I
end
end
end
end

0 comments on commit e97e813

Please sign in to comment.