diff --git a/src/rewriters.jl b/src/rewriters.jl index 464898d36..003b55575 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -33,7 +33,7 @@ module Rewriters using SymbolicUtils: @timer using TermInterface -import SymbolicUtils: similarterm +import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough # Cache of printed rules to speed up @timer @@ -63,18 +63,24 @@ If(f, x) = IfElse(f, x, Empty()) struct Chain rws + stop_on_match::Bool end +Chain(rws) = Chain(rws, false) function (rw::Chain)(x) for f in rw.rws y = @timer cached_repr(f) f(x) + if rw.stop_on_match && !isnothing(y) && !isequal(y, x) + return y + end + if y !== nothing x = y end end return x -end +end instrument(c::Chain, f) = Chain(map(x->instrument(x,f), c.rws)) struct RestartedChain{Cs} @@ -145,8 +151,8 @@ function (rw::FixpointNoCycle)(x) f = rw.rw push!(rw.hist, hash(x)) y = @timer cached_repr(f) f(x) - while x !== y && hash(x) ∉ hist - if y === nothing + while x !== y && hash(x) ∉ rw.hist + if y === nothing empty!(rw.hist) return x end @@ -195,9 +201,10 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} if ord === :pre x = p.rw(x) end - if iscall(x) - x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x))) - end + + x = p.similarterm(x, operation(x), map(PassThrough(p), + unsorted_arguments(x)), metadata=metadata(x)) + return ord === :post ? p.rw(x) : x else return p.rw(x) @@ -219,7 +226,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} end end args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.similarterm(x, operation(x), args) + t = p.similarterm(x, operation(x), args, metadata=metadata(x)) end return ord === :post ? p.rw(t) : t else diff --git a/src/rule.jl b/src/rule.jl index a529979e7..05941b764 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -139,7 +139,7 @@ function (r::Rule)(term) try # n == 1 means that exactly one term of the input (term,) was matched - success(bindings, n) = n == 1 ? (@timer "RHS" rhs(bindings)) : nothing + success(bindings, n) = n == 1 ? (@timer "RHS" rhs(assoc(bindings, :MATCH, term))) : nothing return r.matcher(success, (term,), EMPTY_IMMUTABLE_DICT) catch err throw(RuleRewriteError(r, term)) diff --git a/test/rewrite.jl b/test/rewrite.jl index f39b4e58b..3bb2621e3 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -76,3 +76,32 @@ using SymbolicUtils: @capture @eqtest f(b^b) == b @test f(b+b) == nothing end + +@testset "Rewriter tweaks #548" begin + struct MetaData end + ex = a + b + ex = setmetadata(ex, MetaData, :metadata) + ex1 = ex + c + + @test SymbolicUtils.isterm(ex1) + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + + ex = a + ex = setmetadata(ex, MetaData, :metadata) + ex1 = ex + b + + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + + ex = a * b + ex = setmetadata(ex, MetaData, :metadata) + ex1 = ex * c + + @test SymbolicUtils.isterm(ex1) + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + + ex = a + ex = setmetadata(ex, MetaData, :metadata) + ex1 = ex * b + + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata +end \ No newline at end of file