From 8cd371004ae3ac39b90b33d9cc7cf320b3ec85bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 21 Nov 2024 16:42:14 +0100 Subject: [PATCH] Refactor substitution --- src/substitution.jl | 23 ++++++++++++++++------- test/substitution.jl | 19 +++++++++++++++++-- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/substitution.jl b/src/substitution.jl index 8dfa0b74..4b695528 100644 --- a/src/substitution.jl +++ b/src/substitution.jl @@ -27,6 +27,7 @@ const AbstractMultiSubstitution = Union{ } const AbstractSubstitution = Union{Substitution,AbstractMultiSubstitution} const Substitutions = Tuple{Vararg{AbstractSubstitution}} +const _Substitutions = Tuple{Vararg{Substitution}} abstract type AbstractSubstitutionType end struct Subs <: AbstractSubstitutionType end @@ -40,12 +41,20 @@ is equivalent to: subs(polynomial, (x=>1, y=>2)) """ -function substitute(st::_AST, p::_APL, s::AbstractMultiSubstitution) - return substitute(st, p, _flatten_subs(s)) +function substitute_fallback(st::_AST, p::_APL, s::Substitutions) + return substitute_fallback(st, p, _flatten_subs(s...)) +end + +function substitute(st::_AST, p::_APL, s::AbstractSubstitution...) + return substitute(st, p, s) +end + +function substitute(st::_AST, p::_APL, s::Substitutions) + return substitute_fallback(st, p, s) end ## Variables -function substitute(st::_AST, v::AbstractVariable, s::Substitutions) +function substitute_fallback(st::_AST, v::AbstractVariable, s::Substitutions) return substitute(st, v, s...) end @@ -127,7 +136,7 @@ function power_promote( ) end -function substitute(st::_AST, m::AbstractMonomial, s::Substitutions) +function substitute_fallback(st::_AST, m::AbstractMonomial, s::Substitutions) if isconstant(m) return one(power_promote(typeof(st), variables(m), s)) else @@ -136,7 +145,7 @@ function substitute(st::_AST, m::AbstractMonomial, s::Substitutions) end ## Terms -function substitute(st::_AST, t::AbstractTerm, s::Substitutions) +function substitute_fallback(st::_AST, t::AbstractTerm, s::Substitutions) return coefficient(t) * substitute(st, monomial(t), s) end @@ -163,7 +172,7 @@ end ## Polynomials _polynomial(α) = α _polynomial(p::_APL) = polynomial(p) -function substitute(st::_AST, p::AbstractPolynomial, s::Substitutions) +function substitute_fallback(st::_AST, p::AbstractPolynomial, s::Substitutions) if iszero(p) _polynomial(substitute(st, zero_term(p), s)) else @@ -189,7 +198,7 @@ function MA.promote_operation( end ## Fallbacks -function substitute(st::_AST, p::_APL, s::Substitutions) +function substitute_fallback(st::_AST, p::_APL, s::Substitutions) return substitute(st, polynomial(p), s) end function substitute(st::_AST, q::RationalPoly, s::Substitutions) diff --git a/test/substitution.jl b/test/substitution.jl index dbc04921..2a0d1130 100644 --- a/test/substitution.jl +++ b/test/substitution.jl @@ -3,7 +3,9 @@ import Test: @inferred import MutableArithmetics as MA -@testset "Substitution" begin +Mod = DynamicPolynomials + +#@testset "Substitution" begin Mod.@polyvar x[1:3] @test subs(2, x[1] => 3) == 2 @@ -107,4 +109,17 @@ import MutableArithmetics as MA @test subs(F, x[3] => T(0)) == x[1] * x[2]^2 @test subs(F, x[3] => 0) == x[1] * x[2]^2 end -end + + function subs_alloc(x) + p = sum(x) + v = map(_ -> 1, x) + @inferred subs(p, x => v) + @time subs(p, x => v) + @time subs(p, x => v) + @time substitute(Eval(), p, x => v) + @time substitute(Eval(), p, x => v) + @time substitute(Eval(), p, v) + @time p(v) + end + subs_alloc(x) +#end