From 28546233389999c7034158527e234f3315b1fdab Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 4 Nov 2024 13:08:50 +0000 Subject: [PATCH 1/7] Apply PDL from source --- tests/filecheck/interpreters/extra_file.mlir | 9 ++++ .../interpreters/pdl_definition.mlir | 9 ++++ .../test_pdl_different_files.mlir | 7 +++ .../test_pdl_interpreter_extra_file.mlir | 7 +++ .../test_pdl_interpreter_simple.mlir | 26 +++++++++++ xdsl/transforms/__init__.py | 12 ++--- xdsl/transforms/apply_pdl.py | 45 +++++++++++++++++++ 7 files changed, 109 insertions(+), 6 deletions(-) create mode 100644 tests/filecheck/interpreters/extra_file.mlir create mode 100644 tests/filecheck/interpreters/pdl_definition.mlir create mode 100644 tests/filecheck/interpreters/test_pdl_different_files.mlir create mode 100644 tests/filecheck/interpreters/test_pdl_interpreter_extra_file.mlir create mode 100644 tests/filecheck/interpreters/test_pdl_interpreter_simple.mlir create mode 100644 xdsl/transforms/apply_pdl.py diff --git a/tests/filecheck/interpreters/extra_file.mlir b/tests/filecheck/interpreters/extra_file.mlir new file mode 100644 index 0000000000..5293ae8e68 --- /dev/null +++ b/tests/filecheck/interpreters/extra_file.mlir @@ -0,0 +1,9 @@ +pdl.pattern : benefit(1) { + %zero_attr = pdl.attribute = 0 + %root = pdl.operation "test.op" {"attr" = %zero_attr} + pdl.rewrite %root { + %one_attr = pdl.attribute = 1 + %new_op = pdl.operation "test.op" {"attr" = %one_attr} + pdl.replace %root with %new_op + } +} diff --git a/tests/filecheck/interpreters/pdl_definition.mlir b/tests/filecheck/interpreters/pdl_definition.mlir new file mode 100644 index 0000000000..5293ae8e68 --- /dev/null +++ b/tests/filecheck/interpreters/pdl_definition.mlir @@ -0,0 +1,9 @@ +pdl.pattern : benefit(1) { + %zero_attr = pdl.attribute = 0 + %root = pdl.operation "test.op" {"attr" = %zero_attr} + pdl.rewrite %root { + %one_attr = pdl.attribute = 1 + %new_op = pdl.operation "test.op" {"attr" = %one_attr} + pdl.replace %root with %new_op + } +} diff --git a/tests/filecheck/interpreters/test_pdl_different_files.mlir b/tests/filecheck/interpreters/test_pdl_different_files.mlir new file mode 100644 index 0000000000..8ff283c303 --- /dev/null +++ b/tests/filecheck/interpreters/test_pdl_different_files.mlir @@ -0,0 +1,7 @@ +// RUN: xdsl-opt %s -p 'apply-pdl{pdl_file="%p/pdl_definition.mlir"}' | filecheck %s + +"test.op"() {attr = 0} : () -> () + +//CHECK: builtin.module { +// CHECK-NEXT: "test.op"() {"attr" = 1 : i64} : () -> () +// CHECK-NEXT: } diff --git a/tests/filecheck/interpreters/test_pdl_interpreter_extra_file.mlir b/tests/filecheck/interpreters/test_pdl_interpreter_extra_file.mlir new file mode 100644 index 0000000000..b9ff3f6229 --- /dev/null +++ b/tests/filecheck/interpreters/test_pdl_interpreter_extra_file.mlir @@ -0,0 +1,7 @@ +// RUN: xdsl-opt %s -p 'apply-pdl{pdl_file="%p/extra_file.mlir"}' | filecheck %s + +"test.op"() {attr = 0} : () -> () + +//CHECK: builtin.module { +// CHECK-NEXT: "test.op"() {"attr" = 1 : i64} : () -> () +// CHECK-NEXT: } diff --git a/tests/filecheck/interpreters/test_pdl_interpreter_simple.mlir b/tests/filecheck/interpreters/test_pdl_interpreter_simple.mlir new file mode 100644 index 0000000000..c88535b705 --- /dev/null +++ b/tests/filecheck/interpreters/test_pdl_interpreter_simple.mlir @@ -0,0 +1,26 @@ +// RUN: xdsl-opt %s -p apply-pdl | filecheck %s + +"test.op"() {attr = 0} : () -> () + +pdl.pattern : benefit(1) { + %zero_attr = pdl.attribute = 0 + %root = pdl.operation "test.op" {"attr" = %zero_attr} + pdl.rewrite %root { + %one_attr = pdl.attribute = 1 + %new_op = pdl.operation "test.op" {"attr" = %one_attr} + pdl.replace %root with %new_op + } +} + +//CHECK: builtin.module { +// CHECK-NEXT: "test.op"() {"attr" = 1 : i64} : () -> () +// CHECK-NEXT: pdl.pattern : benefit(1) { +// CHECK-NEXT: %zero_attr = pdl.attribute = 0 : i64 +// CHECK-NEXT: %root = pdl.operation "test.op" {"attr" = %zero_attr} +// CHECK-NEXT: pdl.rewrite %root { +// CHECK-NEXT: %one_attr = pdl.attribute = 1 : i64 +// CHECK-NEXT: %new_op = pdl.operation "test.op" {"attr" = %one_attr} +// CHECK-NEXT: pdl.replace %root with %new_op +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index a4f5d5b6fb..140edad3d3 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -6,6 +6,11 @@ def get_all_passes() -> dict[str, Callable[[], type[ModulePass]]]: """Return the list of all available passes.""" + def get_apply_pdl(): + from xdsl.transforms import apply_pdl + + return apply_pdl.ApplyPDLPass + def get_arith_add_fastmath(): from xdsl.transforms import arith_add_fastmath @@ -236,11 +241,6 @@ def get_memref_to_dsd(): return memref_to_dsd.MemrefToDsdPass - def get_memref_to_ptr(): - from xdsl.transforms import convert_memref_to_ptr - - return convert_memref_to_ptr.ConvertMemrefToPtr - def get_mlir_opt(): from xdsl.transforms import mlir_opt @@ -441,6 +441,7 @@ def get_varith_fuse_repeated_operands(): return varith_transformations.VarithFuseRepeatedOperandsPass return { + "apply-pdl": get_apply_pdl, "arith-add-fastmath": get_arith_add_fastmath, "loop-hoist-memref": get_loop_hoist_memref, "canonicalize-dmp": get_canonicalize_dmp, @@ -506,7 +507,6 @@ def get_varith_fuse_repeated_operands(): "memref-stream-tile-outer-loops": get_memref_stream_tile_outer_loops, "memref-stream-legalize": get_memref_stream_legalize, "memref-to-dsd": get_memref_to_dsd, - "convert-memref-to-ptr": get_memref_to_ptr, "mlir-opt": get_mlir_opt, "printf-to-llvm": get_printf_to_llvm, "printf-to-putchar": get_printf_to_putchar, diff --git a/xdsl/transforms/apply_pdl.py b/xdsl/transforms/apply_pdl.py new file mode 100644 index 0000000000..4cad9ee56a --- /dev/null +++ b/xdsl/transforms/apply_pdl.py @@ -0,0 +1,45 @@ +import os +from dataclasses import dataclass +from io import StringIO +from typing import cast + +from xdsl.context import MLContext +from xdsl.dialects import builtin, pdl +from xdsl.interpreters.pdl import ( + PDLRewritePattern, +) +from xdsl.parser import Parser +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + RewritePattern, +) + + +@dataclass(frozen=True) +class ApplyPDLPass(ModulePass): + name = "apply-pdl" + + pdl_file: str | None = None + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + payload_module = op + if self.pdl_file is not None: + assert os.path.exists(self.pdl_file) + with open(self.pdl_file) as f: + pdl_module_str = f.read() + parser = Parser(ctx, pdl_module_str) + pdl_module = parser.parse_module() + else: + pdl_module = payload_module + stream = StringIO() + rewrite_patterns = [ + cast(RewritePattern, PDLRewritePattern(op, ctx, stream)) + for op in pdl_module.walk() + if isinstance(op, pdl.RewriteOp) + ] + pattern_applier = GreedyRewritePatternApplier(rewrite_patterns) + PatternRewriteWalker(pattern_applier).rewrite_op(payload_module) + # pattern_rewriter = PatternRewriter(payload_module) + # pattern_applier.match_and_rewrite(payload_module, pattern_rewriter) From 29cc6aeb8eaeb659a890b6cf77c8d8f8c5ef825d Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 4 Nov 2024 13:12:19 +0000 Subject: [PATCH 2/7] Remove useless comments --- xdsl/transforms/apply_pdl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xdsl/transforms/apply_pdl.py b/xdsl/transforms/apply_pdl.py index 4cad9ee56a..754134211f 100644 --- a/xdsl/transforms/apply_pdl.py +++ b/xdsl/transforms/apply_pdl.py @@ -41,5 +41,3 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: ] pattern_applier = GreedyRewritePatternApplier(rewrite_patterns) PatternRewriteWalker(pattern_applier).rewrite_op(payload_module) - # pattern_rewriter = PatternRewriter(payload_module) - # pattern_applier.match_and_rewrite(payload_module, pattern_rewriter) From aac466eaa60f89cd4599e4722314210c46fc595d Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 4 Nov 2024 13:25:20 +0000 Subject: [PATCH 3/7] Correct this accidental removal --- xdsl/transforms/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index 140edad3d3..a426cc40e6 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -241,6 +241,11 @@ def get_memref_to_dsd(): return memref_to_dsd.MemrefToDsdPass + def get_memref_to_ptr(): + from xdsl.transforms import convert_memref_to_ptr + + return convert_memref_to_ptr.ConvertMemrefToPtr + def get_mlir_opt(): from xdsl.transforms import mlir_opt @@ -507,6 +512,7 @@ def get_varith_fuse_repeated_operands(): "memref-stream-tile-outer-loops": get_memref_stream_tile_outer_loops, "memref-stream-legalize": get_memref_stream_legalize, "memref-to-dsd": get_memref_to_dsd, + "convert-memref-to-ptr": get_memref_to_ptr, "mlir-opt": get_mlir_opt, "printf-to-llvm": get_printf_to_llvm, "printf-to-putchar": get_printf_to_putchar, From ee0a850ca43585f78dd324fdadaf75e730c1c646 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 4 Nov 2024 16:09:14 +0000 Subject: [PATCH 4/7] Remove the redundant tests --- tests/filecheck/interpreters/pdl_definition.mlir | 9 --------- .../filecheck/interpreters/test_pdl_different_files.mlir | 7 ------- 2 files changed, 16 deletions(-) delete mode 100644 tests/filecheck/interpreters/pdl_definition.mlir delete mode 100644 tests/filecheck/interpreters/test_pdl_different_files.mlir diff --git a/tests/filecheck/interpreters/pdl_definition.mlir b/tests/filecheck/interpreters/pdl_definition.mlir deleted file mode 100644 index 5293ae8e68..0000000000 --- a/tests/filecheck/interpreters/pdl_definition.mlir +++ /dev/null @@ -1,9 +0,0 @@ -pdl.pattern : benefit(1) { - %zero_attr = pdl.attribute = 0 - %root = pdl.operation "test.op" {"attr" = %zero_attr} - pdl.rewrite %root { - %one_attr = pdl.attribute = 1 - %new_op = pdl.operation "test.op" {"attr" = %one_attr} - pdl.replace %root with %new_op - } -} diff --git a/tests/filecheck/interpreters/test_pdl_different_files.mlir b/tests/filecheck/interpreters/test_pdl_different_files.mlir deleted file mode 100644 index 8ff283c303..0000000000 --- a/tests/filecheck/interpreters/test_pdl_different_files.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: xdsl-opt %s -p 'apply-pdl{pdl_file="%p/pdl_definition.mlir"}' | filecheck %s - -"test.op"() {attr = 0} : () -> () - -//CHECK: builtin.module { -// CHECK-NEXT: "test.op"() {"attr" = 1 : i64} : () -> () -// CHECK-NEXT: } From ee374a473fc8a141e8509fadbdcc8d71907ee36a Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 4 Nov 2024 16:13:10 +0000 Subject: [PATCH 5/7] Update xdsl/transforms/apply_pdl.py Co-authored-by: Sasha Lopoukhine --- xdsl/transforms/apply_pdl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/transforms/apply_pdl.py b/xdsl/transforms/apply_pdl.py index 754134211f..0614bb6ab3 100644 --- a/xdsl/transforms/apply_pdl.py +++ b/xdsl/transforms/apply_pdl.py @@ -35,7 +35,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: pdl_module = payload_module stream = StringIO() rewrite_patterns = [ - cast(RewritePattern, PDLRewritePattern(op, ctx, stream)) + cast(RewritePattern, PDLRewritePattern(op, ctx, None)) for op in pdl_module.walk() if isinstance(op, pdl.RewriteOp) ] From 39822f000a5deb70c4f91fec17a584eb953da575 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 4 Nov 2024 16:13:59 +0000 Subject: [PATCH 6/7] Remove not accessed variable --- xdsl/transforms/apply_pdl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xdsl/transforms/apply_pdl.py b/xdsl/transforms/apply_pdl.py index 0614bb6ab3..a426c13058 100644 --- a/xdsl/transforms/apply_pdl.py +++ b/xdsl/transforms/apply_pdl.py @@ -1,6 +1,5 @@ import os from dataclasses import dataclass -from io import StringIO from typing import cast from xdsl.context import MLContext @@ -33,7 +32,6 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: pdl_module = parser.parse_module() else: pdl_module = payload_module - stream = StringIO() rewrite_patterns = [ cast(RewritePattern, PDLRewritePattern(op, ctx, None)) for op in pdl_module.walk() From eeaf9f7f7c56b2b14ce74de6da15fea9639c5def Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Mon, 4 Nov 2024 17:39:43 +0000 Subject: [PATCH 7/7] Type annotation inline --- xdsl/transforms/apply_pdl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xdsl/transforms/apply_pdl.py b/xdsl/transforms/apply_pdl.py index a426c13058..5193d94b05 100644 --- a/xdsl/transforms/apply_pdl.py +++ b/xdsl/transforms/apply_pdl.py @@ -1,6 +1,5 @@ import os from dataclasses import dataclass -from typing import cast from xdsl.context import MLContext from xdsl.dialects import builtin, pdl @@ -32,8 +31,8 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: pdl_module = parser.parse_module() else: pdl_module = payload_module - rewrite_patterns = [ - cast(RewritePattern, PDLRewritePattern(op, ctx, None)) + rewrite_patterns: list[RewritePattern] = [ + PDLRewritePattern(op, ctx, None) for op in pdl_module.walk() if isinstance(op, pdl.RewriteOp) ]