From 5481995293c748fd3023ad91fa0d8def6f2dc71a Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 11 Sep 2024 12:08:41 -0400 Subject: [PATCH] [TritonGEN] Use OCL builtins for subgroup block read/write (#2178) Use `intel_sub_group_block_[read|write]` defined in https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html, https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_char.html, https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_short.html, https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups_long.html, and https://github.com/KhronosGroup/OpenCL-Docs/blob/main/extensions/cl_intel_subgroup_local_block_io.asciidoc. --------- Signed-off-by: Whitney Tsang --- ...tritongpu_to_llvm_intel_advanced_path.mlir | 11 +++-- test/TritonGEN/tritongen-to-llvm.mlir | 22 +++++++++ .../Dialect/TritonGEN/IR/TritonGENOps.td | 4 +- .../TritonGENToLLVM/TritonGENToLLVMPass.cpp | 47 ++++++++++++++++++- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 9 +++- third_party/intel/lib/Utils/Mangling.cpp | 2 +- 6 files changed, 84 insertions(+), 11 deletions(-) diff --git a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir index ed2a7b5484..30eca7568c 100644 --- a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir +++ b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir @@ -281,11 +281,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(256 : i64) : i64 // CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_3]] : i64 // CHECK: %[[VAL_9:.*]] = llvm.getelementptr inbounds %[[VAL_0]]{{\[}}%[[VAL_8]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 -// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite(%[[VAL_9]], %[[VAL_1]]) {{{.*}}} : (!llvm.ptr<3>, vector<16xf32>) -> () -// CHECK: %[[VAL_10:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i64 -// CHECK: %[[VAL_11:.*]] = llvm.getelementptr inbounds %[[VAL_9]]{{\[}}%[[VAL_10]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 -// CHECK: %[[VAL_12:.*]] = llvm.load %[[VAL_11]] : !llvm.ptr<3> -> vector<16xf32> -// CHECK: llvm.return %[[VAL_12]] : vector<16xf32> +// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_1]] : vector<16xf32> to vector<16xi32> +// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite(%[[VAL_9]], %[[VAL_10]]) {{{.*}}} : (!llvm.ptr<3>, vector<16xi32>) -> () +// CHECK: %[[VAL_11:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i64 +// CHECK: %[[VAL_12:.*]] = llvm.getelementptr inbounds %[[VAL_9]]{{\[}}%[[VAL_11]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 +// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] : !llvm.ptr<3> -> vector<16xf32> +// CHECK: llvm.return %[[VAL_13]] : vector<16xf32> tt.func @test(%arg0: !tt.ptr, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { %0 = triton_intel_gpu.sub_group_transpose %arg0, %arg1 : tensor<16x16xf32> tt.return %0 : tensor<16x16xf32> diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 30a2dd6f3b..469a764de2 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -289,6 +289,17 @@ llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { // ----- +// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {passthrough = ["nounwind", "willreturn", ["memory", "1"]]} + +llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { + // CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) { + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<2xi16> + %ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<2xi16> + llvm.return +} + +// ----- + // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockWrite(!llvm.ptr<3>, vector<64xi16>) attributes {passthrough = ["nounwind", "willreturn", ["memory", "3"]]} llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<64xi16>) { @@ -297,3 +308,14 @@ llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<64xi16>) triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>) llvm.return } + +// ----- + +// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {passthrough = ["nounwind", "willreturn", ["memory", "3"]]} + +llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) { + // CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<2xi16>) -> () + triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<2xi16>) + llvm.return +} diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index 941657fd6d..66cde56ea1 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -397,7 +397,7 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">, } def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">, - Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$res)>, + Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>, Arguments<(ins Arg:$ptr )> { @@ -418,7 +418,7 @@ def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">, def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">, Arguments<(ins Arg:$ptr, - FixedVectorOf<[TritonGEN_MatrixElemType]>:$val + FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val )> { let summary = "simd block write"; diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 97062dc9bc..696a5b4430 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -1231,6 +1231,44 @@ struct TritonMatrix2DBlockPrefetchLowering } }; +static bool isTySIMDOCLBuiltinAvailable(VectorType vecTy) { + unsigned numElems = vecTy.getNumElements(); + if (numElems == 1 || numElems == 2 || numElems == 4 || numElems == 8) + return true; + + // FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it. + IntegerType elemTy = cast(vecTy.getElementType()); + if (elemTy.getWidth() == 8 && numElems == 16) + return true; + + return false; +} + +template ::value>> +static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) { + constexpr bool isWrite = + std::is_same::value; + const LLVM::LLVMPointerType ptrTy = op.getPtr().getType(); + const unsigned numElems = vecTy.getNumElements(); + // Note: OCL builtin name here differs from regular mangling. + std::string funcName = "intel_sub_group_block_"; + if constexpr (isWrite) + funcName += "write"; + else + funcName += "read"; + funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) + + (numElems == 1 ? "" : std::to_string(numElems)); + funcName = + "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" + + std::to_string(ptrTy.getAddressSpace()) + + intel::getTypeMangling(vecTy.getElementType(), true /*isUnsigned*/); + if constexpr (isWrite) + funcName += intel::getTypeMangling(vecTy, true /*isUnsigned*/); + return funcName; +} + struct TritonSIMDBlockReadLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -1243,7 +1281,9 @@ struct TritonSIMDBlockReadLowering VectorType vecTy = op.getRes().getType(); // TODO: Remove GenISA lowering after PoC productization is completed. - const StringLiteral funcName = "llvm.genx.GenISA.simdBlockRead"; + std::string funcName = "llvm.genx.GenISA.simdBlockRead"; + if (isTySIMDOCLBuiltinAvailable(vecTy)) + funcName = getSIMDBlockManglingName(op, vecTy); intel::AttributeList attrs = createFunctionAttributes( {{llvm::Attribute::NoUnwind, std::nullopt}, @@ -1272,7 +1312,10 @@ struct TritonSIMDBlockWriteLowering VectorType vecTy = op.getVal().getType(); // TODO: Remove GenISA lowering after PoC productization is completed. - const StringLiteral funcName = "llvm.genx.GenISA.simdBlockWrite"; + std::string funcName = "llvm.genx.GenISA.simdBlockWrite"; + if (isTySIMDOCLBuiltinAvailable(vecTy)) + funcName = getSIMDBlockManglingName(op, vecTy); + intel::AttributeList attrs = createFunctionAttributes( {{llvm::Attribute::NoUnwind, std::nullopt}, {llvm::Attribute::WillReturn, std::nullopt}, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index f23b098393..e3da9dc1a8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -619,7 +619,14 @@ class SubGroupTransposeOpConversion ValueRange{subGroupOffset}, /*inbounds=*/true); // Store matrix in local memory. - rewriter.create(loc, subGroupBasePtr, src); + Value val = + vecTy.getElementType().isInteger() + ? src + : bitcast( + src, + vec_ty(int_ty(vecTy.getElementType().getIntOrFloatBitWidth()), + vecTy.getNumElements())); + rewriter.create(loc, subGroupBasePtr, val); // Load from matrix, trasposed. Value workItemOffset = mul(wiStride, subGroupLocalId); diff --git a/third_party/intel/lib/Utils/Mangling.cpp b/third_party/intel/lib/Utils/Mangling.cpp index 6ddeb644cc..b6ef04c2d3 100644 --- a/third_party/intel/lib/Utils/Mangling.cpp +++ b/third_party/intel/lib/Utils/Mangling.cpp @@ -20,7 +20,7 @@ std::string getTypeMangling(Type ty, bool isUnsigned) { .Case([isUnsigned](IntegerType ty) -> std::string { switch (ty.getWidth()) { case 8: - return "c"; + return isUnsigned ? "h" : "c"; case 16: return isUnsigned ? "t" : "s"; case 32: