Skip to content

Commit

Permalink
dialects: (cf) cond_br folding (#3283)
Browse files Browse the repository at this point in the history
First two `cond_br` canonicalization patterns:
- Constant folding (`cf.cond_br %true ^0 ^1` == `cf.br ^0`)
- Passthrough (conditional branch to a branch with just a single `br` op
gets forwarded)
  • Loading branch information
alexarice authored Oct 10, 2024
1 parent 76cccc0 commit 7f0f3e6
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 16 deletions.
79 changes: 76 additions & 3 deletions tests/filecheck/dialects/cf/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,82 @@ func.func @br_passthrough(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @br_dead_passthrough() {
cf.br ^1
cf.br ^1
^0:
cf.br ^1
cf.br ^1
^1:
func.return
func.return
}

/// Test the folding of CondBranchOp with a constant condition.
/// This will reduce further with other rewrites

// CHECK: func.func @cond_br_folding(%cond : i1, %a : i32) {
// CHECK-NEXT: cf.cond_br %cond, ^[[#b0:]], ^[[#b0]]
// CHECK-NEXT: ^[[#b0]]:
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @cond_br_folding(%cond : i1, %a : i32) {
%false_cond = arith.constant false
%true_cond = arith.constant true
cf.cond_br %cond, ^bb1, ^bb2(%a : i32)

^bb1:
cf.cond_br %true_cond, ^bb3, ^bb2(%a : i32)

^bb2(%x : i32):
cf.cond_br %false_cond, ^bb2(%x : i32), ^bb3

^bb3:
return
}

/// Test the compound folding of BranchOp and CondBranchOp.
// CHECK-NEXT: func.func @cond_br_and_br_folding(%a : i32) {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @cond_br_and_br_folding(%a : i32) {

%false_cond = arith.constant false
%true_cond = arith.constant true
cf.cond_br %true_cond, ^bb2, ^bb1(%a : i32)

^bb1(%x : i32):
cf.cond_br %false_cond, ^bb1(%x : i32), ^bb2

^bb2:
return
}

/// Test that pass-through successors of CondBranchOp get folded.
// CHECK: func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
// CHECK-NEXT: cf.cond_br %cond, ^[[#b0:]](%arg0, %arg1 : i32, i32), ^[[#b0]](%arg2, %arg2 : i32, i32)
// CHECK-NEXT: ^[[#b0]](%arg4 : i32, %arg5 : i32):
// CHECK-NEXT: func.return %arg4, %arg5 : i32, i32
// CHECK-NEXT: }
func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
^bb1(%arg3: i32):
cf.br ^bb2(%arg3, %arg1 : i32, i32)
^bb2(%arg4: i32, %arg5: i32):
return %arg4, %arg5 : i32, i32
}

/// Test the failure modes of collapsing CondBranchOp pass-throughs successors.

// CHECK-NEXT: func.func @cond_br_pass_through_fail(%cond : i1) {
// CHECK-NEXT: cf.cond_br %cond, ^[[#b0:]], ^[[#b1:]]
// CHECK-NEXT: ^[[#b0]]:
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: cf.br ^[[#b1]]
// CHECK-NEXT: ^[[#b1]]:
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @cond_br_pass_through_fail(%cond : i1) {
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
"test.op"() : () -> ()
cf.br ^bb2
^bb2:
return
}
13 changes: 12 additions & 1 deletion xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def __init__(self, dest: Block, *ops: Operation | SSAValue):
assembly_format = "$successor (`(` $arguments^ `:` type($arguments) `)`)? attr-dict"


class ConditionalBranchHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.cf import (
SimplifyConstCondBranchPred,
SimplifyPassThroughCondBranch,
)

return (SimplifyConstCondBranchPred(), SimplifyPassThroughCondBranch())


@irdl_op_definition
class ConditionalBranch(IRDLOperation):
"""Conditional branch operation"""
Expand All @@ -111,7 +122,7 @@ class ConditionalBranch(IRDLOperation):
then_block = successor_def()
else_block = successor_def()

traits = frozenset([IsTerminator()])
traits = frozenset([IsTerminator(), ConditionalBranchHasCanonicalizationPatterns()])

def __init__(
self,
Expand Down
52 changes: 52 additions & 0 deletions xdsl/transforms/canonicalization_patterns/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand


class AssertTrue(RewritePattern):
Expand Down Expand Up @@ -120,3 +121,54 @@ def match_and_rewrite(self, op: cf.Branch, rewriter: PatternRewriter):
(block, args) = ret

rewriter.replace_matched_op(cf.Branch(block, *args))


class SimplifyConstCondBranchPred(RewritePattern):
"""
cf.cond_br true, ^bb1, ^bb2
-> br ^bb1
cf.cond_br false, ^bb1, ^bb2
-> br ^bb2
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter):
# Check if cond operand is constant
cond = const_evaluate_operand(op.cond)

if cond == 1:
rewriter.replace_matched_op(cf.Branch(op.then_block, *op.then_arguments))
elif cond == 0:
rewriter.replace_matched_op(cf.Branch(op.else_block, *op.else_arguments))


class SimplifyPassThroughCondBranch(RewritePattern):
"""
cf.cond_br %cond, ^bb1, ^bb2
^bb1
br ^bbN(...)
^bb2
br ^bbK(...)
-> cf.cond_br %cond, ^bbN(...), ^bbK(...)
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.ConditionalBranch, rewriter: PatternRewriter):
# Try to collapse both branches
collapsed_then = collapse_branch(op.then_block, op.then_arguments)
collapsed_else = collapse_branch(op.else_block, op.else_arguments)

# If neither collapsed then we return
if collapsed_then is None and collapsed_else is None:
return

(new_then, new_then_args) = collapsed_then or (op.then_block, op.then_arguments)

(new_else, new_else_args) = collapsed_else or (op.else_block, op.else_arguments)

rewriter.replace_matched_op(
cf.ConditionalBranch(
op.cond, new_then, new_then_args, new_else, new_else_args
)
)
14 changes: 2 additions & 12 deletions xdsl/transforms/canonicalization_patterns/scf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections.abc import Sequence

from xdsl.dialects import arith, scf
from xdsl.dialects.builtin import IntegerAttr
from xdsl.dialects import scf
from xdsl.ir import Operation, Region, SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand


class SimplifyTrivialLoops(RewritePattern):
Expand Down Expand Up @@ -77,13 +77,3 @@ def replace_op_with_region(
rewriter.inline_block(block, InsertPoint.before(op), args)
rewriter.replace_op(op, (), terminator.operands)
rewriter.erase_op(terminator)


def const_evaluate_operand(operand: SSAValue) -> int | None:
"""
Try to constant evaluate an SSA value, returning None on failure.
"""
if isinstance(op := operand.owner, arith.Constant) and isinstance(
val := op.value, IntegerAttr
):
return val.value.data
13 changes: 13 additions & 0 deletions xdsl/transforms/canonicalization_patterns/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from xdsl.dialects import arith
from xdsl.dialects.builtin import IntegerAttr
from xdsl.ir import SSAValue


def const_evaluate_operand(operand: SSAValue) -> int | None:
"""
Try to constant evaluate an SSA value, returning None on failure.
"""
if isinstance(op := operand.owner, arith.Constant) and isinstance(
val := op.value, IntegerAttr
):
return val.value.data

0 comments on commit 7f0f3e6

Please sign in to comment.