diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index e62e8d026c..bebd25643e 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -428,19 +428,3 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } - -// ----- - -llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { - // expected-error @+1 {{'triton_gen.simdblockread' op unsupported vector type}} - %ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16> - llvm.return -} - -// ----- - -llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val: vector<64xi16>) { - // expected-error @+1 {{'triton_gen.simdblockwrite' op unsupported vector type}} - triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>) - llvm.return -} diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index e3a4cdbfac..8c2ff913bb 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -241,10 +241,10 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b // CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {memory_effects = #llvm.memory_effects, no_unwind, will_return} -llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { - // CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) { +llvm.func @triton_gen.sub_group_block_read(%ptr: !llvm.ptr<3>) { + // CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<3>) { // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<2xi16> - %ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<2xi16> + %ret = triton_gen.sub_group_block_read %ptr : !llvm.ptr<3> -> vector<2xi16> llvm.return } @@ -252,9 +252,18 @@ llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) { // CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {memory_effects = #llvm.memory_effects, no_unwind, will_return} -llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) { - // CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) { +llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) { + // CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) { // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<2xi16>) -> () - triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<2xi16>) + triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, vector<2xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<1>, %val : i32) { + // CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<1>, %arg1: i32) { + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS1jj(%arg0, %arg1) {{.*}} : (!llvm.ptr<1>, i32) -> () + triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, i32 llvm.return } diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 90e2336ded..2041a5da7a 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -125,16 +125,16 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base llvm.return } -llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) { - // CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) { - // CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<2xi16> - triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<2xi16> +llvm.func @triton_gen.sub_group_block_read(%ptr : !llvm.ptr<1>) { + // CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<1>) { + // CHECK-NEXT: triton_gen.sub_group_block_read %arg0 : !llvm.ptr<1> -> vector<2xi16> + triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<2xi16> llvm.return } -llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<2xi16>) { - // CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<2xi16>) { - // CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<2xi16>) - triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<2xi16>) +llvm.func @triton_gen.sub_group_block_write(%ptr : !llvm.ptr<3>, %val : i32) { + // CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: i32) { + // CHECK-NEXT: triton_gen.sub_group_block_write %arg0, %arg1 : !llvm.ptr<3>, i32 + triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, i32 llvm.return } diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index dde9fd97e3..5ace55decd 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -314,46 +314,97 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">, let hasVerifier = 1; } -def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">, - Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>, - Arguments<(ins - Arg:$ptr - )> { - - let summary = "simd block read"; +def TritonGEN_SubGroupBlockMemoryAccessElementType + : AnyTypeOf<[I8, I16, I32, I64], + "Valid sub-group block memory access element type">; + +def TritonGEN_SubGroupBlockMemoryAccessType + : AnyTypeOf<[TritonGEN_SubGroupBlockMemoryAccessElementType, + FixedVectorOfLengthAndType< + [2, 4, 8], + [TritonGEN_SubGroupBlockMemoryAccessElementType]>, + // Vectors of length 16 only allowed for i8 for now. + FixedVectorOfLengthAndType<[16], [I8]>], + "Valid sub-group block memory access type">; + +def TritonGEN_SubGroupBlockMemoryAccessPointerType + : Type($_self)" # + ".getAddressSpace() == " # + "static_cast(kCrossWorkgroup)">, + CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" # + ".getAddressSpace() == " # + "static_cast(kWorkgroup)">]>]>, + "LLVM pointer in local or global OpenCL address space", + "::mlir::LLVM::LLVMPointerType">; + +def TritonGEN_SubGroupBlockReadOp: TritonGEN_Op<"sub_group_block_read"> { + let summary = "Sub-group block read."; let description = [{ - The `triton_gen.simdblockread` operation performs simd block read from - a start address without laneId offset. The parameters are: - $ptr - the base address to read data + The `triton_gen.sub_group_block_read` reads a scalar or vector for each + work-item in the sub-group from pointer `ptr` as a block operation. + The data is read strided, so the first value is read from: + ``` + ptr[sub_group_local_id] + ``` + and the second one is: + ``` + ptr[sub_group_local_id + sub_group_size] + ``` + etc. + + `ptr` must be aligned to the size of the element type of `res`. + + Example: + ```mlir + %0 = triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<4xi32> + ``` }]; + let arguments = (ins + Arg:$ptr); + + let results = (outs TritonGEN_SubGroupBlockMemoryAccessType:$res); + let assemblyFormat = [{ - operands ` ` attr-dict `:` functional-type(operands, results) + $ptr attr-dict `:` qualified(type($ptr)) `->` type($res) }]; - - let hasVerifier = 1; } -def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">, - Arguments<(ins - Arg:$ptr, - FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val - )> { - +def TritonGEN_SubGroupBlockWriteOp : TritonGEN_Op<"sub_group_block_write"> { let summary = "simd block write"; let description = [{ - The `triton_gen.simdblockwrite` operation performs simd block write to - a start address without laneId offset. The parameters are: - $ptr - the base address to be written - $val - the value vector to write + The `triton_gen.sub_group_block_write` writes a scalar or vector for each + work-item in the sub-group from pointer `ptr` as a block operation. + The data is read strided, so the first value is written to: + ``` + ptr[sub_group_local_id] + ``` + and the second one is: + ``` + ptr[sub_group_local_id + sub_group_size] + ``` + etc. + + `ptr` must be aligned to the size of the element type of `res`. + + Example: + ```mlir + %0 = triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, vector<4xi32> + ``` }]; + let arguments = (ins + Arg:$ptr, + TritonGEN_SubGroupBlockMemoryAccessType:$val); + + let results = (outs); + let assemblyFormat = [{ - operands ` ` attr-dict `:` `(` type(operands) `)` + $ptr `,` $val attr-dict `:` qualified(type($ptr)) `,` type($val) }]; - - let hasVerifier = 1; } + #endif // TRITONGEN_OPS diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index 0fc819a774..17ae98733b 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -48,18 +48,6 @@ template static LogicalResult verifyMatrixInput(Op op) { return success(); } -static LogicalResult verifySIMDBlockTy(Operation *op, VectorType vecTy) { - unsigned numElems = vecTy.getNumElements(); - IntegerType elemTy = cast(vecTy.getElementType()); - - // FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it. - if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 && - (elemTy.getWidth() != 8 || numElems != 16)) - return op->emitOpError("unsupported vector type"); - - return success(); -} - //===----------------------------------------------------------------------===// // gen.sub_group_reduce //===----------------------------------------------------------------------===// @@ -438,19 +426,3 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() { return success(); } - -//===----------------------------------------------------------------------===// -// gen.simdblockread -//===----------------------------------------------------------------------===// - -LogicalResult TritonGEN::SIMDBlockReadOp::verify() { - return verifySIMDBlockTy(*this, getRes().getType()); -} - -//===----------------------------------------------------------------------===// -// gen.simdblockwrite -//===----------------------------------------------------------------------===// - -LogicalResult TritonGEN::SIMDBlockWriteOp::verify() { - return verifySIMDBlockTy(*this, getVal().getType()); -} diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 78df567d61..d144345a56 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -28,6 +28,8 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/identity.h" #include "llvm/IR/Attributes.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ModRef.h" @@ -935,42 +937,50 @@ struct TritonMatrix2DBlockPrefetchLowering }; template ::value>> -static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) { + OpType, TritonGEN::SubGroupBlockReadOp, + TritonGEN::SubGroupBlockWriteOp>::value>> +static std::string getSubGroupBlockManglingName(OpType op, Type type) { constexpr bool isWrite = - std::is_same::value; + std::is_same::value; const LLVM::LLVMPointerType ptrTy = op.getPtr().getType(); - const unsigned numElems = vecTy.getNumElements(); // Note: OCL builtin name here differs from regular mangling. std::string funcName = "intel_sub_group_block_"; if constexpr (isWrite) funcName += "write"; else funcName += "read"; - funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) + - (numElems == 1 ? "" : std::to_string(numElems)); - funcName = - "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" + - std::to_string(ptrTy.getAddressSpace()) + - intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true); + Type elementType = + TypeSwitch(type) + .Case([](VectorType vecType) { return vecType.getElementType(); }) + // Scalar case + .Default(llvm::identity()); + const unsigned numElems = + TypeSwitch(type) + .Case([](VectorType vecType) { return vecType.getNumElements(); }) + // Scalar case + .Default(0u); + funcName += "_u" + intel::getTypeMangling(elementType) + + (numElems ? std::to_string(numElems) : ""); + funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" + + std::to_string(ptrTy.getAddressSpace()) + + intel::getTypeMangling(elementType, /*isUnsigned=*/true); if constexpr (isWrite) - funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true); + funcName += intel::getTypeMangling(type, /*isUnsigned=*/true); return funcName; } -struct TritonSIMDBlockReadLowering - : public ConvertOpToLLVMPattern { +struct TritonSubGroupBlockReadLowering + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - TritonGEN::SIMDBlockReadOp>::ConvertOpToLLVMPattern; + TritonGEN::SubGroupBlockReadOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TritonGEN::SIMDBlockReadOp op, OpAdaptor adaptor, + matchAndRewrite(TritonGEN::SubGroupBlockReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { LLVM::LLVMPointerType ptrTy = op.getPtr().getType(); - VectorType vecTy = op.getRes().getType(); + Type type = op.getRes().getType(); - std::string funcName = getSIMDBlockManglingName(op, vecTy); + std::string funcName = getSubGroupBlockManglingName(op, type); auto memAttr = rewriter.getAttr( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, @@ -978,26 +988,26 @@ struct TritonSIMDBlockReadLowering auto funcAttrs = noUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; LLVM::CallOp call = createDeviceFunctionCall( - rewriter, funcName, vecTy, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {}); + rewriter, funcName, type, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {}); rewriter.replaceOp(op, call.getResult()); return success(); } }; -struct TritonSIMDBlockWriteLowering - : public ConvertOpToLLVMPattern { +struct TritonSubGroupBlockWriteLowering + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - TritonGEN::SIMDBlockWriteOp>::ConvertOpToLLVMPattern; + TritonGEN::SubGroupBlockWriteOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TritonGEN::SIMDBlockWriteOp op, OpAdaptor adaptor, + matchAndRewrite(TritonGEN::SubGroupBlockWriteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *ctx = rewriter.getContext(); LLVM::LLVMPointerType ptrTy = op.getPtr().getType(); - VectorType vecTy = op.getVal().getType(); + Type type = op.getVal().getType(); - std::string funcName = getSIMDBlockManglingName(op, vecTy); + std::string funcName = getSubGroupBlockManglingName(op, type); auto memAttr = rewriter.getAttr( /*other=*/LLVM::ModRefInfo::NoModRef, @@ -1006,7 +1016,7 @@ struct TritonSIMDBlockWriteLowering auto funcAttrs = noUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; LLVM::CallOp call = createDeviceFunctionCall( - rewriter, funcName, void_ty(ctx), {ptrTy, vecTy}, + rewriter, funcName, void_ty(ctx), {ptrTy, type}, {op.getPtr(), op.getVal()}, {}, funcAttrs); rewriter.replaceOp(op, call); @@ -1071,12 +1081,13 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void mlir::triton::populateTritonGENToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add< - TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering, - TritonSubGroupReduceLowering, TritonSubGroupScanLowering, - TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering, - TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering, - TritonSIMDBlockReadLowering, TritonSIMDBlockWriteLowering>(converter); + patterns + .add(converter); } void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry ®istry) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d5e003d90e..c20258b263 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -792,14 +792,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion vals = vals.drop_front(vecWidth)) { ArrayRef curr = vals.take_front(vecWidth); Value vec = wrapInVector(loc, opType, curr, rewriter); - rewriter.create(loc, base, vec); + rewriter.create(loc, base, vec); base = gep(base.getType(), opType, base, ArrayRef{offset}, /*inbounds=*/true); } // Load from matrix, non-trasposed. - // As per SIMD block semantics, we have stored the elements in a matrix of - // `Nxsub_group_size` size, so we need to load back in blocks of + // As per sub-group block semantics, we have stored the elements in a matrix + // of `Nxsub_group_size` size, so we need to load back in blocks of // `sub_group_size` (`N/sub_group_size` loads). Value workItemOffset = mul(wiStride, subGroupLocalId); Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 0af192e560..eabc89531f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -46,7 +46,7 @@ static void decomposeBlockStore(ConversionPatternRewriter &rewriter, VectorType::get(maxBlockStoreWidth, vecTy.getElementType()); Value offset = i32_val(subGroupSize); for (int i = 0; i < vecTy.getNumElements() / maxBlockStoreWidth; ++i) { - rewriter.create( + rewriter.create( loc, base, rewriter .create(loc, decomposedVecTy, val, i) @@ -313,7 +313,7 @@ class LoadStorePrefetchOpConversion i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); SmallVector values; for (int i = 0; i < 64 / maxBlockLoadi16Width; ++i) { - auto simdRead = rewriter.create( + auto simdRead = rewriter.create( loc, decomposedVecTy, base); values.push_back(simdRead.getRes()); base = gep(ptrToSharedMemTy, decomposedVecTy, base, offset);