diff --git a/src/arithmetic.jl b/src/arithmetic.jl index ad8c1f3..3ff2693 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -24,20 +24,28 @@ end for op in [:+, :-] @eval begin function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P,Q<:_AE} + I = MA.promote_operation(implicit, Q) return MA.promote_operation( $op, - constant_algebra_element_type(MA.promote_operation(SA.basis, Q), P), - Q, + constant_algebra_element_type(MA.promote_operation(SA.basis, I), P), + I, ) end - Base.$op(p, q::_AE) = $op(constant_algebra_element(typeof(SA.basis(q)), p), q) + function Base.$op(p, q::_AE) + i = implicit(q) + return $op(constant_algebra_element(typeof(SA.basis(i)), p), i) + end function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P<:_AE,Q} + I = MA.promote_operation(implicit, P) return MA.promote_operation( $op, - P, - constant_algebra_element_type(MA.promote_operation(SA.basis, P), Q), + I, + constant_algebra_element_type(MA.promote_operation(SA.basis, I), Q), ) end - Base.$op(p::_AE, q) = $op(p, constant_algebra_element(typeof(SA.basis(p)), q)) + function Base.$op(p::_AE, q) + i = implicit(p) + return $op(i, constant_algebra_element(typeof(SA.basis(i)), q)) + end end end diff --git a/src/monomial.jl b/src/monomial.jl index 714bcd1..c6de2d9 100644 --- a/src/monomial.jl +++ b/src/monomial.jl @@ -127,6 +127,21 @@ end implicit_basis(::SubBasis{B,M}) where {B,M} = FullBasis{B,M}() implicit_basis(basis::FullBasis) = basis +function implicit(a::SA.AlgebraElement) + basis = implicit_basis(SA.basis(a)) + return algebra_element(SA.coeffs(a, basis), basis) +end + +function MA.promote_operation( + ::typeof(implicit), + ::Type{E}, +) where {AG,T,E<:SA.AlgebraElement{AG,T}} + BT = MA.promote_operation(implicit_basis, MA.promote_operation(SA.basis, E)) + A = MA.promote_operation(algebra, BT) + M = MP.monomial_type(BT) + return SA.AlgebraElement{A,T,SA.SparseCoefficients{M,T,Vector{M},Vector{T}}} +end + function MA.promote_operation( ::typeof(implicit_basis), ::Type{<:Union{FullBasis{B,M},SubBasis{B,M}}}, @@ -393,14 +408,25 @@ function MP.polynomial_type(::Type{FullBasis{B,M}}, ::Type{T}) where {T,B,M} return MP.polynomial_type(M, _promote_coef(T, B)) end +_vec(v::Vector) = v +_vec(v::AbstractVector) = collect(v) + # Adapted from SA to incorporate `_promote_coef` function SA.coeffs( cfs, - source::MonomialIndexedBasis{B}, - target::MonomialIndexedBasis{Monomial}, -) where {B} + source::MonomialIndexedBasis{B1}, + target::MonomialIndexedBasis{B2}, +) where {B1,B2} source === target && return cfs source == target && return cfs - res = SA.zero_coeffs(_promote_coef(valtype(cfs), B), target) - return SA.coeffs!(res, cfs, source, target) + if B1 === B2 && target isa FullBasis + # The defaults initialize to zero and then sums which promotes + # `JuMP.VariableRef` to `JuMP.AffExpr` + return SA.SparseCoefficients(_vec(source.monomials), _vec(cfs)) + elseif B2 === Monomial + res = SA.zero_coeffs(_promote_coef(valtype(cfs), B1), target) + return SA.coeffs!(res, cfs, source, target) + else + error("Convertion from $source to $target not implemented yet") + end end diff --git a/test/runtests.jl b/test/runtests.jl index 23c3455..31cc81a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,9 +7,9 @@ const MB = MultivariateBases using LinearAlgebra using DynamicPolynomials -function _test_op(op, a, b) - result = @inferred op(a, b) - @test typeof(result) == MA.promote_operation(op, typeof(a), typeof(b)) +function _test_op(op, args...) + result = @inferred op(args...) + @test typeof(result) == MA.promote_operation(op, typeof.(args)...) return result end @@ -53,6 +53,8 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree) @test length(empty_basis(typeof(basis))) == 0 @test polynomial_type(basis, Float64) == polynomial_type(x[1], Float64) #@test polynomial(i -> 0.0, basis) isa polynomial_type(basis, Float64) + a = MB.algebra_element(ones(length(basis)), basis) + _test_op(MB.implicit, a) end mono = x[1]^2 * x[2]^3 p = MB.Polynomial{B}(mono)