Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose large simdblockread to smaller simdblockreads #2193

Merged
merged 9 commits into from
Sep 13, 2024
27 changes: 25 additions & 2 deletions test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,38 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockRead(!llvm.ptr<3>) -> vector<64xi16>
// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t(!llvm.ptr<3>) -> vector<8xi16>
// CHECK-LABEL: @slm_load
tt.func public @slm_load(%arg0: !tt.ptr<f16, 3>) {
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%c64_i64 = arith.constant 64 : i64
%ptr = tt.make_tensor_ptr %arg0, [%c0_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #dot0>, 3>
// CHECK: {{.*}} = llvm.call spir_funccc @llvm.genx.GenISA.simdBlockRead({{.*}}) {{.*}} : (!llvm.ptr<3>) -> vector<64xi16>
// CHECK-COUNT-2: [[SUBGROUP_SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[READ1:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ2:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ3:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ4:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ5:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ6:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ7:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ8:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[GLUE1:%.*]] = llvm.shufflevector [[READ1]], [[READ2]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16>
// CHECK: [[GLUE2:%.*]] = llvm.shufflevector [[READ3]], [[READ4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16>
// CHECK: [[GLUE3:%.*]] = llvm.shufflevector [[READ5]], [[READ6]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16>
// CHECK: [[GLUE4:%.*]] = llvm.shufflevector [[READ7]], [[READ8]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16>
// CHECK: [[GLUE5:%.*]] = llvm.shufflevector [[GLUE1]], [[GLUE2]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xi16>
// CHECK: [[GLUE6:%.*]] = llvm.shufflevector [[GLUE3]], [[GLUE4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xi16>
// CHECK: [[READ:%.*]] = llvm.shufflevector [[GLUE5]], [[GLUE6]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xi16>
// CHECK: llvm.bitcast [[READ]] : vector<64xi16> to vector<64xf16>
%ld = tt.load %ptr {DotIdx = 0 : i32} : !tt.ptr<tensor<16x64xf16, #dot0>, 3>
tt.return
}
Expand Down
8 changes: 8 additions & 0 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,11 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
llvm.return
}

// -----

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// expected-error @+1 {{'triton_gen.simdblockread' op unsupported vector type}}
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
llvm.return
}
11 changes: 0 additions & 11 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,6 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b

// -----

// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockRead(!llvm.ptr<3>) -> vector<64xi16> attributes {passthrough = ["nounwind", "willreturn", ["memory", "1"]]}

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) {
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockRead(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<64xi16>
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {passthrough = ["nounwind", "willreturn", ["memory", "1"]]}

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
Expand Down
4 changes: 2 additions & 2 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base

llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) {
// CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<64xi16>
triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<64xi16>
// CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<2xi16>
triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<2xi16>
llvm.return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
let assemblyFormat = [{
operands ` ` attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
Expand Down
17 changes: 17 additions & 0 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,20 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {

return success();
}

//===----------------------------------------------------------------------===//
// gen.simdblockread
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
VectorType vecTy = getRes().getType();
unsigned numElems = vecTy.getNumElements();
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
(elemTy.getWidth() != 8 || numElems != 16))
return emitOpError("unsupported vector type");

return success();
}
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 3 additions & 7 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,9 +1263,9 @@ static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) {
funcName =
"_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(vecTy.getElementType(), true /*isUnsigned*/);
intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true);
if constexpr (isWrite)
funcName += intel::getTypeMangling(vecTy, true /*isUnsigned*/);
funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true);
return funcName;
}

Expand All @@ -1280,11 +1280,7 @@ struct TritonSIMDBlockReadLowering
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getRes().getType();

// TODO: Remove GenISA lowering after PoC productization is completed.
std::string funcName = "llvm.genx.GenISA.simdBlockRead";
if (isTySIMDOCLBuiltinAvailable(vecTy))
funcName = getSIMDBlockManglingName(op, vecTy);

std::string funcName = getSIMDBlockManglingName(op, vecTy);
intel::AttributeList attrs = createFunctionAttributes(
{{llvm::Attribute::NoUnwind, std::nullopt},
{llvm::Attribute::WillReturn, std::nullopt},
Expand Down
18 changes: 16 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,22 @@ class LoadStorePrefetchOpConversion
if constexpr (std::is_same_v<OpType, LoadOp>) {
rewriter.restoreInsertionPoint(insertPoint);

TritonGEN::SIMDBlockReadOp simdRead =
rewriter.create<TritonGEN::SIMDBlockReadOp>(loc, v64i16Ty, base);
constexpr unsigned maxBlockLoadi16Width = 8;
VectorType decomposedVecTy =
VectorType::get(maxBlockLoadi16Width, i16_ty);
auto mod = op->template getParentOfType<mlir::ModuleOp>();
Value offset =
i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
SmallVector<Value> values;
for (int i = 0; i < 64 / maxBlockLoadi16Width; ++i) {
auto simdRead = rewriter.create<TritonGEN::SIMDBlockReadOp>(
loc, decomposedVecTy, base);
values.push_back(simdRead.getRes());
base = gep(ptrToSharedMemTy, decomposedVecTy, base, offset);
}
auto simdRead =
rewriter.create<triton::gpu::intel::GlueOp>(loc, v64i16Ty, values);

VectorType v64Ty = VectorType::get(64, elemType);
rewriter.replaceOp(op, bitcast(simdRead.getRes(), v64Ty));

Expand Down