From b01b6689d7a2008999654d4d7b0fd11365122c36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 22 Jun 2023 18:02:58 +0200 Subject: [PATCH] Fixes for SIMDPolynomials (#267) * Fixes for SIMDPolynomials * Fix format --- src/monomial.jl | 3 ++- src/show.jl | 6 ++--- src/substitution.jl | 54 ++++++++++++++++++++++++++++++++++++++------ test/substitution.jl | 10 ++++++++ 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/src/monomial.jl b/src/monomial.jl index 783a9bae..db40488e 100644 --- a/src/monomial.jl +++ b/src/monomial.jl @@ -101,7 +101,8 @@ end Returns whether the monomial of `t` is constant. """ -isconstant(t::AbstractTermLike) = all(iszero, exponents(t)) +isconstant(t::AbstractTerm) = isconstant(monomial(t)) +isconstant(t::AbstractMonomialLike) = all(iszero, exponents(t)) isconstant(v::AbstractVariable) = false """ diff --git a/src/show.jl b/src/show.jl index 7f573e14..92df2050 100644 --- a/src/show.jl +++ b/src/show.jl @@ -34,10 +34,10 @@ function _show(io::IO, mime::MIME, var::AbstractVariable) print_subscript(io, mime, indices) end end -function _show(io::IO, mime::MIME"text/print", var::AbstractVariable) - return print(io, name(var)) -end +function print_subscript(io::IO, ::MIME"text/print", index) + return print(io, "[", join(index, ","), "]") +end function print_subscript(io::IO, ::MIME"text/latex", index) return print(io, "_{", join(index, ","), "}") end diff --git a/src/substitution.jl b/src/substitution.jl index c05766c9..f370bfe1 100644 --- a/src/substitution.jl +++ b/src/substitution.jl @@ -81,8 +81,39 @@ function powersubstitute( ) return powersubstitute(st, s, p) * powersubstitute(st, s, p2...) end + +function _promote_subs(S, T, s::Substitution) + # `T` is a constant + return T +end + +function _promote_subs(S, T::Type{<:Union{RationalPoly,_APL}}, s::Substitution) + return MA.promote_operation(substitute, S, T, typeof(s)) +end + +function _promote_subs(S, T, s::AbstractMultiSubstitution) + return _promote_subs( + S, + T, + pair_zip(_monomial_vector_to_variable_tuple(s))..., + ) +end + +function _promote_subs( + S, + T, + head::AbstractSubstitution, + tail::Vararg{AbstractSubstitution,N}, +) where {N} + return _promote_subs(S, _promote_subs(S, T, head), tail...) +end + function substitute(st::_AST, m::AbstractMonomial, s::Substitutions) - return powersubstitute(st, s, powers(m)...) + if isconstant(m) + return one(_promote_subs(typeof(st), typeof(m), s...)) + else + return powersubstitute(st, s, powers(m)...) + end end ## Terms @@ -92,11 +123,20 @@ end function MA.promote_operation( ::typeof(substitute), - ::Type{Subs}, + ::Type{Eval}, + ::Type{M}, + ::Type{Pair{V,T}}, +) where {M<:AbstractMonomial,V<:AbstractVariable,T} + return MA.promote_operation(*, T, T) +end + +function MA.promote_operation( + ::typeof(substitute), + ::Type{S}, ::Type{T}, args::Vararg{Type,N}, -) where {T<:AbstractTerm,N} - M = MA.promote_operation(substitute, Subs, monomial_type(T), args...) +) where {S<:AbstractSubstitutionType,T<:AbstractTerm,N} + M = MA.promote_operation(substitute, S, monomial_type(T), args...) U = coefficient_type(T) return MA.promote_operation(*, U, M) end @@ -121,11 +161,11 @@ end function MA.promote_operation( ::typeof(substitute), - ::Type{Subs}, + ::Type{S}, ::Type{P}, args::Vararg{Type,N}, -) where {P<:AbstractPolynomial,N} - T = MA.promote_operation(substitute, Subs, term_type(P), args...) +) where {S<:AbstractSubstitutionType,P<:AbstractPolynomial,N} + T = MA.promote_operation(substitute, S, term_type(P), args...) return MA.promote_operation(+, T, T) end diff --git a/test/substitution.jl b/test/substitution.jl index 76488ef1..3e0deb1e 100644 --- a/test/substitution.jl +++ b/test/substitution.jl @@ -1,5 +1,8 @@ +using Test import Test: @inferred +import MutableArithmetics as MA + @testset "Substitution" begin Mod.@polyvar x[1:3] @@ -28,6 +31,8 @@ import Test: @inferred p = x[1] + x[2] + 2 * x[1]^2 + 3 * x[1] * x[2]^2 #@inferred p((x[1], x[2]) => (1.0, 2.0)) + m = x[1] * x[2] + @inferred subs(m, x[2] => 2.0) @inferred subs(p, x[2] => 2.0) @test subs(p, x[2] => 2.0) == 13x[1] + 2 + 2x[1]^2 @inferred subs(p, x[2] => x[1]) @@ -77,4 +82,9 @@ import Test: @inferred @test t == @inferred subs(t, x => 1.0x) @test t == @inferred subs(t, x => 1.0) end + + for (p, s) in [(x^1, x => 2x)] + @test MA.promote_operation(substitute, Subs, typeof(p), typeof(s)) == + typeof(subs(p, s)) + end end