From f5157af05cb7c93ffae5ee3bcb6f0ec581740833 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 18 Nov 2024 23:17:13 +0000 Subject: [PATCH] add apply-eqsat-pdl pass --- .../apply_eqsat_pdl_swap_inputs.mlir | 33 +++++++++++++++ xdsl/transforms/__init__.py | 6 +++ xdsl/transforms/apply_eqsat_pdl.py | 41 +++++++++++++++++++ 3 files changed, 80 insertions(+) create mode 100644 tests/filecheck/transforms/apply-eqsat-pdl/apply_eqsat_pdl_swap_inputs.mlir create mode 100644 xdsl/transforms/apply_eqsat_pdl.py diff --git a/tests/filecheck/transforms/apply-eqsat-pdl/apply_eqsat_pdl_swap_inputs.mlir b/tests/filecheck/transforms/apply-eqsat-pdl/apply_eqsat_pdl_swap_inputs.mlir new file mode 100644 index 0000000000..12b4365af1 --- /dev/null +++ b/tests/filecheck/transforms/apply-eqsat-pdl/apply_eqsat_pdl_swap_inputs.mlir @@ -0,0 +1,33 @@ +// RUN: xdsl-opt %s -p apply-eqsat-pdl | filecheck %s + +// CHECK: func.func @impl() -> i32 { +// CHECK-NEXT: %c4 = arith.constant 4 : i32 +// CHECK-NEXT: %c4_eq = eqsat.eclass %c4 : i32 +// CHECK-NEXT: %c2 = arith.constant 2 : i32 +// CHECK-NEXT: %c2_eq = eqsat.eclass %c2 : i32 +// CHECK-NEXT: %0 = arith.addi %c2_eq, %c4_eq : i32 +// CHECK-NEXT: %sum = arith.addi %c4_eq, %c2_eq : i32 +// CHECK-NEXT: %sum_eq = eqsat.eclass %sum, %0 : i32 +// CHECK-NEXT: func.return %sum_eq : i32 +// CHECK-NEXT: } + +func.func @impl() -> i32 { + %c4 = arith.constant 4 : i32 + %c4_eq = eqsat.eclass %c4 : i32 + %c2 = arith.constant 2 : i32 + %c2_eq = eqsat.eclass %c2 : i32 + %sum = arith.addi %c4_eq, %c2_eq : i32 + %sum_eq = eqsat.eclass %sum : i32 + func.return %sum_eq : i32 +} + +pdl.pattern : benefit(1) { + %x = pdl.operand + %y = pdl.operand + %type = pdl.type + %x_y = pdl.operation "arith.addi" (%x, %y : !pdl.value, !pdl.value) -> (%type : !pdl.type) + pdl.rewrite %x_y { + %y_x = pdl.operation "arith.addi" (%y, %x : !pdl.value, !pdl.value) -> (%type : !pdl.type) + pdl.replace %x_y with %y_x + } +} diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index b61dccf042..e8f4792e4f 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -11,6 +11,11 @@ def get_apply_individual_rewrite(): return ApplyIndividualRewritePass + def get_apply_eqsat_pdl(): + from xdsl.transforms import apply_eqsat_pdl + + return apply_eqsat_pdl.ApplyEqsatPDLPass + def get_apply_pdl(): from xdsl.transforms import apply_pdl @@ -487,6 +492,7 @@ def get_varith_fuse_repeated_operands(): return { "apply-individual-rewrite": get_apply_individual_rewrite, + "apply-eqsat-pdl": get_apply_eqsat_pdl, "apply-pdl": get_apply_pdl, "arith-add-fastmath": get_arith_add_fastmath, "canonicalize-dmp": get_canonicalize_dmp, diff --git a/xdsl/transforms/apply_eqsat_pdl.py b/xdsl/transforms/apply_eqsat_pdl.py new file mode 100644 index 0000000000..3bfb350d9d --- /dev/null +++ b/xdsl/transforms/apply_eqsat_pdl.py @@ -0,0 +1,41 @@ +import os +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import builtin, pdl +from xdsl.interpreters.eqsat_pdl import EqsatPDLRewritePattern +from xdsl.parser import Parser +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + RewritePattern, +) + + +@dataclass(frozen=True) +class ApplyEqsatPDLPass(ModulePass): + name = "apply-eqsat-pdl" + + pdl_file: str | None = None + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + payload_module = op + if self.pdl_file is not None: + assert os.path.exists(self.pdl_file) + with open(self.pdl_file) as f: + pdl_module_str = f.read() + parser = Parser(ctx, pdl_module_str) + pdl_module = parser.parse_module() + else: + pdl_module = payload_module + rewrite_patterns: list[RewritePattern] = [ + EqsatPDLRewritePattern(op, ctx, None) + for op in pdl_module.walk() + if isinstance(op, pdl.RewriteOp) + ] + pattern_applier = GreedyRewritePatternApplier(rewrite_patterns) + # TODO: remove apply_recursively=False + PatternRewriteWalker(pattern_applier, apply_recursively=False).rewrite_op( + payload_module + )