Skip to content

Commit

Permalink
add apply-eqsat-pdl pass
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 18, 2024
1 parent 5111b75 commit f5157af
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: xdsl-opt %s -p apply-eqsat-pdl | filecheck %s

// CHECK: func.func @impl() -> i32 {
// CHECK-NEXT: %c4 = arith.constant 4 : i32
// CHECK-NEXT: %c4_eq = eqsat.eclass %c4 : i32
// CHECK-NEXT: %c2 = arith.constant 2 : i32
// CHECK-NEXT: %c2_eq = eqsat.eclass %c2 : i32
// CHECK-NEXT: %0 = arith.addi %c2_eq, %c4_eq : i32
// CHECK-NEXT: %sum = arith.addi %c4_eq, %c2_eq : i32
// CHECK-NEXT: %sum_eq = eqsat.eclass %sum, %0 : i32
// CHECK-NEXT: func.return %sum_eq : i32
// CHECK-NEXT: }

func.func @impl() -> i32 {
%c4 = arith.constant 4 : i32
%c4_eq = eqsat.eclass %c4 : i32
%c2 = arith.constant 2 : i32
%c2_eq = eqsat.eclass %c2 : i32
%sum = arith.addi %c4_eq, %c2_eq : i32
%sum_eq = eqsat.eclass %sum : i32
func.return %sum_eq : i32
}

pdl.pattern : benefit(1) {
%x = pdl.operand
%y = pdl.operand
%type = pdl.type
%x_y = pdl.operation "arith.addi" (%x, %y : !pdl.value, !pdl.value) -> (%type : !pdl.type)
pdl.rewrite %x_y {
%y_x = pdl.operation "arith.addi" (%y, %x : !pdl.value, !pdl.value) -> (%type : !pdl.type)
pdl.replace %x_y with %y_x
}
}
6 changes: 6 additions & 0 deletions xdsl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ def get_apply_individual_rewrite():

return ApplyIndividualRewritePass

def get_apply_eqsat_pdl():
from xdsl.transforms import apply_eqsat_pdl

return apply_eqsat_pdl.ApplyEqsatPDLPass

def get_apply_pdl():
from xdsl.transforms import apply_pdl

Expand Down Expand Up @@ -487,6 +492,7 @@ def get_varith_fuse_repeated_operands():

return {
"apply-individual-rewrite": get_apply_individual_rewrite,
"apply-eqsat-pdl": get_apply_eqsat_pdl,
"apply-pdl": get_apply_pdl,
"arith-add-fastmath": get_arith_add_fastmath,
"canonicalize-dmp": get_canonicalize_dmp,
Expand Down
41 changes: 41 additions & 0 deletions xdsl/transforms/apply_eqsat_pdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import builtin, pdl
from xdsl.interpreters.eqsat_pdl import EqsatPDLRewritePattern
from xdsl.parser import Parser
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
RewritePattern,
)


@dataclass(frozen=True)
class ApplyEqsatPDLPass(ModulePass):
name = "apply-eqsat-pdl"

pdl_file: str | None = None

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
payload_module = op
if self.pdl_file is not None:
assert os.path.exists(self.pdl_file)
with open(self.pdl_file) as f:
pdl_module_str = f.read()
parser = Parser(ctx, pdl_module_str)
pdl_module = parser.parse_module()
else:
pdl_module = payload_module
rewrite_patterns: list[RewritePattern] = [
EqsatPDLRewritePattern(op, ctx, None)
for op in pdl_module.walk()
if isinstance(op, pdl.RewriteOp)
]
pattern_applier = GreedyRewritePatternApplier(rewrite_patterns)
# TODO: remove apply_recursively=False
PatternRewriteWalker(pattern_applier, apply_recursively=False).rewrite_op(
payload_module
)

0 comments on commit f5157af

Please sign in to comment.