diff --git a/src/MultivariateBases.jl b/src/MultivariateBases.jl index 943268e..74590f4 100644 --- a/src/MultivariateBases.jl +++ b/src/MultivariateBases.jl @@ -25,7 +25,10 @@ end SA.basis(a::Algebra) = a.basis #Base.:(==)(::Algebra{BT1,B1,M}, ::Algebra{BT2,B2,M}) where {BT1,B1,BT2,B2,M} = true -#Base.:(==)(::Algebra, ::Algebra) = false +function Base.:(==)(a::Algebra, b::Algebra) + # `===` is a shortcut for speedup + return a.basis === b.basis || a.basis == b.basis +end function Base.show(io::IO, ::Algebra{BT,B}) where {BT,B} ioc = IOContext(io, :limit => true, :compact => true) @@ -72,13 +75,6 @@ function MA.promote_operation( return Algebra{BT,B,M} end -const _APL = MP.AbstractPolynomialLike -# We don't define it for all `AlgebraElement` as this would be type piracy -const _AE = SA.AlgebraElement{<:Algebra} - -Base.:(+)(p::_APL, q::_AE) = +(p, MP.polynomial(q)) -Base.:(+)(p::_AE, q::_APL) = +(MP.polynomial(p), q) -Base.:(-)(p::_APL, q::_AE) = -(p, MP.polynomial(q)) -Base.:(-)(p::_AE, q::_APL) = -(MP.polynomial(p), q) +include("arithmetic.jl") end # module diff --git a/src/arithmetic.jl b/src/arithmetic.jl new file mode 100644 index 0000000..8c1c5ea --- /dev/null +++ b/src/arithmetic.jl @@ -0,0 +1,27 @@ +const _APL = MP.AbstractPolynomialLike +# We don't define it for all `AlgebraElement` as this would be type piracy +const _AE = SA.AlgebraElement{<:Algebra} + +Base.:(+)(p::_APL, q::_AE) = +(p, MP.polynomial(q)) +Base.:(+)(p::_AE, q::_APL) = +(MP.polynomial(p), q) +Base.:(-)(p::_APL, q::_AE) = -(p, MP.polynomial(q)) +Base.:(-)(p::_AE, q::_APL) = -(MP.polynomial(p), q) + +Base.:(+)(p, q::_AE) = +(constant_algebra_element(typeof(SA.basis(q)), p), q) +function Base.:(+)(p::_AE, q) + return +(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q)) +end +function Base.:(-)(p, q::_AE) + return -(constant_algebra_element(typeof(SA.basis(q)), p), MP.polynomial(q)) +end +function Base.:(-)(p::_AE, q) + return -(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q)) +end + +function Base.:(+)(p::_AE, q::_AE) + return MA.operate_to!(SA._preallocate_output(+, p, q), +, p, q) +end + +function Base.:(-)(p::_AE, q::_AE) + return MA.operate_to!(SA._preallocate_output(-, p, q), -, p, q) +end diff --git a/src/chebyshev.jl b/src/chebyshev.jl index fbd4285..7a5dc77 100644 --- a/src/chebyshev.jl +++ b/src/chebyshev.jl @@ -23,7 +23,7 @@ struct ChebyshevFirstKind <: AbstractChebyshev end const Chebyshev = ChebyshevFirstKind # https://en.wikipedia.org/wiki/Chebyshev_polynomials#Properties -# T_n * T_m = T_{n + m} / 2 + T_{|n - m|} / 2 +# `T_n * T_m = T_{n + m} / 2 + T_{|n - m|} / 2` function (::Mul{Chebyshev})(a::MP.AbstractMonomial, b::MP.AbstractMonomial) terms = [MP.term(1 // 1, MP.constant_monomial(a * b))] vars_a = MP.variables(a) diff --git a/src/monomial.jl b/src/monomial.jl index 38d2465..44d6f2d 100644 --- a/src/monomial.jl +++ b/src/monomial.jl @@ -198,23 +198,23 @@ function algebra_element(f::Function, basis::SubBasis) return algebra_element(map(f, eachindex(basis)), basis) end -function constant_algebra_element( - ::Type{FullBasis{B,M}}, - ::Type{T}, -) where {B,M,T} +_one_if_type(α) = α +_one_if_type(::Type{T}) where {T} = one(T) + +function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M} return algebra_element( sparse_coefficients( - MP.polynomial(MP.term(one(T), MP.constant_monomial(M))), + MP.polynomial(MP.term(_one_if_type(α), MP.constant_monomial(M))), ), FullBasis{B,M}(), ) end -function constant_algebra_element( - ::Type{<:SubBasis{B,M}}, - ::Type{T}, -) where {B,M,T} - return algebra_element([one(T)], SubBasis{B}([MP.constant_monomial(M)])) +function constant_algebra_element(::Type{<:SubBasis{B,M}}, α) where {B,M} + return algebra_element( + [_one_if_type(α)], + SubBasis{B}([MP.constant_monomial(M)]), + ) end function _show(io::IO, mime::MIME, basis::SubBasis{B}) where {B} diff --git a/src/orthogonal.jl b/src/orthogonal.jl index 9307c0e..8962297 100644 --- a/src/orthogonal.jl +++ b/src/orthogonal.jl @@ -190,13 +190,15 @@ function SA.coeffs( end function SA.coeffs( - p::Polynomial{B}, + p::Polynomial{B,M}, ::FullBasis{Monomial}, -) where {B<:AbstractMultipleOrthogonal} +) where {B<:AbstractMultipleOrthogonal,M} return sparse_coefficients( prod( - univariate_orthogonal_basis(B, var, deg)[deg+1] for - (var, deg) in MP.powers(p.monomial) - ), + MP.powers(p.monomial); + init = MP.constant_monomial(M), + ) do (var, deg) + return univariate_orthogonal_basis(B, var, deg)[deg+1] + end, ) end diff --git a/src/scaled.jl b/src/scaled.jl index 60403d9..b5d6c7a 100644 --- a/src/scaled.jl +++ b/src/scaled.jl @@ -70,7 +70,9 @@ function Base.promote_rule( end function scaling(m::MP.AbstractMonomial) - return √(factorial(MP.degree(m)) / prod(factorial, MP.exponents(m))) + return √( + factorial(MP.degree(m)) / prod(factorial, MP.exponents(m); init = 1), + ) end unscale_coef(t::MP.AbstractTerm) = MP.coefficient(t) / scaling(MP.monomial(t)) function SA.coeffs( diff --git a/test/runtests.jl b/test/runtests.jl index 041149f..6fd543e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,10 +74,14 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree) _wrap(MB.SA.trim_LaTeX(mime, sprint(show, mime, p.monomial))) * " \$\$" const_mono = constant_monomial(prod(x)) - @test const_mono + MB.algebra_element(MB.Polynomial{B}(const_mono)) == 2 - @test MB.algebra_element(MB.Polynomial{B}(const_mono)) + const_mono == 2 - @test iszero(const_mono - MB.algebra_element(MB.Polynomial{B}(const_mono))) - @test iszero(MB.algebra_element(MB.Polynomial{B}(const_mono)) - const_mono) + const_poly = MB.Polynomial{B}(const_mono) + const_alg_el = MB.algebra_element(const_poly) + for other in (const_mono, 1, const_alg_el) + @test other + const_alg_el ≈ 2 * other + @test const_alg_el + other ≈ 2 * other + @test iszero(other - const_alg_el) + @test iszero(const_alg_el - other) + end @test typeof(MB.sparse_coefficients(sum(x))) == MA.promote_operation(MB.sparse_coefficients, typeof(sum(x))) @test typeof(MB.algebra_element(sum(x))) ==