From 39a2e0aac7eea7aa9a2604fdbe864b0abb80ef0c Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Fri, 22 Nov 2024 16:10:49 +0000 Subject: [PATCH] core: Walk regions for rewritings instead of ops Currently, the pattern rewriter walker walks an operation, and tries to apply patterns on all contained operations, including itself. This is not what we would want, as rewriting a toplevel operation is dangerous, if that operation does not have a parent. This commit changes the walker to walk regions instead of operations, so it is clear that the toplevel operation will not be modified. This is similar to what MLIR currently does in applyPatternsAndFoldGreedily. --- docs/database_example.ipynb | 4 +- docs/xdsl-introduction.ipynb | 4 +- .../pattern_rewriter/test_pattern_rewriter.py | 97 ++++++++++++------- xdsl/pattern_rewriter.py | 16 +-- xdsl/transforms/apply_pdl.py | 5 +- xdsl/transforms/convert_qref_to_qssa.py | 2 +- xdsl/transforms/convert_qssa_to_qref.py | 2 +- xdsl/transforms/convert_scf_to_cf.py | 2 +- .../convert_stencil_to_csl_stencil.py | 4 +- xdsl/transforms/dead_code_elimination.py | 8 +- xdsl/transforms/varith_transformations.py | 6 +- 11 files changed, 87 insertions(+), 63 deletions(-) diff --git a/docs/database_example.ipynb b/docs/database_example.ipynb index 57031933c1..700cc4d4fc 100644 --- a/docs/database_example.ipynb +++ b/docs/database_example.ipynb @@ -310,7 +310,7 @@ ")\n", "\n", "\n", - "walker.rewrite_op(sel)" + "walker.rewrite_region(sel.filter)" ] }, { @@ -387,7 +387,7 @@ " walk_reverse=False,\n", ")\n", "\n", - "walker.rewrite_op(sel)" + "walker.rewrite_region(sel.filter)" ] }, { diff --git a/docs/xdsl-introduction.ipynb b/docs/xdsl-introduction.ipynb index b5cb7aee7c..e5ae1afb45 100644 --- a/docs/xdsl-introduction.ipynb +++ b/docs/xdsl-introduction.ipynb @@ -762,7 +762,7 @@ " apply_recursively=True,\n", " walk_reverse=False,\n", ")\n", - "walker.rewrite_module(filtered)\n", + "walker.rewrite_region(filtered.filter)\n", "printer.print_op(filtered)" ] }, @@ -808,7 +808,7 @@ " apply_recursively=True,\n", " walk_reverse=False,\n", ")\n", - "walker.rewrite_module(filtered)\n", + "walker.rewrite_region(filtered.filter)\n", "printer.print_op(filtered)" ] }, diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index a7830d91c2..468fe200fa 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -377,17 +377,21 @@ def test_insert_op_at_start(): """Test rewrites where operations are inserted with a given position.""" prog = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" expected = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 - %1 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 + %1 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" class Rewrite(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, mod: ModuleOp, rewriter: PatternRewriter): + def match_and_rewrite(self, mod: test.TestOp, rewriter: PatternRewriter): new_cst = Constant.from_int_and_width(42, i32) rewriter.insert_op(new_cst, InsertPoint.at_start(mod.regions[0].blocks[0])) @@ -404,20 +408,24 @@ def test_insert_op_before(): """Test rewrites where operations are inserted before a given operation.""" prog = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" expected = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 - %1 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 + %1 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" class Rewrite(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, mod: ModuleOp, rewriter: PatternRewriter): + def match_and_rewrite(self, mod: test.TestOp, rewriter: PatternRewriter): new_cst = Constant.from_int_and_width(42, i32) - first_op = mod.ops.first + first_op = mod.regions[0].block.ops.first assert first_op is not None rewriter.insert_op(new_cst, InsertPoint.before(first_op)) @@ -433,20 +441,24 @@ def test_insert_op_after(): """Test rewrites where operations are inserted after a given operation.""" prog = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" expected = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 - %1 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + %1 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" class Rewrite(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, mod: ModuleOp, rewriter: PatternRewriter): + def match_and_rewrite(self, mod: test.TestOp, rewriter: PatternRewriter): new_cst = Constant.from_int_and_width(42, i32) - first_op = mod.ops.first + first_op = mod.regions[0].block.ops.first assert first_op is not None rewriter.insert_op(new_cst, InsertPoint.after(first_op)) @@ -592,17 +604,21 @@ def test_delete_inner_op(): """Test rewrites where an operation inside a region of the matched op is deleted.""" prog = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" expected = """"builtin.module"() ({ -^0: + "test.op"() ({ + ^0: + }) : () -> () }) : () -> ()""" class Rewrite(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter): - first_op = op.ops.first + def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): + first_op = op.regions[0].block.ops.first assert first_op is not None rewriter.erase_op(first_op) @@ -619,17 +635,21 @@ def test_replace_inner_op(): """Test rewrites where an operation inside a region of the matched op is deleted.""" prog = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 5 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" expected = """"builtin.module"() ({ - %0 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 + "test.op"() ({ + %0 = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 + }) : () -> () }) : () -> ()""" class Rewrite(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter): - first_op = op.ops.first + def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): + first_op = op.regions[0].block.ops.first assert first_op is not None rewriter.replace_op(first_op, [Constant.from_int_and_width(42, i32)]) @@ -1048,29 +1068,36 @@ def test_move_region_contents_to_new_regions(): prog = """\ "builtin.module"() ({ - %0 = "test.op"() : () -> !test.type<"int"> - %1 = "test.op"() ({ - ^0: - %2 = "test.op"() : () -> !test.type<"int"> - }) : () -> !test.type<"int"> + "test.op"() ({ + %0 = "test.op"() : () -> !test.type<"int"> + %1 = "test.op"() ({ + ^0: + %2 = "test.op"() : () -> !test.type<"int"> + }) : () -> !test.type<"int"> + }) : () -> () }) : () -> () """ expected = """\ "builtin.module"() ({ - %0 = "test.op"() : () -> !test.type<"int"> - %1 = "test.op"() ({ - }) : () -> !test.type<"int"> - %2 = "test.op"() ({ - %3 = "test.op"() : () -> !test.type<"int"> - }) : () -> !test.type<"int"> + "test.op"() ({ + %0 = "test.op"() : () -> !test.type<"int"> + %1 = "test.op"() ({ + }) : () -> !test.type<"int"> + %2 = "test.op"() ({ + %3 = "test.op"() : () -> !test.type<"int"> + }) : () -> !test.type<"int"> + }) : () -> () }) : () -> () """ class Rewrite(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter): - ops_iter = iter(op.ops) + def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter): + # Match the toplevel test.op + if not isinstance(op.parent_op(), ModuleOp): + return + ops_iter = iter(op.regions[0].block.ops) _ = next(ops_iter) # skip first op old_op = next(ops_iter) diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index 866b5e7f8b..a8282bc21a 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -690,7 +690,7 @@ class PatternRewriteWalker: That way, all uses are replaced before the definitions. """ - post_walk_func: Callable[[Operation, PatternRewriterListener], bool] | None = field( + post_walk_func: Callable[[Region, PatternRewriterListener], bool] | None = field( default=None ) """ @@ -773,19 +773,19 @@ def rewrite_module(self, module: ModuleOp) -> bool: Rewrite operations nested in the given operation by repeatedly applying the pattern. Returns `True` if the IR was mutated. """ - return self.rewrite_op(module) + return self.rewrite_region(module.body) - def rewrite_op(self, op: Operation) -> bool: + def rewrite_region(self, region: Region) -> bool: """ Rewrite operations nested in the given operation by repeatedly applying the pattern. Returns `True` if the IR was mutated. """ pattern_listener = self._get_rewriter_listener() - self._populate_worklist(op) + self._populate_worklist(region) op_was_modified = self._process_worklist(pattern_listener) if self.post_walk_func is not None: - op_was_modified |= self.post_walk_func(op, pattern_listener) + op_was_modified |= self.post_walk_func(region, pattern_listener) if not self.apply_recursively: return op_was_modified @@ -793,14 +793,14 @@ def rewrite_op(self, op: Operation) -> bool: result = op_was_modified while op_was_modified: - self._populate_worklist(op) + self._populate_worklist(region) op_was_modified = self._process_worklist(pattern_listener) if self.post_walk_func is not None: - op_was_modified |= self.post_walk_func(op, pattern_listener) + op_was_modified |= self.post_walk_func(region, pattern_listener) return result - def _populate_worklist(self, op: Operation) -> None: + def _populate_worklist(self, op: Operation | Region | Block) -> None: """Populate the worklist with all nested operations.""" # We walk in reverse order since we use a stack for our worklist. for sub_op in op.walk( diff --git a/xdsl/transforms/apply_pdl.py b/xdsl/transforms/apply_pdl.py index 5193d94b05..dedee50a0f 100644 --- a/xdsl/transforms/apply_pdl.py +++ b/xdsl/transforms/apply_pdl.py @@ -22,7 +22,6 @@ class ApplyPDLPass(ModulePass): pdl_file: str | None = None def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - payload_module = op if self.pdl_file is not None: assert os.path.exists(self.pdl_file) with open(self.pdl_file) as f: @@ -30,11 +29,11 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: parser = Parser(ctx, pdl_module_str) pdl_module = parser.parse_module() else: - pdl_module = payload_module + pdl_module = op rewrite_patterns: list[RewritePattern] = [ PDLRewritePattern(op, ctx, None) for op in pdl_module.walk() if isinstance(op, pdl.RewriteOp) ] pattern_applier = GreedyRewritePatternApplier(rewrite_patterns) - PatternRewriteWalker(pattern_applier).rewrite_op(payload_module) + PatternRewriteWalker(pattern_applier).rewrite_module(op) diff --git a/xdsl/transforms/convert_qref_to_qssa.py b/xdsl/transforms/convert_qref_to_qssa.py index 4ae11df40f..ed2cb18c31 100644 --- a/xdsl/transforms/convert_qref_to_qssa.py +++ b/xdsl/transforms/convert_qref_to_qssa.py @@ -48,4 +48,4 @@ class ConvertQRefToQssa(ModulePass): def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: PatternRewriteWalker( ConvertQRefToQssaPattern(), apply_recursively=False - ).rewrite_op(op) + ).rewrite_module(op) diff --git a/xdsl/transforms/convert_qssa_to_qref.py b/xdsl/transforms/convert_qssa_to_qref.py index b81ae26fd3..675c7b017f 100644 --- a/xdsl/transforms/convert_qssa_to_qref.py +++ b/xdsl/transforms/convert_qssa_to_qref.py @@ -33,4 +33,4 @@ class ConvertQssaToQRef(ModulePass): name = "convert-qssa-to-qref" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - PatternRewriteWalker(ConvertQssaToQRefPattern()).rewrite_op(op) + PatternRewriteWalker(ConvertQssaToQRefPattern()).rewrite_module(op) diff --git a/xdsl/transforms/convert_scf_to_cf.py b/xdsl/transforms/convert_scf_to_cf.py index 6f0f922ce1..11688e31df 100644 --- a/xdsl/transforms/convert_scf_to_cf.py +++ b/xdsl/transforms/convert_scf_to_cf.py @@ -238,4 +238,4 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: ForLowering(), ] ) - ).rewrite_op(op) + ).rewrite_module(op) diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 38187a6f41..2e94190651 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -252,7 +252,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): ConvertAccessOpFromPrefetchPattern(arg_idx) ) - nested_rewriter.rewrite_op(new_apply_op) + nested_rewriter.rewrite_region(new_apply_op.region) def split_ops( @@ -415,7 +415,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): PatternRewriteWalker( SplitVarithOpPattern(op.region.block.args[prefetch_idx]), apply_recursively=False, - ).rewrite_op(op) + ).rewrite_region(op.region) # determine how ops should be split across the two regions chunk_region_ops, done_exchange_ops = split_ops( diff --git a/xdsl/transforms/dead_code_elimination.py b/xdsl/transforms/dead_code_elimination.py index bf844d9c3e..cfd800db24 100644 --- a/xdsl/transforms/dead_code_elimination.py +++ b/xdsl/transforms/dead_code_elimination.py @@ -159,14 +159,12 @@ def region_dce(region: Region, listener: PatternRewriterListener | None = None) return live_set.changed -def op_dce(op: Operation, listener: PatternRewriterListener | None = None): - changed = tuple(region_dce(region, listener) for region in op.regions) - - return any(changed) +def op_dce(region: Region, listener: PatternRewriterListener | None = None): + return region_dce(region, listener) class DeadCodeElimination(ModulePass): name = "dce" def apply(self, ctx: MLContext, op: ModuleOp) -> None: - op_dce(op) + op_dce(op.body) diff --git a/xdsl/transforms/varith_transformations.py b/xdsl/transforms/varith_transformations.py index 1f6081e60a..52045e1347 100644 --- a/xdsl/transforms/varith_transformations.py +++ b/xdsl/transforms/varith_transformations.py @@ -252,7 +252,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: ] ), walk_reverse=True, - ).rewrite_op(op) + ).rewrite_module(op) class ConvertVarithToArithPass(ModulePass): @@ -268,7 +268,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: PatternRewriteWalker( VarithToArithPattern(), apply_recursively=False, - ).rewrite_op(op) + ).rewrite_module(op) class VarithFuseRepeatedOperandsPass(ModulePass): @@ -285,4 +285,4 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: PatternRewriteWalker( FuseRepeatedAddArgsPattern(self.min_reps), apply_recursively=False, - ).rewrite_op(op) + ).rewrite_module(op)