Skip to content

Commit

Permalink
similarterm -> maketerm in Walk and PolyForm
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi committed May 30, 2024
1 parent 9e3b9de commit 37e54b9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ isexpr(x::PolyForm) = true
iscall(x::Type{<:PolyForm}) = true
iscall(x::PolyForm) = true

function similarterm(t::PolyForm, f, args, symtype; metadata=nothing)
basic_similarterm(t, f, args, symtype; metadata=metadata)
function similarterm(::Type{<:PolyForm}, f, args, symtype, metadata)
basicsymbolic(t, f, args, symtype, metadata)
end
function similarterm(::PolyForm, f::Union{typeof(*), typeof(+), typeof(^)},
args, symtype; metadata=nothing)
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)},
args, symtype, metadata)
f(args...)
end

Expand Down
32 changes: 22 additions & 10 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module Rewriters
using SymbolicUtils: @timer
using TermInterface

import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count
import SymbolicUtils: istree, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough

# Cache of printed rules to speed up @timer
Expand Down Expand Up @@ -167,24 +167,36 @@ end
struct Walk{ord, C, F, threaded}
rw::C
thread_cutoff::Int
similarterm::F
maketerm::F
end

function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded}
irw = instrument(x.rw, f)
Walk{ord, typeof(irw), typeof(x.similarterm), threaded}(irw,
Walk{ord, typeof(irw), typeof(x.maketerm), threaded}(irw,
x.thread_cutoff,
x.similarterm)
x.maketerm)
end

using .Threads

function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
function compatmaker(similarterm, maketerm)
if similarterm isa Nothing
function (x, f, args, type=_promote_symtype(f, args); metadata)
maketerm(typeof(x), f, args, type, metadata)
end
else
Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm)
similarterm
end
end
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
end

function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
end

struct PassThrough{C}
Expand All @@ -203,7 +215,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
end

if istree(x)
x = p.similarterm(x, operation(x), map(PassThrough(p),
x = p.maketerm(x, operation(x), map(PassThrough(p),
unsorted_arguments(x)), metadata=metadata(x))
end

Expand All @@ -228,7 +240,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
t = p.similarterm(x, operation(x), args, metadata=metadata(x))
t = p.maketerm(x, operation(x), args, metadata=metadata(x))
end
return ord === :post ? p.rw(t) : t
else
Expand Down

0 comments on commit 37e54b9

Please sign in to comment.