diff --git a/src/polyform.jl b/src/polyform.jl index 21e04ac9b..b80e7a9e3 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -175,11 +175,11 @@ isexpr(x::PolyForm) = true iscall(x::Type{<:PolyForm}) = true iscall(x::PolyForm) = true -function similarterm(t::PolyForm, f, args, symtype; metadata=nothing) - basic_similarterm(t, f, args, symtype; metadata=metadata) +function similarterm(::Type{<:PolyForm}, f, args, symtype, metadata) + basicsymbolic(t, f, args, symtype, metadata) end -function similarterm(::PolyForm, f::Union{typeof(*), typeof(+), typeof(^)}, - args, symtype; metadata=nothing) +function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, + args, symtype, metadata) f(args...) end diff --git a/src/rewriters.jl b/src/rewriters.jl index 81ae2dbe0..24b216422 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -33,7 +33,7 @@ module Rewriters using SymbolicUtils: @timer using TermInterface -import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count +import SymbolicUtils: istree, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough # Cache of printed rules to speed up @timer @@ -167,24 +167,36 @@ end struct Walk{ord, C, F, threaded} rw::C thread_cutoff::Int - similarterm::F + maketerm::F end function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded} irw = instrument(x.rw, f) - Walk{ord, typeof(irw), typeof(x.similarterm), threaded}(irw, + Walk{ord, typeof(irw), typeof(x.maketerm), threaded}(irw, x.thread_cutoff, - x.similarterm) + x.maketerm) end using .Threads -function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm) - Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm) +function compatmaker(similarterm, maketerm) + if similarterm isa Nothing + function (x, f, args, type=_promote_symtype(f, args); metadata) + maketerm(typeof(x), f, args, type, metadata) + end + else + Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm) + similarterm + end +end +function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) + maker = compatmaker(similarterm, maketerm) + Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) end -function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm) - Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm) +function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) + maker = compatmaker(similarterm, maketerm) + Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) end struct PassThrough{C} @@ -203,7 +215,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} end if istree(x) - x = p.similarterm(x, operation(x), map(PassThrough(p), + x = p.maketerm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)), metadata=metadata(x)) end @@ -228,7 +240,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} end end args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.similarterm(x, operation(x), args, metadata=metadata(x)) + t = p.maketerm(x, operation(x), args, metadata=metadata(x)) end return ord === :post ? p.rw(t) : t else