Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jul 2, 2024
1 parent 8f9cb74 commit 6ce35a9
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
20 changes: 14 additions & 6 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,28 @@ end
for op in [:+, :-]
@eval begin
function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P,Q<:_AE}
I = MA.promote_operation(implicit, Q)
return MA.promote_operation(
$op,
constant_algebra_element_type(MA.promote_operation(SA.basis, Q), P),
Q,
constant_algebra_element_type(MA.promote_operation(SA.basis, I), P),
I,
)
end
Base.$op(p, q::_AE) = $op(constant_algebra_element(typeof(SA.basis(q)), p), q)
function Base.$op(p, q::_AE)
i = implicit(q)
return $op(constant_algebra_element(typeof(SA.basis(i)), p), i)
end
function MA.promote_operation(::typeof($op), ::Type{P}, ::Type{Q}) where {P<:_AE,Q}
I = MA.promote_operation(implicit, P)
return MA.promote_operation(
$op,
P,
constant_algebra_element_type(MA.promote_operation(SA.basis, P), Q),
I,
constant_algebra_element_type(MA.promote_operation(SA.basis, I), Q),
)
end
Base.$op(p::_AE, q) = $op(p, constant_algebra_element(typeof(SA.basis(p)), q))
function Base.$op(p::_AE, q)
i = implicit(p)
return $op(i, constant_algebra_element(typeof(SA.basis(i)), q))
end
end
end
36 changes: 31 additions & 5 deletions src/monomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,21 @@ end
implicit_basis(::SubBasis{B,M}) where {B,M} = FullBasis{B,M}()
implicit_basis(basis::FullBasis) = basis

function implicit(a::SA.AlgebraElement)
basis = implicit_basis(SA.basis(a))
return algebra_element(SA.coeffs(a, basis), basis)
end

function MA.promote_operation(
::typeof(implicit),
::Type{E},
) where {AG,T,E<:SA.AlgebraElement{AG,T}}
BT = MA.promote_operation(implicit_basis, MA.promote_operation(SA.basis, E))
A = MA.promote_operation(algebra, BT)
M = MP.monomial_type(BT)
return SA.AlgebraElement{A,T,SA.SparseCoefficients{M,T,Vector{M},Vector{T}}}
end

function MA.promote_operation(
::typeof(implicit_basis),
::Type{<:Union{FullBasis{B,M},SubBasis{B,M}}},
Expand Down Expand Up @@ -393,14 +408,25 @@ function MP.polynomial_type(::Type{FullBasis{B,M}}, ::Type{T}) where {T,B,M}
return MP.polynomial_type(M, _promote_coef(T, B))
end

_vec(v::Vector) = v
_vec(v::AbstractVector) = collect(v)

# Adapted from SA to incorporate `_promote_coef`
function SA.coeffs(
cfs,
source::MonomialIndexedBasis{B},
target::MonomialIndexedBasis{Monomial},
) where {B}
source::MonomialIndexedBasis{B1},
target::MonomialIndexedBasis{B2},
) where {B1,B2}
source === target && return cfs
source == target && return cfs
res = SA.zero_coeffs(_promote_coef(valtype(cfs), B), target)
return SA.coeffs!(res, cfs, source, target)
if B1 === B2 && target isa FullBasis
# The defaults initialize to zero and then sums which promotes
# `JuMP.VariableRef` to `JuMP.AffExpr`
return SA.SparseCoefficients(_vec(source.monomials), _vec(cfs))
elseif B2 === Monomial
res = SA.zero_coeffs(_promote_coef(valtype(cfs), B1), target)
return SA.coeffs!(res, cfs, source, target)
else
error("Convertion from $source to $target not implemented yet")

Check warning on line 430 in src/monomial.jl

View check run for this annotation

Codecov / codecov/patch

src/monomial.jl#L430

Added line #L430 was not covered by tests
end
end
8 changes: 5 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ const MB = MultivariateBases
using LinearAlgebra
using DynamicPolynomials

function _test_op(op, a, b)
result = @inferred op(a, b)
@test typeof(result) == MA.promote_operation(op, typeof(a), typeof(b))
function _test_op(op, args...)
result = @inferred op(args...)
@test typeof(result) == MA.promote_operation(op, typeof.(args)...)
return result
end

Expand Down Expand Up @@ -53,6 +53,8 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
@test length(empty_basis(typeof(basis))) == 0
@test polynomial_type(basis, Float64) == polynomial_type(x[1], Float64)
#@test polynomial(i -> 0.0, basis) isa polynomial_type(basis, Float64)
a = MB.algebra_element(ones(length(basis)), basis)
_test_op(MB.implicit, a)
end
mono = x[1]^2 * x[2]^3
p = MB.Polynomial{B}(mono)
Expand Down

0 comments on commit 6ce35a9

Please sign in to comment.