Skip to content

Commit

Permalink
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 11 deletions.
11 changes: 6 additions & 5 deletions test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(256 : i64) : i64
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_3]] : i64
// CHECK: %[[VAL_9:.*]] = llvm.getelementptr inbounds %[[VAL_0]]{{\[}}%[[VAL_8]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite(%[[VAL_9]], %[[VAL_1]]) {{{.*}}} : (!llvm.ptr<3>, vector<16xf32>) -> ()
// CHECK: %[[VAL_10:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i64
// CHECK: %[[VAL_11:.*]] = llvm.getelementptr inbounds %[[VAL_9]]{{\[}}%[[VAL_10]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
// CHECK: %[[VAL_12:.*]] = llvm.load %[[VAL_11]] : !llvm.ptr<3> -> vector<16xf32>
// CHECK: llvm.return %[[VAL_12]] : vector<16xf32>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_1]] : vector<16xf32> to vector<16xi32>
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.simdBlockWrite(%[[VAL_9]], %[[VAL_10]]) {{{.*}}} : (!llvm.ptr<3>, vector<16xi32>) -> ()
// CHECK: %[[VAL_11:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i64
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr inbounds %[[VAL_9]]{{\[}}%[[VAL_11]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] : !llvm.ptr<3> -> vector<16xf32>
// CHECK: llvm.return %[[VAL_13]] : vector<16xf32>
tt.func @test(%arg0: !tt.ptr<f32, 3>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> {
%0 = triton_intel_gpu.sub_group_transpose %arg0, %arg1 : tensor<16x16xf32>
tt.return %0 : tensor<16x16xf32>
Expand Down
22 changes: 22 additions & 0 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {

// -----

// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {passthrough = ["nounwind", "willreturn", ["memory", "1"]]}

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) {
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<2xi16>
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<2xi16>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.simdBlockWrite(!llvm.ptr<3>, vector<64xi16>) attributes {passthrough = ["nounwind", "willreturn", ["memory", "3"]]}

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<64xi16>) {
Expand All @@ -297,3 +308,14 @@ llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<64xi16>)
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {passthrough = ["nounwind", "willreturn", ["memory", "3"]]}

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<2xi16>) -> ()
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<2xi16>)
llvm.return
}
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
}

def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$res)>,
Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
)> {
Expand All @@ -418,7 +418,7 @@ def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
FixedVectorOf<[TritonGEN_MatrixElemType]>:$val
FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
)> {

let summary = "simd block write";
Expand Down
47 changes: 45 additions & 2 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,44 @@ struct TritonMatrix2DBlockPrefetchLowering
}
};

static bool isTySIMDOCLBuiltinAvailable(VectorType vecTy) {
unsigned numElems = vecTy.getNumElements();
if (numElems == 1 || numElems == 2 || numElems == 4 || numElems == 8)
return true;

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());
if (elemTy.getWidth() == 8 && numElems == 16)
return true;

return false;
}

template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SIMDBlockReadOp,
TritonGEN::SIMDBlockWriteOp>::value>>
static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) {
constexpr bool isWrite =
std::is_same<OpType, TritonGEN::SIMDBlockWriteOp>::value;
const LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
const unsigned numElems = vecTy.getNumElements();
// Note: OCL builtin name here differs from regular mangling.
std::string funcName = "intel_sub_group_block_";
if constexpr (isWrite)
funcName += "write";
else
funcName += "read";
funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) +
(numElems == 1 ? "" : std::to_string(numElems));
funcName =
"_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(vecTy.getElementType(), true /*isUnsigned*/);
if constexpr (isWrite)
funcName += intel::getTypeMangling(vecTy, true /*isUnsigned*/);
return funcName;
}

struct TritonSIMDBlockReadLowering
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockReadOp> {
using ConvertOpToLLVMPattern<
Expand All @@ -1243,7 +1281,9 @@ struct TritonSIMDBlockReadLowering
VectorType vecTy = op.getRes().getType();

// TODO: Remove GenISA lowering after PoC productization is completed.
const StringLiteral funcName = "llvm.genx.GenISA.simdBlockRead";
std::string funcName = "llvm.genx.GenISA.simdBlockRead";
if (isTySIMDOCLBuiltinAvailable(vecTy))
funcName = getSIMDBlockManglingName(op, vecTy);

intel::AttributeList attrs = createFunctionAttributes(
{{llvm::Attribute::NoUnwind, std::nullopt},
Expand Down Expand Up @@ -1272,7 +1312,10 @@ struct TritonSIMDBlockWriteLowering
VectorType vecTy = op.getVal().getType();

// TODO: Remove GenISA lowering after PoC productization is completed.
const StringLiteral funcName = "llvm.genx.GenISA.simdBlockWrite";
std::string funcName = "llvm.genx.GenISA.simdBlockWrite";
if (isTySIMDOCLBuiltinAvailable(vecTy))
funcName = getSIMDBlockManglingName(op, vecTy);

intel::AttributeList attrs = createFunctionAttributes(
{{llvm::Attribute::NoUnwind, std::nullopt},
{llvm::Attribute::WillReturn, std::nullopt},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,14 @@ class SubGroupTransposeOpConversion
ValueRange{subGroupOffset}, /*inbounds=*/true);

// Store matrix in local memory.
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, subGroupBasePtr, src);
Value val =
vecTy.getElementType().isInteger()
? src
: bitcast(
src,
vec_ty(int_ty(vecTy.getElementType().getIntOrFloatBitWidth()),
vecTy.getNumElements()));
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, subGroupBasePtr, val);

// Load from matrix, trasposed.
Value workItemOffset = mul(wiStride, subGroupLocalId);
Expand Down
2 changes: 1 addition & 1 deletion third_party/intel/lib/Utils/Mangling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::string getTypeMangling(Type ty, bool isUnsigned) {
.Case([isUnsigned](IntegerType ty) -> std::string {
switch (ty.getWidth()) {
case 8:
return "c";
return isUnsigned ? "h" : "c";
case 16:
return isUnsigned ? "t" : "s";
case 32:
Expand Down

0 comments on commit 5481995

Please sign in to comment.