diff --git a/tests/filecheck/interpreters/extra_file.mlir b/tests/filecheck/interpreters/extra_file.mlir new file mode 100644 index 0000000000..5293ae8e68 --- /dev/null +++ b/tests/filecheck/interpreters/extra_file.mlir @@ -0,0 +1,9 @@ +pdl.pattern : benefit(1) { + %zero_attr = pdl.attribute = 0 + %root = pdl.operation "test.op" {"attr" = %zero_attr} + pdl.rewrite %root { + %one_attr = pdl.attribute = 1 + %new_op = pdl.operation "test.op" {"attr" = %one_attr} + pdl.replace %root with %new_op + } +} diff --git a/tests/filecheck/interpreters/test_pdl_interpreter_extra_file.mlir b/tests/filecheck/interpreters/test_pdl_interpreter_extra_file.mlir new file mode 100644 index 0000000000..b9ff3f6229 --- /dev/null +++ b/tests/filecheck/interpreters/test_pdl_interpreter_extra_file.mlir @@ -0,0 +1,7 @@ +// RUN: xdsl-opt %s -p 'apply-pdl{pdl_file="%p/extra_file.mlir"}' | filecheck %s + +"test.op"() {attr = 0} : () -> () + +//CHECK: builtin.module { +// CHECK-NEXT: "test.op"() {"attr" = 1 : i64} : () -> () +// CHECK-NEXT: } diff --git a/tests/filecheck/interpreters/test_pdl_interpreter_simple.mlir b/tests/filecheck/interpreters/test_pdl_interpreter_simple.mlir new file mode 100644 index 0000000000..c88535b705 --- /dev/null +++ b/tests/filecheck/interpreters/test_pdl_interpreter_simple.mlir @@ -0,0 +1,26 @@ +// RUN: xdsl-opt %s -p apply-pdl | filecheck %s + +"test.op"() {attr = 0} : () -> () + +pdl.pattern : benefit(1) { + %zero_attr = pdl.attribute = 0 + %root = pdl.operation "test.op" {"attr" = %zero_attr} + pdl.rewrite %root { + %one_attr = pdl.attribute = 1 + %new_op = pdl.operation "test.op" {"attr" = %one_attr} + pdl.replace %root with %new_op + } +} + +//CHECK: builtin.module { +// CHECK-NEXT: "test.op"() {"attr" = 1 : i64} : () -> () +// CHECK-NEXT: pdl.pattern : benefit(1) { +// CHECK-NEXT: %zero_attr = pdl.attribute = 0 : i64 +// CHECK-NEXT: %root = pdl.operation "test.op" {"attr" = %zero_attr} +// CHECK-NEXT: pdl.rewrite %root { +// CHECK-NEXT: %one_attr = pdl.attribute = 1 : i64 +// CHECK-NEXT: %new_op = pdl.operation "test.op" {"attr" = %one_attr} +// CHECK-NEXT: pdl.replace %root with %new_op +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index c2456a4d03..0cbe05ac2a 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -6,6 +6,11 @@ def get_all_passes() -> dict[str, Callable[[], type[ModulePass]]]: """Return the list of all available passes.""" + def get_apply_pdl(): + from xdsl.transforms import apply_pdl + + return apply_pdl.ApplyPDLPass + def get_arith_add_fastmath(): from xdsl.transforms import arith_add_fastmath @@ -446,6 +451,7 @@ def get_varith_fuse_repeated_operands(): return varith_transformations.VarithFuseRepeatedOperandsPass return { + "apply-pdl": get_apply_pdl, "arith-add-fastmath": get_arith_add_fastmath, "loop-hoist-memref": get_loop_hoist_memref, "canonicalize-dmp": get_canonicalize_dmp, diff --git a/xdsl/transforms/apply_pdl.py b/xdsl/transforms/apply_pdl.py new file mode 100644 index 0000000000..5193d94b05 --- /dev/null +++ b/xdsl/transforms/apply_pdl.py @@ -0,0 +1,40 @@ +import os +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import builtin, pdl +from xdsl.interpreters.pdl import ( + PDLRewritePattern, +) +from xdsl.parser import Parser +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + RewritePattern, +) + + +@dataclass(frozen=True) +class ApplyPDLPass(ModulePass): + name = "apply-pdl" + + pdl_file: str | None = None + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + payload_module = op + if self.pdl_file is not None: + assert os.path.exists(self.pdl_file) + with open(self.pdl_file) as f: + pdl_module_str = f.read() + parser = Parser(ctx, pdl_module_str) + pdl_module = parser.parse_module() + else: + pdl_module = payload_module + rewrite_patterns: list[RewritePattern] = [ + PDLRewritePattern(op, ctx, None) + for op in pdl_module.walk() + if isinstance(op, pdl.RewriteOp) + ] + pattern_applier = GreedyRewritePatternApplier(rewrite_patterns) + PatternRewriteWalker(pattern_applier).rewrite_op(payload_module)