Skip to content

Commit

Permalink
core: Make PatternRewriter a Builder
Browse files Browse the repository at this point in the history
This allows us to now use the ImplicitBuilder on PatternRewriter,
or to use `insert` using an InsertPoint.

stack-info: PR: #3540, branch: math-fehr/stack/4
  • Loading branch information
math-fehr committed Nov 29, 2024
1 parent 7ce487c commit a8b1c12
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 17 deletions.
40 changes: 40 additions & 0 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from conftest import assert_print_op

from xdsl.builder import ImplicitBuilder
from xdsl.context import MLContext
from xdsl.dialects import test
from xdsl.dialects.arith import AddiOp, Arith, ConstantOp, MuliOp
Expand All @@ -14,6 +15,7 @@
IntegerType,
ModuleOp,
StringAttr,
UnitAttr,
i32,
i64,
)
Expand Down Expand Up @@ -1411,6 +1413,44 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
)


def test_pattern_rewriter_as_op_builder():
"""Test that the PatternRewriter works as an OpBuilder."""
prog = """
"builtin.module"() ({
"test.op"() : () -> ()
"test.op"() {"nomatch"} : () -> ()
"test.op"() : () -> ()
}) : () -> ()"""
expected = """
"builtin.module"() ({
"test.op"() {"inserted"} : () -> ()
"test.op"() {"replaced"} : () -> ()
"test.op"() {"nomatch"} : () -> ()
"test.op"() {"inserted"} : () -> ()
"test.op"() {"replaced"} : () -> ()
}) : () -> ()"""

class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if "nomatch" in op.attributes:
return
with ImplicitBuilder(rewriter):
test.TestOp.create(attributes={"inserted": UnitAttr()})
rewriter.replace_matched_op(
test.TestOp.create(attributes={"replaced": UnitAttr()})
)

rewrite_and_compare(
prog,
expected,
PatternRewriteWalker(Rewrite(), apply_recursively=False),
op_inserted=4,
op_removed=2,
op_replaced=2,
)


def test_type_conversion():
"""Test rewriter on ops without results"""
prog = """\
Expand Down
6 changes: 3 additions & 3 deletions xdsl/backend/riscv/riscv_scf_to_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
yield_op = body.last_op
assert isinstance(yield_op, riscv_scf.YieldOp)

body.insert_ops_after(
rewriter.insert_op(
[
riscv.AddOp(get_loop_var, op.step, rd=loop_var_reg),
riscv.BltOp(get_loop_var, op.ub, scf_body),
riscv.LabelOp(scf_body_end),
],
yield_op,
InsertPoint.after(yield_op),
)
body.erase_op(yield_op)
rewriter.erase_op(yield_op)

# We know that the body is not empty now.
assert body.first_op is not None
Expand Down
18 changes: 14 additions & 4 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing_extensions import deprecated

from xdsl.builder import BuilderListener
from xdsl.builder import Builder, BuilderListener
from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, ModuleOp
from xdsl.ir import (
Attribute,
Expand Down Expand Up @@ -77,8 +77,8 @@ def extend_from_listener(self, listener: BuilderListener | PatternRewriterListen
)


@dataclass(eq=False)
class PatternRewriter(PatternRewriterListener):
@dataclass(eq=False, init=False)
class PatternRewriter(Builder, PatternRewriterListener):
"""
A rewriter used during pattern matching.
Once an operation is matched, this rewriter is used to apply
Expand All @@ -91,6 +91,11 @@ class PatternRewriter(PatternRewriterListener):
has_done_action: bool = field(default=False, init=False)
"""Has the rewriter done any action during the current match."""

def __init__(self, current_operation: Operation):
PatternRewriterListener.__init__(self)
self.current_operation = current_operation
Builder.__init__(self, InsertPoint.before(current_operation))

def insert_op(
self, op: Operation | Sequence[Operation], insertion_point: InsertPoint
):
Expand Down Expand Up @@ -726,7 +731,11 @@ def _handle_operation_removal(self, op: Operation) -> None:
"""Handle removal of an operation."""
if self.apply_recursively:
self._add_operands_to_worklist(op.operands)
self._worklist.remove(op)
if op.regions:
for sub_op in op.walk():
self._worklist.remove(sub_op)
else:
self._worklist.remove(op)

def _handle_operation_modification(self, op: Operation) -> None:
"""Handle modification of an operation."""
Expand Down Expand Up @@ -829,6 +838,7 @@ def _process_worklist(self, listener: PatternRewriterListener) -> bool:
# Reset the rewriter on `op`
rewriter.has_done_action = False
rewriter.current_operation = op
rewriter.insertion_point = InsertPoint.before(op)

# Apply the pattern on the operation
try:
Expand Down
8 changes: 4 additions & 4 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def erase_op(op: Operation, safe_erase: bool = True):
If safe_erase is True, check that the operation has no uses.
Otherwise, replace its uses with ErasedSSAValue.
"""
assert op.parent is not None, "Cannot erase an operation that has no parents"

block = op.parent
block.erase_op(op, safe_erase=safe_erase)
if (block := op.parent) is not None:
block.erase_op(op, safe_erase=safe_erase)
else:
op.erase(safe_erase=safe_erase)

@staticmethod
def replace_op(
Expand Down
23 changes: 17 additions & 6 deletions xdsl/transforms/convert_scf_to_openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint


@dataclass
Expand Down Expand Up @@ -43,7 +44,8 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
regions=[Region(Block())],
operands=[[], [], [], [], [], []],
)
with ImplicitBuilder(parallel.region):
rewriter.insertion_point = InsertPoint.at_end(parallel.region.block)
with ImplicitBuilder(rewriter):
if self.chunk is None:
chunk_op = []
else:
Expand All @@ -65,7 +67,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
omp.ScheduleKind(self.schedule)
)
omp.TerminatorOp()
with ImplicitBuilder(wsloop.body):

rewriter.insertion_point = InsertPoint.at_end(wsloop.body.block)
with ImplicitBuilder(rewriter):
loop_nest = omp.LoopNestOp(
operands=[
loop.lowerBound[:collapse],
Expand All @@ -75,15 +79,21 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
regions=[Region(Block(arg_types=[IndexType()] * collapse))],
)
omp.TerminatorOp()
with ImplicitBuilder(loop_nest.body):

rewriter.insertion_point = InsertPoint.at_end(loop_nest.body.block)
with ImplicitBuilder(rewriter):
scope = memref.AllocaScopeOp(result_types=[[]], regions=[Region(Block())])
omp.YieldOp()
with ImplicitBuilder(scope.scope):

rewriter.insertion_point = InsertPoint.at_end(scope.scope.block)
with ImplicitBuilder(rewriter):
scope_terminator = memref.AllocaScopeReturnOp(operands=[[]])

for newarg, oldarg in zip(
loop_nest.body.block.args, loop.body.block.args[:collapse]
):
oldarg.replace_by(newarg)

for _ in range(collapse):
loop.body.block.erase_arg(loop.body.block.args[0])
if collapse < len(loop.lowerBound):
Expand All @@ -96,8 +106,9 @@ def match_and_rewrite(self, loop: scf.ParallelOp, rewriter: PatternRewriter, /):
new_ops = [new_loop]
else:
new_ops = [loop.body.block.detach_op(o) for o in loop.body.block.ops]
new_ops.pop()
scope.scope.block.insert_ops_before(new_ops, scope_terminator)
last_op = new_ops.pop()
rewriter.erase_op(last_op)
rewriter.insert_op(new_ops, InsertPoint.before(scope_terminator))

rewriter.replace_matched_op(parallel)

Expand Down

0 comments on commit a8b1c12

Please sign in to comment.