Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Various tweaks to the Rewriters #548

Merged
merged 24 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5e883dd
forward metadata in Walk
shashi Aug 8, 2023
657bd9b
create special slot ~MATCH which is only defined on the RHS and holds…
shashi Aug 8, 2023
80bf968
Merge branch 's/inspect' into s/rewriter-tweaks
shashi Aug 9, 2023
9a6e04e
allow early stop in Chain
shashi Aug 10, 2023
5de0612
add inspect for non-symbolic objects so user code does not have to check
shashi Aug 30, 2023
3b53344
prevent canonicalization and hence destruction of terms with metadata
shashi Aug 31, 2023
1a89dd5
fix check in +
shashi Sep 2, 2023
1517c40
fix typo
shashi Sep 22, 2023
9ce7c76
forward metadata in Walk
shashi Aug 8, 2023
05c87d4
create special slot ~MATCH which is only defined on the RHS and holds…
shashi Aug 8, 2023
4885f7b
allow early stop in Chain
shashi Aug 10, 2023
8c3b506
prevent canonicalization and hence destruction of terms with metadata
shashi Aug 31, 2023
1712ba9
fix check in +
shashi Sep 2, 2023
a1b8cf9
fix typo
shashi Sep 22, 2023
93545fb
remove redundant and wrpng nometa check
shashi Mar 15, 2024
5fa8545
messy rebase
Vaibhavdixit02 Apr 12, 2024
ec0383d
Messy rebase
Vaibhavdixit02 Apr 25, 2024
3680e8a
Merge branch 'master' into s/rewriter-tweaks
Vaibhavdixit02 Apr 25, 2024
a9867cc
Merge remote-tracking branch 'origin/s/rewriter-tweaks' into s/rewrit…
shashi Apr 26, 2024
0fd9fbe
left out line
shashi Apr 26, 2024
993d1fa
comment out repeated istree check
Vaibhavdixit02 May 6, 2024
f387fe6
Bring back istree check in walking
Vaibhavdixit02 May 6, 2024
6f76a28
Add test and fix branch in walk
Vaibhavdixit02 May 6, 2024
1457894
more test
Vaibhavdixit02 May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module Rewriters
using SymbolicUtils: @timer
using TermInterface

import SymbolicUtils: similarterm
import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough

# Cache of printed rules to speed up @timer
Expand Down Expand Up @@ -63,18 +63,24 @@ If(f, x) = IfElse(f, x, Empty())

struct Chain
rws
stop_on_match::Bool
end
Chain(rws) = Chain(rws, false)

function (rw::Chain)(x)
for f in rw.rws
y = @timer cached_repr(f) f(x)
if rw.stop_on_match && !isnothing(y) && !isequal(y, x)
return y
end

if y !== nothing
x = y
end
end
return x
end

end
instrument(c::Chain, f) = Chain(map(x->instrument(x,f), c.rws))

struct RestartedChain{Cs}
Expand Down Expand Up @@ -145,8 +151,8 @@ function (rw::FixpointNoCycle)(x)
f = rw.rw
push!(rw.hist, hash(x))
y = @timer cached_repr(f) f(x)
while x !== y && hash(x) ∉ hist
if y === nothing
while x !== y && hash(x) ∉ rw.hist
if y === nothing
empty!(rw.hist)
return x
end
Expand Down Expand Up @@ -195,9 +201,10 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
if ord === :pre
x = p.rw(x)
end
if iscall(x)
x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)))
end

x = p.similarterm(x, operation(x), map(PassThrough(p),
unsorted_arguments(x)), metadata=metadata(x))

return ord === :post ? p.rw(x) : x
else
return p.rw(x)
Expand All @@ -219,7 +226,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
t = p.similarterm(x, operation(x), args)
t = p.similarterm(x, operation(x), args, metadata=metadata(x))
end
return ord === :post ? p.rw(t) : t
else
Expand Down
2 changes: 1 addition & 1 deletion src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function (r::Rule)(term)

try
# n == 1 means that exactly one term of the input (term,) was matched
success(bindings, n) = n == 1 ? (@timer "RHS" rhs(bindings)) : nothing
success(bindings, n) = n == 1 ? (@timer "RHS" rhs(assoc(bindings, :MATCH, term))) : nothing
return r.matcher(success, (term,), EMPTY_IMMUTABLE_DICT)
catch err
throw(RuleRewriteError(r, term))
Expand Down
29 changes: 29 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,32 @@ using SymbolicUtils: @capture
@eqtest f(b^b) == b
@test f(b+b) == nothing
end

@testset "Rewriter tweaks #548" begin
struct MetaData end
ex = a + b
ex = setmetadata(ex, MetaData, :metadata)
ex1 = ex + c

@test SymbolicUtils.isterm(ex1)
@test getmetadata(arguments(ex1)[1], MetaData) == :metadata

ex = a
ex = setmetadata(ex, MetaData, :metadata)
ex1 = ex + b

@test getmetadata(arguments(ex1)[1], MetaData) == :metadata

ex = a * b
ex = setmetadata(ex, MetaData, :metadata)
ex1 = ex * c

@test SymbolicUtils.isterm(ex1)
@test getmetadata(arguments(ex1)[1], MetaData) == :metadata

ex = a
ex = setmetadata(ex, MetaData, :metadata)
ex1 = ex * b

@test getmetadata(arguments(ex1)[1], MetaData) == :metadata
end
Loading