Skip to content

Commit

Permalink
[TritonIntelGPUToLLVM] Detect more sub-group shuffle convert_layout (
Browse files Browse the repository at this point in the history
…#2573)

Detect sub-group shuffle `convert_layout` cases of more than one element
per thread.

---------

Signed-off-by: victor-eds <[email protected]>
  • Loading branch information
victor-eds authored Oct 30, 2024
1 parent e438919 commit 5d9774c
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 60 deletions.
103 changes: 103 additions & 0 deletions test/Conversion/intel/sub-group-shuffle.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,106 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return %0 : tensor<32xf32, #sliced1>
}
}

// -----

// Case of more than one element per thread in the non-sliced dimension.

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>,
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_5]])
// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_8]])
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_11]])
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_14]])
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_17]])
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_20]])
// CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(6 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_23]])
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(7 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_26]])
// CHECK: %[[VAL_29:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_29]])
// CHECK: %[[VAL_32:.*]] = llvm.mlir.constant(9 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_32]])
// CHECK: %[[VAL_35:.*]] = llvm.mlir.constant(10 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_35]])
// CHECK: %[[VAL_38:.*]] = llvm.mlir.constant(11 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_38]])
// CHECK: %[[VAL_41:.*]] = llvm.mlir.constant(12 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_41]])
// CHECK: %[[VAL_44:.*]] = llvm.mlir.constant(13 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_44]])
// CHECK: %[[VAL_47:.*]] = llvm.mlir.constant(14 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_47]])
// CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_50]])
// CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_53]])
// CHECK: %[[VAL_56:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_56]])
// CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_59]])
// CHECK: %[[VAL_62:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_62]])
// CHECK: %[[VAL_65:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_65]])
// CHECK: %[[VAL_68:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_68]])
// CHECK: %[[VAL_71:.*]] = llvm.mlir.constant(6 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_71]])
// CHECK: %[[VAL_74:.*]] = llvm.mlir.constant(7 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_74]])
// CHECK: %[[VAL_77:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_77]])
// CHECK: %[[VAL_80:.*]] = llvm.mlir.constant(9 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_80]])
// CHECK: %[[VAL_83:.*]] = llvm.mlir.constant(10 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_83]])
// CHECK: %[[VAL_86:.*]] = llvm.mlir.constant(11 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_86]])
// CHECK: %[[VAL_89:.*]] = llvm.mlir.constant(12 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_89]])
// CHECK: %[[VAL_92:.*]] = llvm.mlir.constant(13 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_92]])
// CHECK: %[[VAL_95:.*]] = llvm.mlir.constant(14 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_95]])
// CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_98]])
tt.func @test_non_sliced_multi_register(%arg0: tensor<32xf64, #sliced>) -> tensor<32xf64, #sliced1> {
%0 = triton_gpu.convert_layout %arg0 : tensor<32xf64, #sliced> -> tensor<32xf64, #sliced1>
tt.return %0 : tensor<32xf64, #sliced1>
}
}

// -----

// Case of more than one element per thread and 2 warps in the non-sliced dimension.

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}>
#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register_multi_warp
// CHECK-COUNT-64: llvm.call spir_funccc @_Z17sub_group_shuffleij
tt.func @test_non_sliced_multi_register_multi_warp(%arg0: tensor<128xi32, #sliced>) -> tensor<128xi32, #sliced1> {
%0 = triton_gpu.convert_layout %arg0 : tensor<128xi32, #sliced> -> tensor<128xi32, #sliced1>
tt.return %0 : tensor<128xi32, #sliced1>
}
}
151 changes: 91 additions & 60 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

// Return a vector such as:
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
// [registerSize / 2, 0]],
// i.e., mapping registers to lanes till laneSize and performing an ID
// conversion afterwards.
static std::vector<std::vector<int32_t>>
buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < laneSize; i *= 2) {
curr[1] = i;
bases.push_back(curr);
}
curr[1] = 0;
for (int32_t i = laneSize; i < registerSize; i *= 2) {
curr[0] = i;
bases.push_back(curr);
}
return bases;
}

// Return a vector such as:
// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
// [registerSize / (2 * laneSize), 0]]
// i.e., mapping registers to lanes till laneSize and repeating the pattern
// afterwards.
static std::vector<std::vector<int32_t>>
buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < laneSize; i *= 2) {
curr[1] = i;
bases.push_back(curr);
}
curr[1] = 0;
for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) {
curr[0] = val;
bases.push_back(curr);
}
return bases;
}

// Return a vector such as:
// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
// i.e., mapping lanes to registers.
static std::vector<std::vector<int32_t>>
buildSubGroupTransposeLaneBases(int32_t laneSize) {
std::vector<std::vector<int32_t>> bases;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < laneSize; i *= 2) {
curr[0] = i;
bases.push_back(curr);
}
return bases;
}

bool isSubGroupTranspose(const LinearLayout &srcLayout,
const LinearLayout &dstLayout) const {
MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext();
Expand Down Expand Up @@ -476,35 +532,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
//
// With N >= M.
const auto buildBasis = [&](int32_t size, std::size_t index) {
std::vector<std::vector<int32_t>> basis;
std::vector<int32_t> curr(2);
for (int32_t i = 1; i < size; i *= 2) {
curr[index] = i;
basis.push_back(curr);
}
return basis;
};
constexpr std::size_t laneIndex = 0;
constexpr std::size_t registerIndex = 1;
int32_t laneSize = conversion->getInDimSize(kLane);
std::vector<std::vector<int32_t>> registerBases =
buildBasis(laneSize, registerIndex);
{
// Populate register bases for N > M.
std::vector<int32_t> base(2);
for (int32_t i = laneSize,
registerSize = conversion->getInDimSize(kRegister);
i < registerSize; i *= 2) {
base[laneIndex] = i;
registerBases.push_back(base);
}
}
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
bases{{{kRegister, std::move(registerBases)},
{kLane, buildBasis(laneSize, laneIndex)}}};
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
return conversion == LinearLayout(bases, outDimNames);
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
int32_t laneInDimSize = conversion->getInDimSize(kLane);
return conversion->getBases().lookup(kRegister) ==
buildSubGroupTransposeRegisterBases(registerInDimSize,
laneInDimSize) &&
conversion->getBases().lookup(kLane) ==
buildSubGroupTransposeLaneBases(laneInDimSize);
}

LogicalResult
Expand Down Expand Up @@ -619,32 +653,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Expected conversion is:
// - register=1 -> (0, 1)
// ...
// register=i -> (0, i)
// - register=2**i -> (0, 2**i)
// ...
// - register=M -> (0, 2**M)
// ...
// - register=2**k -> (2**(k-M), 0)
// ...
// register=N -> (0, N)
// - register=2**N -> (2**(N-M), 0)
// - lane=1 -> (0, 0)
// ...
// lane=i -> (0, 0)
// - lane=2**j -> (0, 0)
// ...
// lane=N -> (0, 0)
// where out dims are: [register (size 1), lane (size N)]
std::vector<std::vector<int32_t>> registerBases;
{
constexpr std::size_t registerIndex = 1;
std::vector<int32_t> base(2);
for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) {
base[registerIndex] = i;
registerBases.push_back(base);
}
}

std::vector<std::vector<int32_t>> laneBases(
conversion->getInDimSizeLog2(kLane), std::vector<int32_t>{0, 0});
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
bases{{{kRegister, std::move(registerBases)},
{kLane, std::move(laneBases)}}};
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
return conversion == LinearLayout(bases, outDimNames);
// lane=2**M -> (0, 0)
// where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
//
// With N >= M.
int32_t registerInDimSize = conversion->getInDimSize(kRegister);
int32_t laneOutDimSize = conversion->getOutDimSize(kLane);
return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) &&
conversion->getBases().lookup(kRegister) ==
buildSubGroupShuffleRegisterBases(registerInDimSize,
laneOutDimSize);
}

bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const {
Expand Down Expand Up @@ -674,7 +703,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

SmallVector<Value> inVals =
unpackLLElements(loc, adaptor.getSrc(), rewriter);
assert(inVals.size() == 1 && "Expecting single element");

// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
// upstream level. We are not enabling support for all types here as that
Expand Down Expand Up @@ -703,7 +731,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
});

SmallVector<Value> outVals =
performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter);
performSubGroupShuffle(loc, inVals, subGroupSize, rewriter);

// TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR
// upstream level. We are not enabling support for all types here as that
Expand Down Expand Up @@ -734,16 +762,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}

SmallVector<Value>
performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize,
performSubGroupShuffle(Location loc, ArrayRef<Value> inVals,
int32_t subGroupSize,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> res;
Value width = i32_val(subGroupSize);
for (int32_t i = 0; i < subGroupSize; ++i)
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
for (Value val : inVals) {
for (int32_t i = 0; i < subGroupSize; ++i)
res.push_back(
rewriter
.create<mlir::gpu::ShuffleOp>(loc, val, i32_val(i), width,
mlir::gpu::ShuffleMode::IDX)
.getShuffleResult());
}
return res;
}

Expand Down

0 comments on commit 5d9774c

Please sign in to comment.