Skip to content

Commit

Permalink
[BACKEND] Fix the combineSelectAndIf when the user of select in ifOp.…
Browse files Browse the repository at this point in the history
… (#5031)

The CombineTensorSelectAndIf pass currently doesn’t work correctly
**when the user of select is inside the scf.if block**.

For example:

```mlir
%select = arith.select %cond, %trueVal, %falseVal : i32
%if = scf.if %cond -> (i32) {
  %sub = arith.subi %select, %val1 : i32
  scf.yield %sub : i32
} else {
  %mul = arith.muli %select, %val2 : i32
  scf.yield %mul : i32
}
use %select
```

In this case, dom.dominates(ifOp, user) will return true, but directly
using replaceAllUsesWith would lead to incorrect replacement behavior.

```mlir
// without this pr (the user in ifOp use the result of ifOp) 
%if:2 = scf.if %cond -> (i32, i32) {
  %sub = arith.subi %if#1, %val1 : i32
  scf.yield %sub, %trueVal : i32, i32
} else {
  %mul = arith.muli %if#1, %val2 : i32
  scf.yield %mul, %falseVal : i32, i32
}
use %if#1
```


To address this, we need to adjust the user’s operand based on the
specific region it is in.

```mlir
// with this pr (the user in ifOp be canonicaled first)
%if:2 = scf.if %cond -> (i32, i32) {
  %sub = arith.subi %trueVal, %val1 : i32
  scf.yield %sub, %trueVal : i32, i32
} else {
  %mul = arith.muli %falseVal, %val2 : i32
  scf.yield %mul, %falseVal : i32, i32
}
use %if#1
```
  • Loading branch information
tfruan2000 authored Nov 3, 2024
1 parent 0b443ce commit 73df068
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 30 deletions.
52 changes: 49 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "mlir/IR/Dominance.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand All @@ -14,8 +15,52 @@ namespace gpu {
#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

// Return true if the select could be merged into the If without breaking SSA
// rules.
/// The user of select maybe inside either the ThenRegion or ElseRegion of
/// the scf.if. So, canonicalize user of select in scf.if first.
static void canonicalizeSelectUsersInSCFIf(ModuleOp input) {
llvm::MapVector<std::pair<Value, Value>, SmallVector<Operation *>>
usersNeedreplaced;
input.walk([&](arith::SelectOp selectOp) {
auto *parentBlock = selectOp->getBlock();
Value condition = selectOp.getOperand(0);
Value trueVal = selectOp.getOperand(1);
Value falseVal = selectOp.getOperand(2);
Value resVal = selectOp.getResult();
for (auto *condUser : condition.getUsers()) {
if (!llvm::isa<scf::IfOp>(condUser))
continue;
scf::IfOp ifOp = llvm::cast<scf::IfOp>(condUser);
for (auto *resUser : resVal.getUsers()) {
if (ifOp->isProperAncestor(resUser)) {
if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) !=
nullptr) {
// The user is inside the ThenRegion of the scf.if.
usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back(
resUser);
} else {
// The user is inside the ElseRegion of the scf.if.
usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back(
resUser);
}
}
}
}
});

// Replace the operand of user.
for (auto [replacedSrcAndDst, users] :
llvm::make_early_inc_range(usersNeedreplaced)) {
Value srcVal = replacedSrcAndDst.first;
Value dstVal = replacedSrcAndDst.second;
for (Operation *user : llvm::make_early_inc_range(users)) {
srcVal.replaceUsesWithIf(
dstVal, [&](OpOperand &use) { return use.getOwner() == user; });
}
}
}

/// Return true if the select could be merged into the If without breaking SSA
/// rules.
static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp,
DominanceInfo &dom) {
// If needs to be dominated by the select.
Expand All @@ -38,10 +83,11 @@ class CombineTensorSelectAndIfPass
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
DominanceInfo dom(m);
canonicalizeSelectUsersInSCFIf(m);

// Go over the arith.select ops, look if there is an if
// with the same condition.
DominanceInfo dom(m);
llvm::MapVector<scf::IfOp, SmallVector<arith::SelectOp>> selectToIf;
m.walk([&](arith::SelectOp selectOp) {
// Look if there is an if in the same block, with the same condition.
Expand Down
85 changes: 58 additions & 27 deletions test/TritonGPU/combine-select-if.mlir
Original file line number Diff line number Diff line change
@@ -1,46 +1,77 @@
// RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @select_if_combine
tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr<f32>, #blocked>, %cnd: i1) attributes {noinline = false} {
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00>
%cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked>
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00>
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked>
// CHECK-NOT: arith.select
%sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked>
// CHECK: %[[IF_RES:.*]] = scf.if
scf.if %cnd {
tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>, #blocked>
// CHECK: scf.yield %[[CST0]]
}
// CHECK: else
// CHECK: scf.yield %[[CST1]]
// CHECK: tt.store %{{.*}}, %[[IF_RES]]
tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>, #blocked>
tt.return
tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr<f32>>, %cnd: i1) {
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00>
%cst = arith.constant dense<0.000000e+00> : tensor<64xf32>
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00>
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32>
// CHECK-NOT: arith.select
%sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32>
// CHECK: %[[R:.+]] = scf.if %{{.*}}
// CHECK: tt.store %{{.*}}, %{{.*}}
// CHECK: scf.yield %[[CST0]]
// CHECK: } else {
// CHECK: scf.yield %[[CST1]]
// CHECK: }
scf.if %cnd {
tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>>
}
// CHECK: tt.store %{{.*}}, %[[R]]
tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>>
tt.return
}

// -----

// CHECK-LABEL: @if_multiple_sel
tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){
// CHECK-NOT: select
// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) {
// CHECK: scf.yield {{.*}} : i32, i32, f32
// CHECK: } else {
// CHECK: scf.yield {{.*}} : i32, i32, f32
// CHECK: }
// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32
// CHECK-NOT: arith.select
%0 = arith.select %arg0, %arg1, %arg2 : i32
%1 = arith.select %arg0, %arg3, %arg4 : f32
// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) {
// CHECK: scf.yield {{.*}} : i32, i32, f32
// CHECK: } else {
// CHECK: scf.yield {{.*}} : i32, i32, f32
// CHECK: }
%2 = scf.if %arg0 -> (i32) {
%3 = arith.subi %arg1, %arg2 : i32
scf.yield %3 : i32
} else {
scf.yield %arg1 : i32
}
// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32
tt.return %0, %1, %2 : i32, f32, i32
}

// -----
// CHECK-LABEL: tt.func @users_in_if(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i32
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: i32
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: f32
tt.func @users_in_if(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32, i32) {
// CHECK: %[[CST:.*]] = arith.constant 8 : i32
%c8_i32 = arith.constant 8 : i32
// CHECK-NOT: arith.select
%0 = arith.select %arg0, %arg1, %arg2 : i32
%1 = arith.select %arg0, %arg3, %arg4 : f32
// CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (i32, i32, i32, f32) {
// CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : i32
// CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : i32
// CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : i32, i32, i32, f32
// CHECK: } else {
// CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : i32
// CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : i32, i32, i32, f32
// CHECK: }
%2:2 = scf.if %arg0 -> (i32, i32) {
%3 = arith.muli %0, %arg2 : i32
%4 = arith.addi %0, %c8_i32 : i32
scf.yield %3, %4 : i32, i32
} else {
%3 = arith.subi %0, %c8_i32 : i32
scf.yield %arg1, %3 : i32, i32
}
// CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : i32, f32, i32, i32
tt.return %0, %1, %2#0, %2#1 : i32, f32, i32, i32
}

0 comments on commit 73df068

Please sign in to comment.