Skip to content

Commit

Permalink
Decompose large simdblockwrite to smaller simdblockwrites
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang committed Sep 13, 2024
1 parent 71afe7a commit 128861d
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 52 deletions.
34 changes: 31 additions & 3 deletions test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockWrite(!llvm.ptr<3>, vector<64xi16>)
// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t(!llvm.ptr<3>, vector<8xi16>)
// CHECK-LABEL: @slm_store
tt.func public @slm_store(%arg0: !tt.ptr<f16, 3>, %arg1: tensor<16x64xf16, #dot0>) {
%c0_i32 = arith.constant 0 : i32
Expand All @@ -301,7 +301,30 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
%c64_i64 = arith.constant 64 : i64
%ptr = tt.make_tensor_ptr %arg0, [%c0_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #dot0>, 3>
// CHECK: [[CAST:%.*]] = llvm.bitcast {{.*}} : vector<64xf16> to vector<64xi16>
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite({{.*}}, [[CAST]]) {{.*}} : (!llvm.ptr<3>, vector<64xi16>) -> ()
// CHECK: [[SUBGROUP_SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE1:%.*]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE2]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE1]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE2]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [32, 33, 34, 35, 36, 37, 38, 39] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE1]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [40, 41, 42, 43, 44, 45, 46, 47] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE2]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [48, 49, 50, 51, 52, 53, 54, 55] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE1]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector [[CAST]], [[CAST]] [56, 57, 58, 59, 60, 61, 62, 63] : vector<64xi16>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us8PU3AS3tDv8_t([[BASE2]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi16>) -> ()
tt.store %ptr, %arg1 {DotIdx = 0 : i32} : !tt.ptr<tensor<16x64xf16, #dot0>, 3>
tt.return
}
Expand All @@ -322,7 +345,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_3]] : i64
// CHECK: %[[VAL_9:.*]] = llvm.getelementptr inbounds %[[VAL_0]]{{\[}}%[[VAL_8]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_1]] : vector<16xf32> to vector<16xi32>
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite(%[[VAL_9]], %[[VAL_10]]) {{{.*}}} : (!llvm.ptr<3>, vector<16xi32>) -> ()
// CHECK: [[SUBGROUP_SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_9]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: [[BASE:%.*]] = llvm.getelementptr %[[VAL_9]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi32>
// CHECK: [[EXTRACT:%.*]] = llvm.shufflevector %[[VAL_10]], %[[VAL_10]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xi32>
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j([[BASE]], [[EXTRACT]]) {{.*}} : (!llvm.ptr<3>, vector<8xi32>) -> ()
// CHECK: %[[VAL_11:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i64
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr inbounds %[[VAL_9]]{{\[}}%[[VAL_11]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] : !llvm.ptr<3> -> vector<16xf32>
Expand Down
8 changes: 8 additions & 0 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,11 @@ llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
llvm.return
}

// -----

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val: vector<64xi16>) {
// expected-error @+1 {{'triton_gen.simdblockwrite' op unsupported vector type}}
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
llvm.return
}
11 changes: 0 additions & 11 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,6 @@ llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {

// -----

// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockWrite(!llvm.ptr<3>, vector<64xi16>) attributes {passthrough = ["nounwind", "willreturn", ["memory", "3"]]}

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<64xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<64xi16>) {
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<64xi16>) -> ()
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {passthrough = ["nounwind", "willreturn", ["memory", "3"]]}

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
Expand Down
8 changes: 4 additions & 4 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) {
llvm.return
}

llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<64xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<64xi16>) {
// CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<64xi16>)
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<64xi16>)
llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<2xi16>) {
// CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<2xi16>)
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<2xi16>)
llvm.return
}
Original file line number Diff line number Diff line change
Expand Up @@ -435,5 +435,7 @@ def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
let assemblyFormat = [{
operands ` ` attr-dict `:` `(` type(operands) `)`
}];

let hasVerifier = 1;
}
#endif // TRITONGEN_OPS
28 changes: 20 additions & 8 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
return success();
}

template <typename Op>
static LogicalResult verifySIMDBlockTy(Op op, VectorType vecTy) {
unsigned numElems = vecTy.getNumElements();
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
(elemTy.getWidth() != 8 || numElems != 16))
return op->emitOpError("unsupported vector type");

return success();
}

//===----------------------------------------------------------------------===//
// gen.sub_group_reduce
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -441,14 +454,13 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
VectorType vecTy = getRes().getType();
unsigned numElems = vecTy.getNumElements();
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());
return verifySIMDBlockTy(this, getRes().getType());
}

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
(elemTy.getWidth() != 8 || numElems != 16))
return emitOpError("unsupported vector type");
//===----------------------------------------------------------------------===//
// gen.simdblockwrite
//===----------------------------------------------------------------------===//

return success();
LogicalResult TritonGEN::SIMDBlockWriteOp::verify() {
return verifySIMDBlockTy(this, getVal().getType());
}
18 changes: 1 addition & 17 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,19 +1231,6 @@ struct TritonMatrix2DBlockPrefetchLowering
}
};

static bool isTySIMDOCLBuiltinAvailable(VectorType vecTy) {
unsigned numElems = vecTy.getNumElements();
if (numElems == 1 || numElems == 2 || numElems == 4 || numElems == 8)
return true;

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());
if (elemTy.getWidth() == 8 && numElems == 16)
return true;

return false;
}

template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SIMDBlockReadOp,
TritonGEN::SIMDBlockWriteOp>::value>>
Expand Down Expand Up @@ -1307,10 +1294,7 @@ struct TritonSIMDBlockWriteLowering
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getVal().getType();

// TODO: Remove GenISA lowering after PoC productization is completed.
std::string funcName = "llvm.genx.GenISA.simdBlockWrite";
if (isTySIMDOCLBuiltinAvailable(vecTy))
funcName = getSIMDBlockManglingName(op, vecTy);
std::string funcName = getSIMDBlockManglingName(op, vecTy);

intel::AttributeList attrs = createFunctionAttributes(
{{llvm::Attribute::NoUnwind, std::nullopt},
Expand Down
34 changes: 25 additions & 9 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ VectorType getVectorType(RankedTensorType tensorType, Type elemType) {
return vec_ty(elemType, num);
};

static void decomposeBlockStore(ConversionPatternRewriter &rewriter,
Location loc, Value base, Value val,
VectorType vecTy, unsigned subGroupSize) {
constexpr unsigned maxBlockStoreWidth = 8;
VectorType decomposedVecTy =
VectorType::get(maxBlockStoreWidth, vecTy.getElementType());
Value offset = i32_val(subGroupSize);
for (int i = 0; i < vecTy.getNumElements() / maxBlockStoreWidth; ++i) {
rewriter.create<TritonGEN::SIMDBlockWriteOp>(
loc, base,
rewriter.create<ExtractOp>(loc, decomposedVecTy, val, i).getRes());
base = gep(base.getType(), decomposedVecTy, base, offset);
}
}

/// v2i32 [offsetX, offsetY] for 2D tensor desc.
class MakeTensorPtrOpConversion
: public ConvertTritonGPUOpToLLVMPattern<MakeTensorPtrOp> {
Expand Down Expand Up @@ -318,8 +333,10 @@ class LoadStorePrefetchOpConversion
}
val = bitcast(val, v64i16Ty);

TritonGEN::SIMDBlockWriteOp simdWrite =
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, base, val);
auto mod = op->template getParentOfType<mlir::ModuleOp>();
decomposeBlockStore(
rewriter, loc, base, val, v64i16Ty,
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));

rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -798,14 +815,13 @@ class SubGroupTransposeOpConversion
ValueRange{subGroupOffset}, /*inbounds=*/true);

// Store matrix in local memory.
VectorType intVecTy =
vec_ty(int_ty(vecTy.getElementType().getIntOrFloatBitWidth()),
vecTy.getNumElements());
Value val =
vecTy.getElementType().isInteger()
? src
: bitcast(
src,
vec_ty(int_ty(vecTy.getElementType().getIntOrFloatBitWidth()),
vecTy.getNumElements()));
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, subGroupBasePtr, val);
vecTy.getElementType().isInteger() ? src : bitcast(src, intVecTy);
decomposeBlockStore(rewriter, loc, subGroupBasePtr, val, intVecTy,
threadsPerWarp);

// Load from matrix, trasposed.
Value workItemOffset = mul(wiStride, subGroupLocalId);
Expand Down

0 comments on commit 128861d

Please sign in to comment.