From 73df068b8e24d68f7afe776e798db12a75ba9271 Mon Sep 17 00:00:00 2001 From: tfruan <60765824+tfruan2000@users.noreply.github.com> Date: Mon, 4 Nov 2024 00:35:54 +0800 Subject: [PATCH] [BACKEND] Fix the combineSelectAndIf when the user of select in ifOp. (#5031) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CombineTensorSelectAndIf pass currently doesn’t work correctly **when the user of select is inside the scf.if block**. For example: ```mlir %select = arith.select %cond, %trueVal, %falseVal : i32 %if = scf.if %cond -> (i32) { %sub = arith.subi %select, %val1 : i32 scf.yield %sub : i32 } else { %mul = arith.muli %select, %val2 : i32 scf.yield %mul : i32 } use %select ``` In this case, dom.dominates(ifOp, user) will return true, but directly using replaceAllUsesWith would lead to incorrect replacement behavior. ```mlir // without this pr (the user in ifOp use the result of ifOp) %if:2 = scf.if %cond -> (i32, i32) { %sub = arith.subi %if#1, %val1 : i32 scf.yield %sub, %trueVal : i32, i32 } else { %mul = arith.muli %if#1, %val2 : i32 scf.yield %mul, %falseVal : i32, i32 } use %if#1 ``` To address this, we need to adjust the user’s operand based on the specific region it is in. ```mlir // with this pr (the user in ifOp be canonicaled first) %if:2 = scf.if %cond -> (i32, i32) { %sub = arith.subi %trueVal, %val1 : i32 scf.yield %sub, %trueVal : i32, i32 } else { %mul = arith.muli %falseVal, %val2 : i32 scf.yield %mul, %falseVal : i32, i32 } use %if#1 ``` --- .../Transforms/CombineTensorSelectAndIf.cpp | 52 +++++++++++- test/TritonGPU/combine-select-if.mlir | 85 +++++++++++++------ 2 files changed, 107 insertions(+), 30 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp index 16183b1af4..203fe01ba6 100644 --- a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp +++ b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -1,4 +1,5 @@ #include "mlir/IR/Dominance.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -14,8 +15,52 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" -// Return true if the select could be merged into the If without breaking SSA -// rules. +/// The user of select maybe inside either the ThenRegion or ElseRegion of +/// the scf.if. So, canonicalize user of select in scf.if first. +static void canonicalizeSelectUsersInSCFIf(ModuleOp input) { + llvm::MapVector, SmallVector> + usersNeedreplaced; + input.walk([&](arith::SelectOp selectOp) { + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + Value trueVal = selectOp.getOperand(1); + Value falseVal = selectOp.getOperand(2); + Value resVal = selectOp.getResult(); + for (auto *condUser : condition.getUsers()) { + if (!llvm::isa(condUser)) + continue; + scf::IfOp ifOp = llvm::cast(condUser); + for (auto *resUser : resVal.getUsers()) { + if (ifOp->isProperAncestor(resUser)) { + if (ifOp.getThenRegion().findAncestorOpInRegion(*resUser) != + nullptr) { + // The user is inside the ThenRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, trueVal)].push_back( + resUser); + } else { + // The user is inside the ElseRegion of the scf.if. + usersNeedreplaced[std::make_pair(resVal, falseVal)].push_back( + resUser); + } + } + } + } + }); + + // Replace the operand of user. + for (auto [replacedSrcAndDst, users] : + llvm::make_early_inc_range(usersNeedreplaced)) { + Value srcVal = replacedSrcAndDst.first; + Value dstVal = replacedSrcAndDst.second; + for (Operation *user : llvm::make_early_inc_range(users)) { + srcVal.replaceUsesWithIf( + dstVal, [&](OpOperand &use) { return use.getOwner() == user; }); + } + } +} + +/// Return true if the select could be merged into the If without breaking SSA +/// rules. static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, DominanceInfo &dom) { // If needs to be dominated by the select. @@ -38,10 +83,11 @@ class CombineTensorSelectAndIfPass void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - DominanceInfo dom(m); + canonicalizeSelectUsersInSCFIf(m); // Go over the arith.select ops, look if there is an if // with the same condition. + DominanceInfo dom(m); llvm::MapVector> selectToIf; m.walk([&](arith::SelectOp selectOp) { // Look if there is an if in the same block, with the same condition. diff --git a/test/TritonGPU/combine-select-if.mlir b/test/TritonGPU/combine-select-if.mlir index 62a9474dcb..f00b971235 100644 --- a/test/TritonGPU/combine-select-if.mlir +++ b/test/TritonGPU/combine-select-if.mlir @@ -1,46 +1,77 @@ // RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s -#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @select_if_combine - tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr, #blocked>, %cnd: i1) attributes {noinline = false} { - // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> - %cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked> - // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> - %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked> - // CHECK-NOT: arith.select - %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked> - // CHECK: %[[IF_RES:.*]] = scf.if - scf.if %cnd { - tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr, #blocked> - // CHECK: scf.yield %[[CST0]] - } - // CHECK: else - // CHECK: scf.yield %[[CST1]] - // CHECK: tt.store %{{.*}}, %[[IF_RES]] - tt.store %dst_ptr, %sel : tensor<64x!tt.ptr, #blocked> - tt.return +tt.func public @select_if_combine(%arg0: tensor<64xf32>, %dst_ptr: tensor<64x!tt.ptr>, %cnd: i1) { + // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> + %cst = arith.constant dense<0.000000e+00> : tensor<64xf32> + // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32> + // CHECK-NOT: arith.select + %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32> + // CHECK: %[[R:.+]] = scf.if %{{.*}} + // CHECK: tt.store %{{.*}}, %{{.*}} + // CHECK: scf.yield %[[CST0]] + // CHECK: } else { + // CHECK: scf.yield %[[CST1]] + // CHECK: } + scf.if %cnd { + tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr> } + // CHECK: tt.store %{{.*}}, %[[R]] + tt.store %dst_ptr, %sel : tensor<64x!tt.ptr> + tt.return } // ----- - // CHECK-LABEL: @if_multiple_sel tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){ -// CHECK-NOT: select -// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { -// CHECK: scf.yield {{.*}} : i32, i32, f32 -// CHECK: } else { -// CHECK: scf.yield {{.*}} : i32, i32, f32 -// CHECK: } -// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 + // CHECK-NOT: arith.select %0 = arith.select %arg0, %arg1, %arg2 : i32 %1 = arith.select %arg0, %arg3, %arg4 : f32 + // CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) { + // CHECK: scf.yield {{.*}} : i32, i32, f32 + // CHECK: } else { + // CHECK: scf.yield {{.*}} : i32, i32, f32 + // CHECK: } %2 = scf.if %arg0 -> (i32) { %3 = arith.subi %arg1, %arg2 : i32 scf.yield %3 : i32 } else { scf.yield %arg1 : i32 } + // CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32 tt.return %0, %1, %2 : i32, f32, i32 } + +// ----- +// CHECK-LABEL: tt.func @users_in_if( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i1 +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i32 +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: i32 +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: f32 +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: f32 +tt.func @users_in_if(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32, i32) { + // CHECK: %[[CST:.*]] = arith.constant 8 : i32 + %c8_i32 = arith.constant 8 : i32 + // CHECK-NOT: arith.select + %0 = arith.select %arg0, %arg1, %arg2 : i32 + %1 = arith.select %arg0, %arg3, %arg4 : f32 + // CHECK: %[[R:.+]]:4 = scf.if %[[ARG0]] -> (i32, i32, i32, f32) { + // CHECK: %[[MULI:.*]] = arith.muli %[[ARG1]], %[[ARG2]] : i32 + // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[CST]] : i32 + // CHECK: scf.yield %[[MULI]], %[[ADDI]], %[[ARG1]], %[[ARG3]] : i32, i32, i32, f32 + // CHECK: } else { + // CHECK: %[[ADDI:.*]] = arith.subi %[[ARG2]], %[[CST]] : i32 + // CHECK: scf.yield %[[ARG1]], %[[ADDI]], %[[ARG2]], %[[ARG4]] : i32, i32, i32, f32 + // CHECK: } + %2:2 = scf.if %arg0 -> (i32, i32) { + %3 = arith.muli %0, %arg2 : i32 + %4 = arith.addi %0, %c8_i32 : i32 + scf.yield %3, %4 : i32, i32 + } else { + %3 = arith.subi %0, %c8_i32 : i32 + scf.yield %arg1, %3 : i32, i32 + } + // CHECK: tt.return %[[R]]#2, %[[R]]#3, %[[R]]#0, %[[R]]#1 : i32, f32, i32, i32 + tt.return %0, %1, %2#0, %2#1 : i32, f32, i32, i32 +}