Skip to content

Commit

Permalink
Implement promote_operation (#37)
Browse files Browse the repository at this point in the history
* Implement promote_operation

* Fixes

* Add tests

* Fix format
  • Loading branch information
blegat authored Jul 3, 2024
1 parent c041a60 commit 1046d55
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 34 deletions.
6 changes: 6 additions & 0 deletions src/MultivariateBases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ MP.monomial_type(::Type{<:Algebra{B}}) where {B} = MP.monomial_type(B)
function MP.polynomial_type(::Type{<:Algebra{B}}, ::Type{T}) where {B,T}
return MP.polynomial_type(B, T)
end
function MA.promote_operation(
::typeof(SA.basis),
::Type{<:Algebra{B}},
) where {B}
return B
end
SA.basis(a::Algebra) = a.basis

#Base.:(==)(::Algebra{BT1,B1,M}, ::Algebra{BT2,B2,M}) where {BT1,B1,BT2,B2,M} = true
Expand Down
92 changes: 71 additions & 21 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,76 @@ const _APL = MP.AbstractPolynomialLike
# We don't define it for all `AlgebraElement` as this would be type piracy
const _AE = SA.AlgebraElement{<:Algebra}

Base.:(+)(p::_APL, q::_AE) = +(p, MP.polynomial(q))
Base.:(+)(p::_AE, q::_APL) = +(MP.polynomial(p), q)
Base.:(-)(p::_APL, q::_AE) = -(p, MP.polynomial(q))
Base.:(-)(p::_AE, q::_APL) = -(MP.polynomial(p), q)

Base.:(+)(p, q::_AE) = +(constant_algebra_element(typeof(SA.basis(q)), p), q)
function Base.:(+)(p::_AE, q)
return +(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q))
end
function Base.:(-)(p, q::_AE)
return -(constant_algebra_element(typeof(SA.basis(q)), p), MP.polynomial(q))
end
function Base.:(-)(p::_AE, q)
return -(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q))
for op in [:+, :-, :*]
@eval begin
function MA.promote_operation(
::typeof($op),
::Type{P},
::Type{Q},
) where {P<:_APL,Q<:_AE}
return MA.promote_operation($op, P, MP.polynomial_type(Q))
end
Base.$op(p::_APL, q::_AE) = $op(p, MP.polynomial(q))
function MA.promote_operation(
::typeof($op),
::Type{P},
::Type{Q},
) where {P<:_AE,Q<:_APL}
return MA.promote_operation($op, MP.polynomial_type(P), Q)
end
Base.$op(p::_AE, q::_APL) = $op(MP.polynomial(p), q)
# Break ambiguity between the two defined below and the generic one in SA
function MA.promote_operation(
::typeof($op),
::Type{P},
::Type{Q},
) where {P<:_AE,Q<:_AE}
return SA.algebra_promote_operation($op, P, Q)
end
function Base.$op(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output($op, p, q), $op, p, q)
end
end
end

function Base.:(+)(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output(+, p, q), +, p, q)
end

function Base.:(-)(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output(-, p, q), -, p, q)
for op in [:+, :-]
@eval begin
function MA.promote_operation(
::typeof($op),
::Type{P},
::Type{Q},
) where {P,Q<:_AE}
I = MA.promote_operation(implicit, Q)
return MA.promote_operation(
$op,
constant_algebra_element_type(
MA.promote_operation(SA.basis, I),
P,
),
I,
)
end
function Base.$op(p, q::_AE)
i = implicit(q)
return $op(constant_algebra_element(typeof(SA.basis(i)), p), i)
end
function MA.promote_operation(
::typeof($op),
::Type{P},
::Type{Q},
) where {P<:_AE,Q}
I = MA.promote_operation(implicit, P)
return MA.promote_operation(
$op,
I,
constant_algebra_element_type(
MA.promote_operation(SA.basis, I),
Q,
),
)
end
function Base.$op(p::_AE, q)
i = implicit(p)
return $op(i, constant_algebra_element(typeof(SA.basis(i)), q))
end
end
end
52 changes: 47 additions & 5 deletions src/monomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,21 @@ end
implicit_basis(::SubBasis{B,M}) where {B,M} = FullBasis{B,M}()
implicit_basis(basis::FullBasis) = basis

function implicit(a::SA.AlgebraElement)
basis = implicit_basis(SA.basis(a))
return algebra_element(SA.coeffs(a, basis), basis)
end

function MA.promote_operation(
::typeof(implicit),
::Type{E},
) where {AG,T,E<:SA.AlgebraElement{AG,T}}
BT = MA.promote_operation(implicit_basis, MA.promote_operation(SA.basis, E))
A = MA.promote_operation(algebra, BT)
M = MP.monomial_type(BT)
return SA.AlgebraElement{A,T,SA.SparseCoefficients{M,T,Vector{M},Vector{T}}}
end

function MA.promote_operation(
::typeof(implicit_basis),
::Type{<:Union{FullBasis{B,M},SubBasis{B,M}}},
Expand Down Expand Up @@ -201,6 +216,14 @@ end
_one_if_type(α) = α
_one_if_type(::Type{T}) where {T} = one(T)

function constant_algebra_element_type(
::Type{BT},
::Type{T},
) where {B,M,BT<:FullBasis{B,M},T}
A = MA.promote_operation(algebra, BT)
return SA.AlgebraElement{A,T,SA.SparseCoefficients{M,T,Vector{M},Vector{T}}}
end

function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M}
return algebra_element(
sparse_coefficients(
Expand All @@ -210,6 +233,14 @@ function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M}
)
end

function constant_algebra_element_type(
::Type{B},
::Type{T},
) where {B<:SubBasis,T}
A = MA.promote_operation(algebra, B)
return SA.AlgebraElement{A,T,Vector{T}}
end

function constant_algebra_element(::Type{<:SubBasis{B,M}}, α) where {B,M}
return algebra_element(
[_one_if_type(α)],
Expand Down Expand Up @@ -377,14 +408,25 @@ function MP.polynomial_type(::Type{FullBasis{B,M}}, ::Type{T}) where {T,B,M}
return MP.polynomial_type(M, _promote_coef(T, B))
end

_vec(v::Vector) = v
_vec(v::AbstractVector) = collect(v)

# Adapted from SA to incorporate `_promote_coef`
function SA.coeffs(
cfs,
source::MonomialIndexedBasis{B},
target::MonomialIndexedBasis{Monomial},
) where {B}
source::MonomialIndexedBasis{B1},
target::MonomialIndexedBasis{B2},
) where {B1,B2}
source === target && return cfs
source == target && return cfs
res = SA.zero_coeffs(_promote_coef(valtype(cfs), B), target)
return SA.coeffs!(res, cfs, source, target)
if B1 === B2 && target isa FullBasis
# The defaults initialize to zero and then sums which promotes
# `JuMP.VariableRef` to `JuMP.AffExpr`
return SA.SparseCoefficients(_vec(source.monomials), _vec(cfs))
elseif B2 === Monomial
res = SA.zero_coeffs(_promote_coef(valtype(cfs), B1), target)
return SA.coeffs!(res, cfs, source, target)
else
error("Convertion from `$source` to `$target` not implemented yet")
end
end
8 changes: 8 additions & 0 deletions test/hermite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ end
end

@testset "Coefficients" begin
@polyvar x
coefficient_test(
MB.ProbabilistsHermite,
[4, 6, 6, 1, 9, 1, 1, 1];
Expand All @@ -44,4 +45,11 @@ end
1.0,
]),
)
M = typeof(x^2)
mono = MB.FullBasis{MB.Monomial,M}()
basis = MB.FullBasis{MB.PhysicistsHermite,M}()
err = ErrorException(
"Convertion from `$mono` to `$basis` not implemented yet",
)
@test_throws err SA.coeffs(MB.algebra_element(x + 1), basis)
end
29 changes: 21 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,26 @@ const MB = MultivariateBases
using LinearAlgebra
using DynamicPolynomials

function _test_op(op, args...)
result = @inferred op(args...)
@test typeof(result) == MA.promote_operation(op, typeof.(args)...)
return result
end

function _test_basis(basis)
B = typeof(basis)
@test typeof(MB.algebra(basis)) == MA.promote_operation(MB.algebra, B)
@test typeof(MB.constant_algebra_element(B, 1)) ==
MB.constant_algebra_element_type(B, Int)
end

function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
@polyvar x[1:2]
M = typeof(prod(x))
full_basis = FullBasis{B,M}()
_test_basis(full_basis)
@test sprint(show, MB.algebra(full_basis)) ==
"Polynomial algebra of $B basis"
@test typeof(MB.algebra(full_basis)) ==
MA.promote_operation(MB.algebra, typeof(full_basis))
for basis in [
maxdegree_basis(full_basis, x, degree),
explicit_basis_covering(
Expand All @@ -26,8 +38,7 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
MB.SubBasis{ScaledMonomial}(monomials(x, 0:degree)),
),
]
@test typeof(MB.algebra(basis)) ==
MA.promote_operation(MB.algebra, typeof(basis))
_test_basis(basis)
@test basis isa MB.explicit_basis_type(typeof(full_basis))
for i in eachindex(basis)
mono = basis.monomials[i]
Expand All @@ -47,6 +58,8 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
@test length(empty_basis(typeof(basis))) == 0
@test polynomial_type(basis, Float64) == polynomial_type(x[1], Float64)
#@test polynomial(i -> 0.0, basis) isa polynomial_type(basis, Float64)
a = MB.algebra_element(ones(length(basis)), basis)
_test_op(MB.implicit, a)
end
mono = x[1]^2 * x[2]^3
p = MB.Polynomial{B}(mono)
Expand Down Expand Up @@ -77,10 +90,10 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
const_poly = MB.Polynomial{B}(const_mono)
const_alg_el = MB.algebra_element(const_poly)
for other in (const_mono, 1, const_alg_el)
@test other + const_alg_el 2 * other
@test const_alg_el + other 2 * other
@test iszero(other - const_alg_el)
@test iszero(const_alg_el - other)
@test _test_op(+, other, const_alg_el) _test_op(*, 2, other)
@test _test_op(+, const_alg_el, other) _test_op(*, 2, other)
@test iszero(_test_op(-, other, const_alg_el))
@test iszero(_test_op(-, const_alg_el, other))
end
@test typeof(MB.sparse_coefficients(sum(x))) ==
MA.promote_operation(MB.sparse_coefficients, typeof(sum(x)))
Expand Down

0 comments on commit 1046d55

Please sign in to comment.