Skip to content

Commit

Permalink
Implement promote_operation
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jul 2, 2024
1 parent c041a60 commit 8f9cb74
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 25 deletions.
3 changes: 3 additions & 0 deletions src/MultivariateBases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ MP.monomial_type(::Type{<:Algebra{B}}) where {B} = MP.monomial_type(B)
function MP.polynomial_type(::Type{<:Algebra{B}}, ::Type{T}) where {B,T}
return MP.polynomial_type(B, T)
end
function MA.promote_operation(::typeof(SA.basis), ::Type{<:Algebra{B}}) where {B}
return B
end
SA.basis(a::Algebra) = a.basis

#Base.:(==)(::Algebra{BT1,B1,M}, ::Algebra{BT2,B2,M}) where {BT1,B1,BT2,B2,M} = true
Expand Down
58 changes: 37 additions & 21 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,42 @@ const _APL = MP.AbstractPolynomialLike
# We don't define it for all `AlgebraElement` as this would be type piracy
const _AE = SA.AlgebraElement{<:Algebra}

Base.:(+)(p::_APL, q::_AE) = +(p, MP.polynomial(q))
Base.:(+)(p::_AE, q::_APL) = +(MP.polynomial(p), q)
Base.:(-)(p::_APL, q::_AE) = -(p, MP.polynomial(q))
Base.:(-)(p::_AE, q::_APL) = -(MP.polynomial(p), q)

Base.:(+)(p, q::_AE) = +(constant_algebra_element(typeof(SA.basis(q)), p), q)
function Base.:(+)(p::_AE, q)
return +(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q))
end
function Base.:(-)(p, q::_AE)
return -(constant_algebra_element(typeof(SA.basis(q)), p), MP.polynomial(q))
end
function Base.:(-)(p::_AE, q)
return -(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q))
for op in [:+, :-, :*]
@eval begin
function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P<:_APL,Q<:_AE}
return MA.promote_operation($op, P, MP.polynomial_type(Q))
end
Base.$op(p::_APL, q::_AE) = $op(p, MP.polynomial(q))
function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P<:_AE,Q<:_APL}
return MA.promote_operation($op, MP.polynomial_type(P), Q)
end
Base.$op(p::_AE, q::_APL) = $op(MP.polynomial(p), q)
# Break ambiguity between the two defined below and the generic one in SA
function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P<:_AE,Q<:_AE}
return SA.algebra_promote_operation($op, P, Q)
end
function Base.$op(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output($op, p, q), $op, p, q)
end
end
end

function Base.:(+)(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output(+, p, q), +, p, q)
end

function Base.:(-)(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output(-, p, q), -, p, q)
for op in [:+, :-]
@eval begin
function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P,Q<:_AE}
return MA.promote_operation(
$op,
constant_algebra_element_type(MA.promote_operation(SA.basis, Q), P),
Q,
)
end
Base.$op(p, q::_AE) = $op(constant_algebra_element(typeof(SA.basis(q)), p), q)
function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P<:_AE,Q}
return MA.promote_operation(
$op,
P,
constant_algebra_element_type(MA.promote_operation(SA.basis, P), Q),
)
end
Base.$op(p::_AE, q) = $op(p, constant_algebra_element(typeof(SA.basis(p)), q))
end
end
16 changes: 16 additions & 0 deletions src/monomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ end
_one_if_type(α) = α
_one_if_type(::Type{T}) where {T} = one(T)

function constant_algebra_element_type(
::Type{BT},
::Type{T},
) where {B,M,BT<:FullBasis{B,M},T}
A = MA.promote_operation(algebra, BT)
return SA.AlgebraElement{A,T,SA.SparseCoefficients{M,T,Vector{M},Vector{T}}}
end

function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M}
return algebra_element(
sparse_coefficients(
Expand All @@ -210,6 +218,14 @@ function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M}
)
end

function constant_algebra_element_type(
::Type{B},
::Type{T},
) where {B<:SubBasis,T}
A = MA.promote_operation(algebra, B)
return SA.AlgebraElement{A,T,Vector{T}}
end

function constant_algebra_element(::Type{<:SubBasis{B,M}}, α) where {B,M}
return algebra_element(
[_one_if_type(α)],
Expand Down
14 changes: 10 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ const MB = MultivariateBases
using LinearAlgebra
using DynamicPolynomials

function _test_op(op, a, b)
result = @inferred op(a, b)
@test typeof(result) == MA.promote_operation(op, typeof(a), typeof(b))
return result
end

function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
@polyvar x[1:2]
M = typeof(prod(x))
Expand Down Expand Up @@ -77,10 +83,10 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
const_poly = MB.Polynomial{B}(const_mono)
const_alg_el = MB.algebra_element(const_poly)
for other in (const_mono, 1, const_alg_el)
@test other + const_alg_el 2 * other
@test const_alg_el + other 2 * other
@test iszero(other - const_alg_el)
@test iszero(const_alg_el - other)
@test _test_op(+, other, const_alg_el) _test_op(*, 2, other)
@test _test_op(+, const_alg_el, other) _test_op(*, 2, other)
@test iszero(_test_op(-, other, const_alg_el))
@test iszero(_test_op(-, const_alg_el, other))
end
@test typeof(MB.sparse_coefficients(sum(x))) ==
MA.promote_operation(MB.sparse_coefficients, typeof(sum(x)))
Expand Down

0 comments on commit 8f9cb74

Please sign in to comment.