Skip to content

Commit

Permalink
Fixes for SIMDPolynomials (#267)
Browse files Browse the repository at this point in the history
* Fixes for SIMDPolynomials

* Fix format
  • Loading branch information
blegat authored Jun 22, 2023
1 parent bcc6d7a commit b01b668
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/monomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ end
Returns whether the monomial of `t` is constant.
"""
isconstant(t::AbstractTermLike) = all(iszero, exponents(t))
isconstant(t::AbstractTerm) = isconstant(monomial(t))
isconstant(t::AbstractMonomialLike) = all(iszero, exponents(t))
isconstant(v::AbstractVariable) = false

"""
Expand Down
6 changes: 3 additions & 3 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ function _show(io::IO, mime::MIME, var::AbstractVariable)
print_subscript(io, mime, indices)
end
end
function _show(io::IO, mime::MIME"text/print", var::AbstractVariable)
return print(io, name(var))
end

function print_subscript(io::IO, ::MIME"text/print", index)
return print(io, "[", join(index, ","), "]")
end
function print_subscript(io::IO, ::MIME"text/latex", index)
return print(io, "_{", join(index, ","), "}")
end
Expand Down
54 changes: 47 additions & 7 deletions src/substitution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,39 @@ function powersubstitute(
)
return powersubstitute(st, s, p) * powersubstitute(st, s, p2...)
end

function _promote_subs(S, T, s::Substitution)
# `T` is a constant
return T
end

function _promote_subs(S, T::Type{<:Union{RationalPoly,_APL}}, s::Substitution)
return MA.promote_operation(substitute, S, T, typeof(s))
end

function _promote_subs(S, T, s::AbstractMultiSubstitution)
return _promote_subs(
S,
T,
pair_zip(_monomial_vector_to_variable_tuple(s))...,
)
end

function _promote_subs(
S,
T,
head::AbstractSubstitution,
tail::Vararg{AbstractSubstitution,N},
) where {N}
return _promote_subs(S, _promote_subs(S, T, head), tail...)
end

function substitute(st::_AST, m::AbstractMonomial, s::Substitutions)
return powersubstitute(st, s, powers(m)...)
if isconstant(m)
return one(_promote_subs(typeof(st), typeof(m), s...))
else
return powersubstitute(st, s, powers(m)...)
end
end

## Terms
Expand All @@ -92,11 +123,20 @@ end

function MA.promote_operation(
::typeof(substitute),
::Type{Subs},
::Type{Eval},
::Type{M},
::Type{Pair{V,T}},
) where {M<:AbstractMonomial,V<:AbstractVariable,T}
return MA.promote_operation(*, T, T)
end

function MA.promote_operation(
::typeof(substitute),
::Type{S},
::Type{T},
args::Vararg{Type,N},
) where {T<:AbstractTerm,N}
M = MA.promote_operation(substitute, Subs, monomial_type(T), args...)
) where {S<:AbstractSubstitutionType,T<:AbstractTerm,N}
M = MA.promote_operation(substitute, S, monomial_type(T), args...)
U = coefficient_type(T)
return MA.promote_operation(*, U, M)
end
Expand All @@ -121,11 +161,11 @@ end

function MA.promote_operation(
::typeof(substitute),
::Type{Subs},
::Type{S},
::Type{P},
args::Vararg{Type,N},
) where {P<:AbstractPolynomial,N}
T = MA.promote_operation(substitute, Subs, term_type(P), args...)
) where {S<:AbstractSubstitutionType,P<:AbstractPolynomial,N}
T = MA.promote_operation(substitute, S, term_type(P), args...)
return MA.promote_operation(+, T, T)
end

Expand Down
10 changes: 10 additions & 0 deletions test/substitution.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Test
import Test: @inferred

import MutableArithmetics as MA

@testset "Substitution" begin
Mod.@polyvar x[1:3]

Expand Down Expand Up @@ -28,6 +31,8 @@ import Test: @inferred
p = x[1] + x[2] + 2 * x[1]^2 + 3 * x[1] * x[2]^2
#@inferred p((x[1], x[2]) => (1.0, 2.0))

m = x[1] * x[2]
@inferred subs(m, x[2] => 2.0)
@inferred subs(p, x[2] => 2.0)
@test subs(p, x[2] => 2.0) == 13x[1] + 2 + 2x[1]^2
@inferred subs(p, x[2] => x[1])
Expand Down Expand Up @@ -77,4 +82,9 @@ import Test: @inferred
@test t == @inferred subs(t, x => 1.0x)
@test t == @inferred subs(t, x => 1.0)
end

for (p, s) in [(x^1, x => 2x)]
@test MA.promote_operation(substitute, Subs, typeof(p), typeof(s)) ==
typeof(subs(p, s))
end
end

0 comments on commit b01b668

Please sign in to comment.