From 73ba8b639226fbc505115db59189a1362c350075 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Tue, 10 Dec 2024 07:37:49 +0100 Subject: [PATCH 1/7] [AMD] NFC: Use TableGen enum to define schedule variants (#5376) This PR refactors the instruction scheduling enums. Now, it is implemented in the MLIR. --- test/TritonGPU/amd/amd-instruction-sched.mlir | 6 +- .../Dialect/TritonAMDGPU/IR/CMakeLists.txt | 2 + .../include/Dialect/TritonAMDGPU/IR/Dialect.h | 1 + .../TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td | 28 ++++++++++ .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 2 + .../TritonAMDGPUToLLVM/SchedInstructions.cpp | 56 +++++++------------ 6 files changed, 56 insertions(+), 39 deletions(-) diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index 8cc3ae64f4..8cf3bdcafd 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -1,8 +1,8 @@ -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 -// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0 +// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm_iglp_1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2 -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local_prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1 // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2 diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt index 25a57075be..094ecfc7d4 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -11,6 +11,8 @@ add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonAMDGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td) +mlir_tablegen(TritonAMDGPUEnums.h.inc -gen-enum-decls) +mlir_tablegen(TritonAMDGPUEnums.cpp.inc -gen-enum-defs) mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls) mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index 486fd60293..a771e55609 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -33,6 +33,7 @@ // clang-format off #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc" // clang-format on #define GET_ATTRDEF_CLASSES diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index c0aa08421b..491989669c 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -26,6 +26,7 @@ include "mlir/IR/AttrTypeBase.td" include "TritonAMDGPUDialect.td" +include "mlir/IR/EnumAttr.td" class TritonAMDGPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> @@ -58,5 +59,32 @@ def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> { let assemblyFormat = "`<` params `>`"; } +class TritonAMDGPU_I32Enum cases> + : I32EnumAttr { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +class TritonAMDGPU_I32EnumAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; + let cppNamespace = "::mlir::triton::amdgpu"; +} + +def SchedHintCaseNone : I32EnumAttrCase<"none", 0>; +def SchedHintCaseLLVMIglp0 : I32EnumAttrCase<"llvm_iglp_0", 1>; +def SchedHintCaseLLVMIglp1 : I32EnumAttrCase<"llvm_iglp_1", 2>; +def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 3>; + +def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum< + "SchedHint", "Instruction Scheduling Hints for AMD GPUs", [ + SchedHintCaseNone, + SchedHintCaseLLVMIglp0, + SchedHintCaseLLVMIglp1, + SchedHintCaseLocalPrefetch, + ]>; + +def TritonAMDGPU_SchedHintVariantAttr : + TritonAMDGPU_I32EnumAttr<"SchedHintVariant", TritonAMDGPU_SchedHintsEnum>; #endif diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 0e2a9304eb..3f141ee2ef 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -48,6 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { >(); } +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index d93f2ca6c6..333a78b4c5 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -217,31 +217,22 @@ struct InstructionSchedHintsRewriter : OpRewritePattern(ctx), numStages(numStages) { this->machineDescr = MachineDescr::get(arch); - std::transform(variant.begin(), variant.end(), variant.begin(), - [](unsigned char c) { return std::tolower(c); }); - - this->schedulingType = - llvm::StringSwitch(variant) - .Case("none", SchedulingType::NONE) - .Case("llvm-iglp-0", SchedulingType::LLVM_IGLP_0) - .Case("llvm-iglp-1", SchedulingType::LLVM_IGLP_1) - .Case("local-prefetch", SchedulingType::LOCAL_PREFETCH) - .Default(SchedulingType::UNKNOWN); + this->schedHint = mlir::triton::amdgpu::SchedHint::none; if (this->numStages < 2) { - this->schedulingType = SchedulingType::NONE; LDBG("ignoring instruction scheduling due to a very low num. " "stages value. Must be >= 2"); + return; } - } - enum class SchedulingType : uint32_t { - NONE = 0, - LLVM_IGLP_0, - LLVM_IGLP_1, - LOCAL_PREFETCH, - UNKNOWN - }; + std::transform(variant.begin(), variant.end(), variant.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(variant)) + this->schedHint = maybeSchedHint.value(); + else + LDBG("ignoring instruction scheduling because " + "unknown instruction scheduling variant has been provided"); + } // The following is inspired by ROCm Composable Kernel library's V3 pipelining // (see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp). @@ -416,25 +407,18 @@ struct InstructionSchedHintsRewriter LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, PatternRewriter &rewriter) const override { - if (this->schedulingType == SchedulingType::NONE) { + if (this->schedHint == mlir::triton::amdgpu::SchedHint::none) { rewriter.eraseOp(instructionSchedHint); return success(); } - if (this->schedulingType == SchedulingType::UNKNOWN) { - instructionSchedHint.emitError( - "unknown instruction scheduling variant has been provided"); - return failure(); - } - // The switch controls whether instructions are allowed to cross the basic // block boundaries at the very top and at the very bottom. Note, this is // not supposed to be used together with IGLP OPT according to the AMDGPU // backend documentation. const bool limitSchedulingRange = - !(schedulingType == SchedulingType::NONE || - schedulingType == SchedulingType::LLVM_IGLP_0 || - schedulingType == SchedulingType::LLVM_IGLP_1); + this->schedHint == mlir::triton::amdgpu::SchedHint::local_prefetch; + ; Location loc = instructionSchedHint->getLoc(); Block *block = instructionSchedHint->getBlock(); if (limitSchedulingRange) { @@ -445,15 +429,15 @@ struct InstructionSchedHintsRewriter rewriter.setInsertionPoint(block, std::prev(block->end())); - switch (schedulingType) { - case SchedulingType::LLVM_IGLP_0: - case SchedulingType::LLVM_IGLP_1: - createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); + switch (this->schedHint) { + case mlir::triton::amdgpu::SchedHint::llvm_iglp_0: + case mlir::triton::amdgpu::SchedHint::llvm_iglp_1: + createIglpOpt(rewriter, loc, static_cast(this->schedHint) - 1); break; - case SchedulingType::LOCAL_PREFETCH: + case mlir::triton::amdgpu::SchedHint::local_prefetch: createLocalPrefetchSchedule(rewriter, loc, instructionSchedHint); break; - case SchedulingType::NONE: + case mlir::triton::amdgpu::SchedHint::none: default: break; } @@ -468,7 +452,7 @@ struct InstructionSchedHintsRewriter private: int32_t numStages; - SchedulingType schedulingType; + mlir::triton::amdgpu::SchedHint schedHint; std::unique_ptr machineDescr; }; From 2ec1d1794b44d865f3cf955db98aeb77550d5d21 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Tue, 10 Dec 2024 10:07:36 -0800 Subject: [PATCH 2/7] [Backend] Implement optimized gather codegen within a warp (#5345) This PR implements a specialized codegen for `tt.gather` when it satisfies the conditions of being "warp local": it is possible to compute the output tensor without data movement across warps. `isWarpLocal` is a new function that checks this condition, and places additional restrictions to simplify codegen / separate concerns from `ttg.convert_layout`. This enables `tt.gather` to generate better code when the layout is suitable. In a subsequent PR, a special pattern will be added to generate optimized layouts for `tt.gather` when possible/profitable to enable the lowering. --- include/triton/Analysis/Utility.h | 2 + include/triton/Dialect/TritonGPU/IR/Dialect.h | 5 + lib/Analysis/Utility.cpp | 84 +++++- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 271 +++++++++++++++++- python/test/unit/language/test_core.py | 115 ++++++-- test/Conversion/allocate_shared_memory.mlir | 6 +- test/Conversion/gather_to_llvm.mlir | 271 ++++++++++++++++++ 7 files changed, 719 insertions(+), 35 deletions(-) create mode 100644 test/Conversion/gather_to_llvm.mlir diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index a3e38e177d..2e4cbbc651 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -161,6 +161,8 @@ class GatherLoweringHelper { // Get the shared memory scratch size required by this op. unsigned getScratchSizeInBytes(); + // Determine if the gather can be performed completely within a warp. + bool isWarpLocal(); private: triton::GatherOp gatherOp; diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 85c789635a..e592a9d6d1 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -214,7 +214,12 @@ LinearLayout ensureLayoutNotSmallerThan( const LinearLayout &layout, const llvm::SmallDenseMap &shape); +// Return a vector of the standard out dimension names for tensor layouts. These +// are "dim0", "dim1", etc. SmallVector standardOutDimNames(MLIRContext *ctx, int rank); +// Return an identity mapping from `inDimName` to the standard out dimensions, +// with the dimensions sized according to the shape. The bases are sorted +// according to `order`, with the most minor dimension first. LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, ArrayRef order); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 3014245e61..69eb196a95 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -413,13 +413,93 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp) : gatherOp(gatherOp) {} unsigned GatherLoweringHelper::getScratchSizeInBytes() { - // For now, lower the gather op by writing the source tensor to shared memory. - // TODO(jeff): Leverage locality to avoid using scratch space when possible. + // If the gather is warp-local, no scratch space is needed. + if (isWarpLocal()) + return 0; + + // Otherwise, performing the gather will require scratch space to communicate + // the source tensor across threads. For now, assume the whole source tensor + // is written back to shared memory. RankedTensorType srcType = gatherOp.getSrc().getType(); return product(srcType.getShape()) * ceil(srcType.getElementTypeBitWidth(), 8); } +bool GatherLoweringHelper::isWarpLocal() { + // The gather is warp-local if for each column along the gather axis in the + // source and index tensors, all the elements are owned by the same warp. + RankedTensorType srcType = gatherOp.getSrc().getType(); + RankedTensorType idxType = gatherOp.getIndices().getType(); + std::optional srcLayout = + toLinearLayout(srcType.getShape(), srcType.getEncoding()); + std::optional idxLayout = + toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // FIXME: If an unsupported layout was encountered, assume the gather is not + // warp-local. + if (!srcLayout || !idxLayout) + return false; + + Builder b(gatherOp.getContext()); + StringAttr kBlock = b.getStringAttr("block"); + StringAttr kWarp = b.getStringAttr("warp"); + StringAttr kLane = b.getStringAttr("lane"); + StringAttr kGatherDim = + b.getStringAttr("dim" + std::to_string(gatherOp.getAxis())); + + // The tensor layouts must be distributed layouts, where the basis matrix is a + // subpermutation matrix (permutation matrix plus zeros for broadcasting). + // FIXME(jeff): Check this invariant somehow. + // + // We want to know if all elements of a column along the gather axis are + // mapped to the same set of warps, which means the gather can be performed + // entirely within the warp. We need to query + // + // srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp}) + // + // But due to broadcasting, the matrix might not be invertible. But since the + // matrix is a permutation matrix (checked below), we can instead query + // + // srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim}) + // + // Which implies that changing the warp will not change the gather dimension. + // And since there is no swizzling, this applies to all warps. + if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) || + !idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim)) + return false; + + SmallVector otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + if (dim != gatherOp.getAxis()) { + otherDims.push_back(b.getStringAttr("dim" + Twine(dim))); + } + } + + // If the gather axis `dimN` is invariant to the warp, but the `(block, warp)` + // mapping to all other dimensions must be the same for both layouts. If so, + // then the warp that owns a particular index element also owns all the source + // elements it could index into. + if (srcLayout->sublayout({kBlock, kWarp}, otherDims) != + idxLayout->sublayout({kBlock, kWarp}, otherDims)) + return false; + + // The two constraints above ensure that data-movement to perform the gather + // operation are contained within a warp. The subsequent constraints simplify + // codegen. + + // Require that for any given gather column, the threads mapped to the column + // in the index and source tensors are the same. This means we don't need to + // xor shuffle across threads before emitting index shuffles; we push warp + // shuffling to layout conversions. + if (srcLayout->sublayout(kLane, otherDims) != + idxLayout->sublayout(kLane, otherDims)) + return false; + + // Otherwise, the source layout has to be invertible. This primarily means + // the codegen path doesn't support broadcasted source layouts. + return srcLayout->isInvertible(); +} + unsigned getNumScratchElements(ArrayRef shape) { if (shape.empty()) return 0; diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index 5ab81eff81..faf781369e 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -1,8 +1,10 @@ #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::gpu; namespace { class GatherOpConversion : public ConvertOpToLLVMPattern { @@ -17,12 +19,51 @@ class GatherOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override; private: + // Codegen the gather by storing the source tensor into shared memory and then + // gathering directly from shared memory. + void emitGatherInShared(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // Codegen a warp-local gather by shuffling elements across the warp and + // selecting from them. + void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + const TargetInfoBase &targetInfo; }; LogicalResult GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + GatherLoweringHelper helper(op); + // Specialize the lowering based on the source layout. Given that the cost of + // a warp shuffle is approximately half the cost of a roundtrip to shared + // memory with zero bank conflicts, we will need a more precise heuristic to + // choose between the two codegen paths and rely on the middle end to pick the + // right layout. + if (helper.isWarpLocal()) { + emitWarpLocalGather(op, adaptor, rewriter); + } else { + emitGatherInShared(op, adaptor, rewriter); + } + return success(); +} + +static Value convertIndexToI32(Location loc, Value index, + ConversionPatternRewriter &rewriter) { + unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + index = trunc(i32_ty, index); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + index = zext(i32_ty, index); + } + return index; +} + +void GatherOpConversion::emitGatherInShared( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); RankedTensorType srcType = op.getSrc().getType(); @@ -78,19 +119,10 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, /*withCTAOffset=*/true); - unsigned idxWidth = op.getIndices().getType().getElementTypeBitWidth(); unsigned axis = op.getAxis(); SmallVector results(dstIndices.size()); for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { - // The LL index computations are performed with 32 bit integers. If the - // indices are something else, cast them to i32. - if (idxWidth > 32) { - idx = trunc(i32_ty, idx); - } else if (idxWidth < 32) { - // Negative indices don't make sense, so zero-extend. - idx = zext(i32_ty, idx); - } - indices[axis] = idx; + indices[axis] = convertIndexToI32(loc, idx, rewriter); Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); Value ptr = gep(smemBase.getType(), elemType, smemBase, offset); results[i] = load(elemType, ptr); @@ -99,7 +131,224 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, Value packed = packLLElements(loc, getTypeConverter(), results, rewriter, dstType); rewriter.replaceOp(op, packed); - return success(); +} + +// High-level description of the algorithm: +// +// `isWarpLocal` checks that it is possible to compute each output element +// without data movement across warps. +// +// If the gather dim is `dimN`, then this means +// +// ll^-1(dimN)[(block, warp)] == 0 +// +// for both source and index tensors: moving along the gather axis does not +// change the warp. Broadcasted layouts are not supported, so we know the +// layouts are permutation matrices. +// +// We can check this with `ll((block, warp))[dimN] == 0`. +// +// Let `gatherCol` be a tuple of all dimensions except the gather dimension. +// We also check that the gather columns line up the same way with respect to +// the warp between the source and index tensors with +// +// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol] +// +// This means that for all index columns, the corresponding column in the source +// tensor is owned by the same warp. +// +// We also check +// +// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol] +// +// This boils down to the fact that the algorithm essentially emits a series of +// index shuffles for each index value owned by each thread, and then a pile of +// selects to pick the right value. We need to figure out given an index value +// in a particular column, what are the source register values it could read +// from and who owns them. +// +// If this relationship did not hold, then the possible source registers for +// each index value varies with the thread, meaning the value operand provided +// to each shuffle index instruction would depend on the thread ID. This isn't a +// big deal. It just means would have to emit a pile of selects before each +// shuffle as well, to pick the right source register value. But we choose not +// to handle this. +// +// The codegen algorithm emits code: +// - Given the thread ID and a particular index tensor register, figure out +// which gather column it belongs to using a layout. +// - Using the index value itself as the value for `dimN`, use another layout to +// figure out which lane in the warp owns the desired value and which register +// in that lane it is. +// - For the gather column, figure out the source registers in that column, and +// for each of them, emit an index shuffle with the same computed lane ID. +// - Use the register component to select the right value from the shuffle +// results. +void GatherOpConversion::emitWarpLocalGather( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Layout dimension names. + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); + SmallVector allDims, otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); + if (dim != op.getAxis()) { + otherDims.push_back(allDims.back()); + } + } + + // Compute the src and idx layouts. + LinearLayout srcLayout = + *toLinearLayout(srcType.getShape(), srcType.getEncoding()); + LinearLayout idxLayout = + *toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // Let `ll_src` be the source layout and `ll_idx` be the index layout. + // Let `src_col` be a tuple of dimensions except the gather dimension, + // representing a specific column in the source tensor. Likewise for + // `idx_col`. Let `src_idx` be the index into gather dimension in the source + // tensor. + // + // `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is + // the thread that contains the required element and `src_reg` is the register + // within that thread. + // + // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == + // ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // index tensor) the thread will need to read from the same column in the + // source tensor. + // + // Thus, we can obtain + // + // (src_lane, src_reg) = (ll_src^-1)( + // ll_idx(black, warp, lane, idx_reg)[otherDims], + // idxValues[idx_reg] + // )[{"lane", "register"}] + // + // And the mapping will be the correct for each thread. + // + // Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for + // each `idx_reg` (the number of index shuffles is quadratic!) and + // `llvm.select` using `src_reg` to get the right one. `K` is the number of + // elements per column owned by a thread. + + // Fully invert the source layout. We know it is invertible because + // `isWarpLocal` checked this. + LinearLayout invSrcLayout = srcLayout.invert(); + + // Sanity check: the warp must be invariant to the index because otherwise the + // gather would need to read across warps! + assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kBlock, kWarp}) && + "expected a warp-local gather"); + invSrcLayout = invSrcLayout.sublayout(allDims, {kLane, kRegister}); + + LinearLayout idxColLayout = + idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + auto [blockId, warpId, laneId] = + emitHardwareTuple(loc, rewriter, targetInfo, /*withCTAOffset=*/true, + srcLayout.getInDimSize(kLane)); + + unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); + assert(srcRegsPerThread == srcValues.size()); + + // Given a index value, we need to know which sources register values it could + // index into. This is invariant to anything other than the register, which we + // checked already. Compute the full reverse map from + // + // idx_reg -> gather_column -> (src_reg0, src_reg1, ...) + // + LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister}); + // Remove zero bases in the gather dimension to make the function injective + // (for a given column) over the same codomain. + LinearLayout::BasesT newInvertRegMapBases; + for (auto &[inDim, inDimBases] : invertSrcRegMap.getBases()) { + auto &newInDimBases = newInvertRegMapBases[inDim]; + if (inDim != kGatherDim) { + newInDimBases = inDimBases; + continue; + } + for (auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) { + newInDimBases.push_back(basis); + } + } + } + invertSrcRegMap = LinearLayout( + newInvertRegMapBases, llvm::to_vector(invertSrcRegMap.getOutDimNames())); + // We are left with only non-zero bases in the gather dimension, which means + // the number of registers per column is the size of the "gather dimension". + unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim); + // Get a map from idx_reg to the column it indexes into. + LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims); + // Now given `idx_reg`, we can compute the column it belongs to in both src + // and index tensors, then partially apply `invertSrcRegMap` with this to + // obtain a function that outputs the corresponding registers in the src + // tensor in the same column. + + // L(column, i) = L(column, 0) xor L(0, i) + LinearLayout invertSrcRegMapColPart = + invertSrcRegMap.sublayout(otherDims, {kRegister}); + LinearLayout invertSrcRegMapRest = + invertSrcRegMap.sublayout({kGatherDim}, {kRegister}); + + SmallVector results; + for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { + SmallVector> column = + applyLinearLayout(loc, rewriter, idxColLayout, + {{kBlock, blockId}, + {kWarp, warpId}, + {kLane, laneId}, + {kRegister, i32_val(idxReg)}}); + assert(column.size() == otherDims.size()); + + // Combine the computed column with the data-dependent gather index. + column.emplace_back(kGatherDim, convertIndexToI32(loc, idxVal, rewriter)); + SmallVector> srcLaneAndReg = + applyLinearLayout(loc, rewriter, invSrcLayout, column); + + auto [srcLaneName, srcLane] = srcLaneAndReg.back(); + auto [srcRegName, srcReg] = srcLaneAndReg.front(); + assert(srcLaneName == kLane && srcRegName == kRegister); + + assert(!srcValues.empty() && "can't gather from an empty tensor"); + + // Figure out which src registers we need to index shuffle from. This is + // invariant to anything else. + SmallVector> normalizedColumn = + idxRegToCol.apply({{kRegister, idxReg}}); + int32_t srcBase = + invertSrcRegMapColPart.apply(normalizedColumn).front().second; + + Value result = undef(srcValues.front().getType()); + for (unsigned i = 0; i != numRegsPerColumn; ++i) { + int32_t rest = + invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; + int32_t srcRegIdx = srcBase ^ rest; + + Value value = + targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); + result = select(icmp_eq(i32_val(srcRegIdx), srcReg), value, result); + } + + results.push_back(result); + } + + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results, + rewriter, op.getType())); } } // namespace diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5e60d9fd14..3ca05bad50 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6201,6 +6201,24 @@ def kernel(In, Out, # assert torch.all(ref == result) +@triton.jit +def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, + src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, + idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, + out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, + out_stride1: tl.constexpr): + src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) + src = tl.load(src_ptr + src_offs) + + idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) + idx = tl.load(idx_ptr + idx_offs) + + out = tl.gather(src, idx, axis) + + out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) + tl.store(out_ptr + out_offs, out) + + @pytest.mark.parametrize("src_shape, indices_shape, axis", [ ([4, 4], [8, 4], 0), ([128, 64], [256, 64], 0), @@ -6208,29 +6226,13 @@ def kernel(In, Out, # ]) def test_gather(src_shape, indices_shape, axis): - @triton.jit - def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr, - src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr, - idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr, - out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr, - out_stride1: tl.constexpr): - src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1) - src = tl.load(src_ptr + src_offs) - - idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) - idx = tl.load(idx_ptr + idx_offs) - - out = tl.gather(src, idx, axis) - - out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) - tl.store(out_ptr + out_offs, out) - def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) - gather_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], - src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), - indices.stride(1), output.shape[0], output.shape[1], output.stride(0), output.stride(1)) + gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], + src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1], + indices.stride(0), indices.stride(1), output.shape[0], output.shape[1], + output.stride(0), output.stride(1)) return output @@ -6239,3 +6241,76 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor): ref = torch.gather(src, axis, indices) result = triton_gather(src, axis, indices) torch.testing.assert_close(result, ref, rtol=0, atol=0) + + +# These layouts are specially chosen to trigger the warp shuffle codegen. +@pytest.mark.parametrize("src_shape, indices_shape, axis, src_layout, indices_layout", [ + ([32, 16], [32, 16], 0, + "linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), + ([128, 64], [256, 64], 0, + "linear<{register = [[0, 2], [32, 0], [2, 0], [0, 16], [0, 32], [64, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>", + "linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>" + ), +]) +def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path): + if is_hip(): + pytest.skip("warp-local gather has issues on HIP") + + def prepare_kernel(src: torch.Tensor, axis: int, indices: torch.Tensor): + output = torch.empty(indices.shape, dtype=src.dtype, device=src.device) + compiled = gather_test_kernel.warmup(src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0), + src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0), + indices.stride(1), output.shape[0], output.shape[1], output.stride(0), + output.stride(1), grid=(1, )) + return output, compiled + + def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout, idx_layout): + ir = f""" +#src_layout = #ttg.{src_layout} +#idx_layout = #ttg.{idx_layout} +{ir}""" + + dtypes = {torch.int32: "i32", torch.float32: "f32", torch.int64: "i64", torch.float64: "f64"} + + src_spec = f"{src.shape[0]}x{src.shape[1]}x{dtypes[src.dtype]}" + indices_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[indices.dtype]}" + output_spec = f"{indices.shape[0]}x{indices.shape[1]}x{dtypes[src.dtype]}" + + pat = r"(%[0-9]+) = tt.gather (%[0-9]+)\[(%[0-9]+)\] {axis = " + pat += str(axis) + pat += r" : i32} : \(tensor\<" + pat += src_spec + pat += r", (#[a-z]+[0-9]+)\>, tensor\<" + pat += indices_spec + pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<" + pat += output_spec + pat += r", (#[a-z]+[0-9]+)\>" + + repl = r""" + %src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout> + %idx = ttg.convert_layout \3 : tensor<""" + indices_spec + r""", \5> -> tensor<""" + indices_spec + r""", #idx_layout> + %out = tt.gather %src[%idx] {axis = """ + str( + axis + ) + r""" : i32} : (tensor<""" + src_spec + r""", #src_layout>, tensor<""" + indices_spec + r""", #idx_layout>) -> tensor<""" + output_spec + r""", #idx_layout> + \1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>""" + return re.sub(pat, repl, ir) + + src = torch.randn(src_shape, device='cuda') + indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda') + ref = torch.gather(src, axis, indices) + + output, compiled = prepare_kernel(src, axis, indices) + ir = compiled.asm["ttgir"] + ir = inject_layout(ir, src, axis, indices, src_layout, indices_layout) + + temp_file = tmp_path / "test_warp_gather.ttgir" + temp_file.write_text(ir) + + kernel = triton.compile(str(temp_file)) + assert ("nvvm.shfl.sync.idx" in kernel.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in kernel.asm["llir"]) + + kernel[(1, 1, 1)](src, indices, output) + + torch.testing.assert_close(output, ref, rtol=0, atol=0) diff --git a/test/Conversion/allocate_shared_memory.mlir b/test/Conversion/allocate_shared_memory.mlir index 345714f5b2..f3c2ed7033 100644 --- a/test/Conversion/allocate_shared_memory.mlir +++ b/test/Conversion/allocate_shared_memory.mlir @@ -1,14 +1,16 @@ // RUN: triton-opt %s --allocate-shared-memory | FileCheck %s +#blocked = #ttg.blocked<{sizePerThread = [32, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK-LABEL: module // CHECK-SAME: ttg.shared = 131072 : i32 module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @gather_op // TODO(jeff): Optimize the lowering to reduce shared memory usage. -tt.func @gather_op(%arg0: tensor<1024x256xi32>, %arg1: tensor<128x256xf32>) { +tt.func @gather_op(%arg0: tensor<1024x256xi32, #blocked>, %arg1: tensor<128x256xf32, #blocked>) { // CHECK-NEXT: allocation.offset = 0 : i32 - %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32>, tensor<1024x256xi32>) -> tensor<1024x256xf32> + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<128x256xf32, #blocked>, tensor<1024x256xi32, #blocked>) -> tensor<1024x256xf32, #blocked> tt.return } diff --git a/test/Conversion/gather_to_llvm.mlir b/test/Conversion/gather_to_llvm.mlir new file mode 100644 index 0000000000..28a8a7e6b2 --- /dev/null +++ b/test/Conversion/gather_to_llvm.mlir @@ -0,0 +1,271 @@ +// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s + +// Check the optimized LLVMIR, since InstCombine makes the linear layout +// logic understandable enough (in simple cases) to check correctness by eye. + +#trivial_layout = #ttg.linear<{register = [], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider = #ttg.linear<{register = [[32]], lane = [[1], [2], [4], [8], [16]], warp = [], block = []}> + +#trivial_layout_wider_reg_stride_1 = #ttg.linear<{register = [[1]], lane = [[2], [4], [8], [16], [32]], warp = [], block = []}> + +#trivial_2d_one_col = #ttg.linear<{register = [[0, 1]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [], block = []}> + +#span_2d_cols = #ttg.linear<{register = [[1, 0]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [], block = []}> + +#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> +#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + +// Each source element is mapped to a single thread, so we expect one index shuffle. +// CHECK-LABEL: @gather_warp_local_trivial +tt.func private @gather_warp_local_trivial(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, but there are two index elements per thread. Expect 2 index shuffles +// with the results packed together. +// CHECK-LABEL: @gather_warp_local_larger_output +tt.func private @gather_warp_local_larger_output(%arg0: tensor<64xi32, #trivial_layout_wider>, %arg1: tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> { + // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32xf32, #trivial_layout>, tensor<64xi32, #trivial_layout_wider>) -> tensor<64xf32, #trivial_layout_wider> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<64xf32, #trivial_layout_wider> +} + +// Each thread has 2 elements of the source tensor, strided 32 apart, so we +// expect two index shuffles, using the MSB to select between the two. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 31 + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 32 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Same as above, except the RegID comes from the LSB. +// CHECK-LABEL: @gather_warp_local_larger_input +tt.func private @gather_warp_local_larger_input_stride_1(%arg0: tensor<32xi32, #trivial_layout>, %arg1: tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0 + + // CHECK-NEXT: [[REGID:%.*]] = and i32 [[IDX]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX]], 1 + // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[TMP]], 31 + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + + // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31) + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64xf32, #trivial_layout_wider_reg_stride_1>, tensor<32xi32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + + // CHECK-NEXT: [[PICK:%.*]] = icmp eq i32 [[REGID]], 0 + // CHECK-NEXT: [[RES_i32:%.*]] = select i1 [[PICK]], i32 [[RES0]], i32 [[RES1]] + // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float + + // CHECK-NEXT: ret float [[RES]] + tt.return %0 : tensor<32xf32, #trivial_layout> +} + +// Each thread has 1 element in 2 gather columns, so this is the same as the +// trivial case except now it's 2D. We expect 2 independent index shuffles. +// CHECK-LABEL: @gather_2d_trivial +tt.func private @gather_2d_trivial(%arg0: tensor<32x2xi32, #trivial_2d_one_col>, %arg1: tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[IDX0]], 31 + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31) + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[IDX1]], 31 + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31) + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #trivial_2d_one_col>, tensor<32x2xi32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #trivial_2d_one_col> +} + +// The single warp is split into two columns. Each column has half contiguous +// threads, each with 2 contiguous elements. Expect 4 index shuffles: two per +// column. Thus, the index should be dependent on the thread id, since the +// register alone is not enough to determine the column. +// CHECK-LABEL: @gather_2d_span_2 +tt.func private @gather_2d_span_2(%arg0: tensor<32x2xi32, #span_2d_cols>, %arg1: tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> { + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1 + // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0 + // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1 + + // This uses tid to select between the two columns: + // CHECK-NEXT: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + // CHECK-NEXT: [[COL:%.*]] = and i32 [[TID]], 16 + + // Break the index into reg and thread (within column) components: + // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1 + // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID0]], [[COL]] + + // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // Use the reg id to select between the two results: + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0 + // CHECK-NEXT: [[RES0_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES0:%.*]] = bitcast i32 [[RES0_i32]] to float + + // CHECK-NEXT: [[REGID1:%.*]] = and i32 [[IDX1]], 1 + // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1 + // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 15 + + // CHECK-NEXT: [[SHUFFLE_IDX:%.*]] = or disjoint i32 [[LANEID1]], [[COL]] + + // CHECK-NEXT: [[SRES0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[SHUFFLE_IDX]], i32 31) + // CHECK-NEXT: [[SRES1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[SHUFFLE_IDX]], i32 31) + + // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID1]], 0 + // CHECK-NEXT: [[RES1_i32:%.*]] = select i1 [[PICK0]], i32 [[SRES0]], i32 [[SRES1]] + // CHECK-NEXT: [[RES1:%.*]] = bitcast i32 [[RES1_i32]] to float + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x2xf32, #span_2d_cols>, tensor<32x2xi32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + + // CHECK-NEXT: [[PACKED0:%.*]] = insertvalue { float, float } undef, float [[RES0]], 0 + // CHECK-NEXT: [[PACKED1:%.*]] = insertvalue { float, float } [[PACKED0]], float [[RES1]], 1 + // CHECK-NEXT: ret { float, float } [[PACKED1]] + tt.return %0 : tensor<32x2xf32, #span_2d_cols> +} + +// CHECK-LABEL: @gather_2d_crazy +tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1: tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> { + // The specific logic becomes hard to grasp here. Just check the shuffles. + + // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float, float, float } %1, 0 + // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float, float, float } %1, 1 + // CHECK-NEXT: [[SRC2:%.*]] = extractvalue { float, float, float, float } %1, 2 + // CHECK-NEXT: [[SRC3:%.*]] = extractvalue { float, float, float, float } %1, 3 + + // CHECK: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: [[VALUE2:%.*]] = bitcast float [[SRC2]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]], + + // CHECK: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: [[VALUE3:%.*]] = bitcast float [[SRC3]] to i32 + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + // CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], + // CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]], + + %0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x16xf32, #crazy_2d_src>, tensor<32x16xi32, #crazy_2d_idx>) -> tensor<32x16xf32, #crazy_2d_idx> + tt.return %0 : tensor<32x16xf32, #crazy_2d_idx> +} + +// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM +// from removing unused function results. +tt.func @anchor(%ptr: !llvm.ptr, + %arg0: tensor<32xi32, #trivial_layout>, + %arg1: tensor<32xf32, #trivial_layout>, + %arg2: tensor<64xi32, #trivial_layout_wider>, + %arg3: tensor<64xf32, #trivial_layout_wider>, + %arg4: tensor<64xf32, #trivial_layout_wider_reg_stride_1>, + %arg5: tensor<32x2xi32, #trivial_2d_one_col>, + %arg6: tensor<32x2xf32, #trivial_2d_one_col>, + %arg7: tensor<32x2xi32, #span_2d_cols>, + %arg8: tensor<32x2xf32, #span_2d_cols>, + %arg9: tensor<32x16xi32, #crazy_2d_idx>, + %arg10: tensor<32x16xf32, #crazy_2d_src>) { + + %0 = tt.call @gather_warp_local_trivial(%arg0, %arg1) : (tensor<32xi32, #trivial_layout>, tensor<32xf32, #trivial_layout>) -> tensor<32xf32, #trivial_layout> + %1 = builtin.unrealized_conversion_cast %0 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %1, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %2 = tt.call @gather_warp_local_larger_output(%arg2, %arg1) : (tensor<64xi32, #trivial_layout_wider>, tensor<32xf32, #trivial_layout>) -> tensor<64xf32, #trivial_layout_wider> + %3 = builtin.unrealized_conversion_cast %2 : tensor<64xf32, #trivial_layout_wider> to !llvm.struct<(f32, f32)> + llvm.store volatile %3, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %4 = tt.call @gather_warp_local_larger_input(%arg0, %arg3) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider>) -> tensor<32xf32, #trivial_layout> + %5 = builtin.unrealized_conversion_cast %4 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %5, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %6 = tt.call @gather_warp_local_larger_input_stride_1(%arg0, %arg4) : (tensor<32xi32, #trivial_layout>, tensor<64xf32, #trivial_layout_wider_reg_stride_1>) -> tensor<32xf32, #trivial_layout> + %7 = builtin.unrealized_conversion_cast %6 : tensor<32xf32, #trivial_layout> to !llvm.struct<(f32)> + llvm.store volatile %7, %ptr : !llvm.struct<(f32)>, !llvm.ptr + + %8 = tt.call @gather_2d_trivial(%arg5, %arg6) : (tensor<32x2xi32, #trivial_2d_one_col>, tensor<32x2xf32, #trivial_2d_one_col>) -> tensor<32x2xf32, #trivial_2d_one_col> + %9 = builtin.unrealized_conversion_cast %8 : tensor<32x2xf32, #trivial_2d_one_col> to !llvm.struct<(f32, f32)> + llvm.store volatile %9, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %10 = tt.call @gather_2d_span_2(%arg7, %arg8) : (tensor<32x2xi32, #span_2d_cols>, tensor<32x2xf32, #span_2d_cols>) -> tensor<32x2xf32, #span_2d_cols> + %11 = builtin.unrealized_conversion_cast %10 : tensor<32x2xf32, #span_2d_cols> to !llvm.struct<(f32, f32)> + llvm.store volatile %11, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr + + %12 = tt.call @gather_2d_crazy(%arg9, %arg10) : (tensor<32x16xi32, #crazy_2d_idx>, tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> + %13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)> + llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr + + tt.return +} + +} From 99506b7c90613aae79e719426fe0d9b2880a6819 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Tue, 10 Dec 2024 18:15:50 +0000 Subject: [PATCH 3/7] [Pipeliner] Multi-buffer TMA descriptors (#5290) ### Commits in this PR 1. [Pipeliner] Multi-buffer TMA descriptors 2. Add tests for pipelined descriptor creation 3. Be more conservative about number of TMA buffers to allocate 4. Update golden samples 5. Use correct modulus for tma updates --- .../TritonGPU/Transforms/PipeliningUtility.h | 3 + .../TritonNvidiaGPU/Transforms/TMAUtilities.h | 99 +++++ .../Pipeliner/MatmulLoopPipeline.cpp | 205 +++++++++- .../Pipeliner/PipeliningUtility.cpp | 25 +- .../Transforms/TMALowering.cpp | 82 +--- python/src/passes.cc | 1 + .../test/unit/hopper/test_experimental_tma.py | 104 +++++ python/tutorials/09-persistent-matmul.py | 112 +++--- .../samples/simulated-grouped-gemm.mlir | 377 ++++++++++++++++++ .../samples/simulated-grouped-gemm.mlir.in | 104 +++++ .../lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp | 3 +- 11 files changed, 960 insertions(+), 155 deletions(-) create mode 100644 include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h create mode 100644 test/TritonGPU/samples/simulated-grouped-gemm.mlir create mode 100644 test/TritonGPU/samples/simulated-grouped-gemm.mlir.in diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index cdf22d15d4..414313138b 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -2,6 +2,8 @@ #define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ #include "mlir/Dialect/SCF/IR/SCF.h" +#include +#include #include namespace mlir { @@ -38,6 +40,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, // Return the minClusterId and maxClusterId for the given ForOp. std::pair getMinMaxCluster(scf::ForOp &forOp); std::pair getStageCluster(Operation *op); +std::optional> maybeGetStageCluster(Operation *op); void setStageCluster(Operation *op, int stage, int cluster); } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h new file mode 100644 index 0000000000..8f36b7732f --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -0,0 +1,99 @@ +#pragma once +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton::nvidia_gpu { + +constexpr inline int TMA_SIZE_BYTES = 128; +constexpr inline int TMA_ALIGN = 128; + +template +mlir::LogicalResult createTMADesc(mlir::Value tmaPtr, + mlir::triton::MakeTensorDescOp op, + BuilderT &builder) { + using namespace mlir; + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto mkI32Constant = [&](int32_t val) { + return builder.template create( + loc, builder.getI32Type(), builder.getI32IntegerAttr(val)); + }; + + auto elemType = op.getBase().getType().getPointeeType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + + int32_t contig_dim_size = op.getTensorShape().back(); + int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; + if (contig_dim_size_in_bytes > 128) { + contig_dim_size = 128 / elemSize; + } + llvm::SmallVector boxDim; + boxDim.push_back(mkI32Constant(contig_dim_size)); + for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { + boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); + } + + int32_t swizzle_mode; + if (contig_dim_size_in_bytes >= 128) { + swizzle_mode = 3; + } else if (contig_dim_size_in_bytes == 64) { + swizzle_mode = 2; + } else if (contig_dim_size_in_bytes == 32) { + swizzle_mode = 1; + } else { + op->emitError() + << "contiguous box dimension must be at least 32 bytes but got " + << contig_dim_size_in_bytes; + return failure(); + } + + Value elemSizeVal = builder.template create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize)); + Value globalStride = builder.template create( + loc, op.getStrides()[0], elemSizeVal); + // TODO: Workaround for ptxas bug, remove when we update ptxas + Value four = builder.template create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(4)); + globalStride = + builder.template create(loc, globalStride, four); + + int elemTypeEnum; + switch (elemSize) { + case 1: { + elemTypeEnum = 0; + break; + } + case 2: { + elemTypeEnum = 1; + break; + } + case 4: { + elemTypeEnum = 2; + break; + } + default: { + op->emitError() + << "Tensor descriptor element type must have size 1, 2, or 4 but got " + << elemSize; + return failure(); + } + } + + auto one = mkI32Constant(1); + builder.template create( + loc, + /*desc_ptr=*/tmaPtr, + /*global_address=*/op.getBase(), + /*box_dim=*/boxDim, + /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, + /*global_stride=*/ValueRange{globalStride}, + /*element_strides=*/ValueRange{one, one}, + /*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum), + /*interleave_layout*/ builder.getI32IntegerAttr(0), + /*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode), + /*fill_mode=*/builder.getI32IntegerAttr(0)); + return success(); +} + +} // namespace mlir::triton::nvidia_gpu diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 548702fdc0..fc037cd26f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1,28 +1,24 @@ #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" -#include - #define DEBUG_TYPE "triton-matmul-loop-pipeline" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -70,6 +66,30 @@ class OpBuilderWithStage : public OpBuilder { using OpBuilder::create; }; +class OpBuilderForStage : public OpBuilder { + std::optional stage_, cluster_; + +public: + explicit OpBuilderForStage(Operation *op, int stage, int cluster) + : OpBuilder(op, nullptr), stage_(stage), cluster_(cluster) {} + explicit OpBuilderForStage(Operation *op) : OpBuilder(op, nullptr) { + auto sc = tt::maybeGetStageCluster(op); + if (sc) { + stage_ = sc->first; + cluster_ = sc->second; + } + } + + template OpTy create(Args &&...args) { + OpTy op = OpBuilder::create(std::forward(args)...); + + if (stage_ && cluster_) { + tt::setStageCluster(op, *stage_, *cluster_); + } + return op; + } +}; + static bool sameStageCluster(Operation *op1, Operation *op2) { auto [s1, c1] = tt::getStageCluster(op1); auto [s2, c2] = tt::getStageCluster(op2); @@ -708,7 +728,128 @@ getFinalSchedule(scf::ForOp &forOp, int numStages) { return fSchedule; } -// Convert load ops into their asyn version and apply multi-buffering based on +LogicalResult +allocTMABuffers(scf::ForOp forOp, + llvm::MapVector &tmaBufferMapping, + int numStages) { + IRRewriter rewriter(forOp); + + // Create a multi-buffered allocation for each MakeTensorDescOp call in the + // loop + forOp.walk([&](tt::MakeTensorDescOp op) { + // TODO peter: walk to loop yield to find the init value if this is a + // loop-carried value. That would save us from allocating another buffer + // just for the init value + auto loc = op.getLoc(); + Value alloc = rewriter.create( + loc, triton::getPointerType(rewriter.getI8Type()), + numStages * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); + tmaBufferMapping[op.getOperation()] = alloc; + }); + return success(); +} + +template +Value createIncrementModulo(BuilderT &builder, Location loc, Value counter, + Value modulus, Value zero, Value one) { + Value addOne = builder.template create(loc, counter, one); + Value inRangeCond = builder.template create( + loc, arith::CmpIPredicate::slt, addOne, modulus); + return builder.template create(loc, inRangeCond, addOne, + zero); +} + +template +Value subviewTMADescriptor(BuilderT &builder, Location loc, Value alloc, + Value counter) { + Value tmaSizeVal = builder.template create( + loc, ttng::TMA_SIZE_BYTES, 32); + Value offset = + builder.template create(loc, tmaSizeVal, counter); + return builder.template create(loc, alloc.getType(), alloc, + offset); +} + +LogicalResult rewriteTMABufferUpdates( + scf::ForOp forOp, + const llvm::MapVector &tmaBufferMapping, + ArrayRef tmaCounters, int numStages, Value one, Value zero) { + assert(tmaBufferMapping.size() == tmaCounters.size()); + + Value numStagesVal = mlir::OpBuilder(forOp).create( + forOp.getLoc(), numStages, 32); + + for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { + auto &[op, alloc] = pair; + + // Rewriter MakeTensorDescOp as writing a TMA descriptor + auto makeDescOp = cast(op); + + OpBuilderForStage stageBuilder(makeDescOp); + auto loc = makeDescOp.getLoc(); + + BlockArgument counter = tmaCounters[iOp]; + Value nextBuf = subviewTMADescriptor(stageBuilder, loc, alloc, counter); + if (failed(ttng::createTMADesc(nextBuf, makeDescOp, stageBuilder))) { + return failure(); + } + stageBuilder.create( + loc, nextBuf); + Value nextDesc = stageBuilder.create( + loc, makeDescOp.getType(), nextBuf); + + makeDescOp.getResult().replaceAllUsesWith(nextDesc); + + // Increment the buffer index counter + Value nextCounter = createIncrementModulo(stageBuilder, loc, counter, + numStagesVal, zero, one); + + // If we are in a (potentially nested) if region, propagate the counter + // up to the main for op body scope + Operation *curOp = op; + Operation *parent = op->getParentOp(); + while (parent != forOp.getOperation()) { + auto ifOp = dyn_cast(parent); + if (!ifOp) { + std::string msg; + llvm::raw_string_ostream ss(msg); + ss << "Cannot pipeline MakeTensorDescOp inside:\n"; + parent->print(ss); + ss << "\nOnly scf.if regions are supported"; + return makeDescOp->emitOpError(std::move(msg)); + } + + IRRewriter rewriter(parent); + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, {nextCounter.getType()}); + + auto yieldNewBlock = newIfOp.thenBlock(); + auto yieldOldBlock = newIfOp.elseBlock(); + + if (yieldNewBlock != curOp->getBlock()) { + std::swap(yieldNewBlock, yieldOldBlock); + } + cast(yieldNewBlock->getTerminator()) + .getResultsMutable() + .append(nextCounter); + cast(yieldOldBlock->getTerminator()) + .getResultsMutable() + .append(counter); + + ifOp.erase(); + nextCounter = newIfOp.getResults().back(); + curOp = newIfOp; + parent = newIfOp->getParentOp(); + } + + // Finally, rewrite the loop level yield + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield.setOperand(counter.getArgNumber() - 1, nextCounter); + } + return success(); +} + +// Convert load ops into their async version and apply multi-buffering based on // the required number of buffers. static SmallVector createAsyncOps(scf::ForOp &forOp, @@ -732,6 +873,11 @@ createAsyncOps(scf::ForOp &forOp, numBuffers++; }; + llvm::MapVector tmaBufferMapping; + if (failed(allocTMABuffers(forOp, tmaBufferMapping, numStages))) { + llvm_unreachable("TMA pipelining failed"); + } + SmallVector asyncLoads; SmallVector allocs; bool hasTMALoad = false; @@ -755,7 +901,10 @@ createAsyncOps(scf::ForOp &forOp, builder.setInsertionPoint(forOp); Location loc = forOp.getLoc(); - // Create two new counters to index into the allocs. + // Create a counter to index into the allocations per loop iteration. + // NOTE: We create two duplicates values, insertIdx and extractIdx so that the + // pipeliner will re-materialize the value in later stages of the pipeline + // instead of carrying it as a dependency across multiple iterations. Value minusOne = builder.create(loc, -1, 32); Value zero = builder.create(loc, 0, 32); Value one = builder.create(loc, 1, 32); @@ -768,9 +917,19 @@ createAsyncOps(scf::ForOp &forOp, newOperands.push_back(insertIdx); newOperands.push_back(extractIdx); if (hasTMALoad) { + // A single barrier arrival sequence is a "phase" and two phases can + // overlap, provided the phases are differentiated with an alternating + // boolean value. phase = builder.create(loc, 0, 32); newOperands.push_back(phase); } + // Also create one counter per TMA buffer. This allows the descriptors to be + // updated independently without needing to write duplicate of existing tma + // descriptors. + for (int i = 0; i < tmaBufferMapping.size(); ++i) { + newOperands.push_back(zero); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); // Patch the loop to add the new loop carried dependencies. scf::ForOp newForOp = @@ -782,6 +941,20 @@ createAsyncOps(scf::ForOp &forOp, if (phase) { phase = newForOp.getBody()->getArgument(newOperandIndex + 2); } + auto tmaCounters = ArrayRef(newForOp.getBody()->getArguments()) + .slice(newOperandIndex + (phase ? 3 : 2)); + + // Update yield op with temporary yield values + auto forYield = cast(newForOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + if (failed(rewriteTMABufferUpdates(newForOp, tmaBufferMapping, tmaCounters, + numStages, one, zero))) { + llvm_unreachable("Failed to rewrite TMA ops"); + } + tmaBufferMapping.clear(); // FIXME: loads can be in different (stage, cluster) // Create two counters for the insert and extract indices to avoid creating @@ -815,11 +988,12 @@ createAsyncOps(scf::ForOp &forOp, loadToInfo, numStages); } } - SmallVector newYieldOperands = {insertIdx, extractIdx}; - if (phase) - newYieldOperands.push_back(phase); // Patch the yield with the updated counters. - appendToForOpYield(forOp, newYieldOperands); + forYield.setOperand(newOperandIndex + -1, insertIdx); + forYield.setOperand(newOperandIndex + 0, extractIdx); + if (phase) { + forYield.setOperand(newOperandIndex + 1, phase); + } tt::CoarseSchedule coarseSchedule(numStages); coarseSchedule.deSerialize(forOp); @@ -957,12 +1131,11 @@ static int minNumInterleavedCommitOps(Operation *waitOp) { if (thisHistorySum >= minCommitNumber) return minCommitNumber; - // get the value value assigned to the argument coming from outside the - // loop + // get the value assigned to the argument coming from outside the loop Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1]; int min1 = minOverHistories(incomingVal, forOp, thisHistorySum); - // get the value value assigned to the argument coming from the previous + // get the value assigned to the argument coming from the previous // iteration Operation *yieldOp = block->getTerminator(); Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index aab5607707..3ab65c6105 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Casting.h" using namespace mlir; namespace tt = mlir::triton; @@ -198,15 +199,23 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, op->erase(); } -std::pair mlir::triton::getStageCluster(Operation *op) { - auto stage = cast(op->getAttr(mlir::triton::kLoopStageAttrName)) - .getValue() - .getSExtValue(); +std::optional> +mlir::triton::maybeGetStageCluster(Operation *op) { + auto stage = + dyn_cast_if_present(op->getAttr(tt::kLoopStageAttrName)); auto clusterId = - cast(op->getAttr(mlir::triton::kLoopClusterAttrName)) - .getValue() - .getSExtValue(); - return std::make_pair(stage, clusterId); + dyn_cast_if_present(op->getAttr(tt::kLoopClusterAttrName)); + if (!stage || !clusterId) { + return std::nullopt; + } + + return { + {stage.getValue().getSExtValue(), clusterId.getValue().getSExtValue()}}; +} +std::pair mlir::triton::getStageCluster(Operation *op) { + auto res = maybeGetStageCluster(op); + assert(res.has_value() || "Operation is missing stage & cluster attribute"); + return *res; } void mlir::triton::setStageCluster(Operation *op, int stage, int cluster) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index cb9ae9dd0f..f412755d55 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -2,13 +2,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include @@ -115,87 +115,11 @@ class TMACreateDescLowering : public OpRewritePattern { PatternRewriter &rewriter) const override { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); - constexpr auto kTmaNbytes = 128; - constexpr auto kTmaAlignment = 128; auto alloc = rewriter.create( - loc, getPointerType(rewriter.getI8Type()), kTmaNbytes, kTmaAlignment); - auto mkI32Constant = [&](int32_t val) { - return rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(val)); - }; - - auto elemType = op.getBase().getType().getPointeeType(); - auto elemSize = elemType.getIntOrFloatBitWidth() / 8; - - int32_t contig_dim_size = op.getTensorShape().back(); - int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; - if (contig_dim_size_in_bytes > 128) { - contig_dim_size = 128 / elemSize; - } - llvm::SmallVector boxDim; - boxDim.push_back(mkI32Constant(contig_dim_size)); - for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { - boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); - } - - int32_t swizzle_mode; - if (contig_dim_size_in_bytes >= 128) { - swizzle_mode = 3; - } else if (contig_dim_size_in_bytes == 64) { - swizzle_mode = 2; - } else if (contig_dim_size_in_bytes == 32) { - swizzle_mode = 1; - } else { - op->emitError() - << "contiguous box dimension must be at least 32 bytes but got " - << contig_dim_size_in_bytes; - return failure(); - } - - Value elemSizeVal = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(elemSize)); - Value globalStride = - rewriter.create(loc, op.getStrides()[0], elemSizeVal); - // TODO: Workaround for ptxas bug, remove when we update ptxas - Value four = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(4)); - globalStride = rewriter.create(loc, globalStride, four); - - int elemTypeEnum; - switch (elemSize) { - case 1: { - elemTypeEnum = 0; - break; - } - case 2: { - elemTypeEnum = 1; - break; - } - case 4: { - elemTypeEnum = 2; - break; - } - default: { - op->emitError() - << "Tensor descriptor element type must have size 1, 2, or 4 but got " - << elemSize; + loc, getPointerType(rewriter.getI8Type()), TMA_SIZE_BYTES, TMA_ALIGN); + if (failed(createTMADesc(alloc, op, rewriter))) { return failure(); } - } - - auto one = mkI32Constant(1); - rewriter.create( - loc, - /*desc_ptr=*/alloc.getResult(), - /*global_address=*/op.getBase(), - /*box_dim=*/boxDim, - /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, - /*global_stride=*/ValueRange{globalStride}, - /*element_strides=*/ValueRange{one, one}, - /*elem_type*/ rewriter.getI32IntegerAttr(elemTypeEnum), - /*interleave_layout*/ rewriter.getI32IntegerAttr(0), - /*swizzle_mode=*/rewriter.getI32IntegerAttr(swizzle_mode), - /*fill_mode=*/rewriter.getI32IntegerAttr(0)); rewriter.create( loc, alloc.getResult()); auto newDesc = rewriter.create( diff --git a/python/src/passes.cc b/python/src/passes.cc index 235eba4465..b0efc3cb88 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -31,6 +31,7 @@ void init_triton_passes_common(py::module &&m) { ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); ADD_PASS_WRAPPER_0("add_cse", createCSEPass); ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); + ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); } void init_triton_passes_ttir(py::module &&m) { diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 23065953d6..c32c017dc7 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -538,3 +538,107 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): ) torch.testing.assert_close(ref_out, A) assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] + + +@triton.jit +def batched_gemm_kernel(a_ptr, b_ptr, c_ptr, # + B, M, N, K, # + dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SMS: tl.constexpr): + start_pid = tl.program_id(axis=0) + num_tiles_m = tl.cdiv(M, BLOCK_M) + num_tiles_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles_per_batch = num_tiles_m * num_tiles_n + num_tiles = B * num_tiles_per_batch + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + tile_m = 0 + tile_n = 0 + tile_b = 0 + + offs_m = 0 + offs_n = 0 + offs_b = 0 + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N]) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for _ in range(k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + tile_b = tile_id // num_tiles_per_batch + tile_m = (tile_id // num_tiles_n) % num_tiles_m + tile_n = tile_id % num_tiles_n + + offs_b = tile_b + offs_m = tile_m * BLOCK_M + offs_n = tile_n * BLOCK_N + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], + [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], + [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], + [BLOCK_M, BLOCK_N]) + + offs_k = ki * BLOCK_K + + a = a_desc.load([offs_m, offs_k]) + b = b_desc.load([offs_n, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + c_desc.store([offs_m, offs_n], c) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@requires_tma +def test_tensor_descriptor_batched_gemm(): + device = "cuda" + B, M, N, K = 2, 1024, 1024, 128 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 + NUM_SMS = 96 + num_stages = 3 + + grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), ) + + a = torch.randn((B, M, K), device=device, dtype=torch.float16) + b = torch.randn((B, N, K), device=device, dtype=torch.float16) + c = torch.empty((B, M, N), device=device, dtype=torch.float16) + + expect = torch.bmm(a, b.mT) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + # TODO: should only need num_stages * 3 descriptors per SM + assert size == 128 * 3 * (num_stages + 1) * grid[0] + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + h = batched_gemm_kernel[grid]( + a, b, c, # + B, M, N, K, # + tl.float16, # + BLOCK_M, BLOCK_N, BLOCK_K, # + NUM_SMS, # + num_stages=num_stages, num_warps=8) + print(h.n_regs) + torch.cuda.synchronize() + + torch.testing.assert_close(c, expect, rtol=1e-3, atol=1e-3) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 49a8bb32c4..4b4a08857d 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -28,6 +28,8 @@ import triton.profiler as proton from contextlib import contextmanager +from typing import Optional + if torch.cuda.is_available(): from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) @@ -374,8 +376,7 @@ def matmul_tma_persistent(a, b): @triton.jit(launch_metadata=_matmul_launch_metadata) -def matmul_kernel_device_tma_persistent(workspace_ptr, # - tiles_per_update: tl.constexpr, # +def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # a_ptr, b_ptr, c_ptr, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # @@ -391,24 +392,24 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - TMA_SIZE: tl.constexpr = 128 - workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - c_desc_ptr = workspace_base + 2 * TMA_SIZE - - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K], - element_ty=a_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, - load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K], - element_ty=b_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, - load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N], - element_ty=c_ptr.dtype.element_ty) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: @@ -426,6 +427,9 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # num_pid_in_group = GROUP_SIZE_M * num_pid_n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Create an opaque value to prevent the descriptor creation from being + # hoisted out of the loop + zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1) for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) @@ -434,21 +438,24 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # # Simulate a grouped gemm if ni == tiles_per_update: - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, - load_size=[BLOCK_SIZE_M, - BLOCK_SIZE_K], global_size=[M, K], - element_ty=a_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, - load_size=[BLOCK_SIZE_N, - BLOCK_SIZE_K], global_size=[N, K], - element_ty=b_ptr.dtype.element_ty) - tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, - load_size=[BLOCK_SIZE_M, - BLOCK_SIZE_N], global_size=[M, N], - element_ty=c_ptr.dtype.element_ty) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + a_desc = tl._experimental_make_tensor_descriptor( + a_ptr + zero, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl._experimental_make_tensor_descriptor( + b_ptr + zero, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + c_desc = tl._experimental_make_tensor_descriptor( + c_ptr + zero, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) ni = 0 tile_id += NUM_SMS @@ -463,19 +470,19 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # offs_k = ki * BLOCK_SIZE_K - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) if ki == k_tiles - 1: c = accumulator.to(dtype) - tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + c_desc.store([offs_am, offs_bn], c) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -def matmul_device_tma_persistent(a, b, tiles_per_update): +def matmul_descriptor_persistent(a, b, tiles_per_update): # Autotuner does not work with TMA. Use manual config. configs = { torch.float8_e4m3fn: { @@ -497,12 +504,15 @@ def matmul_device_tma_persistent(a, b, tiles_per_update): c = torch.empty((M, N), device=a.device, dtype=dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - tma_size = 128 - workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) - matmul_kernel_device_tma_persistent[grid]( - workspace, # + matmul_kernel_descriptor_persistent[grid]( tiles_per_update, # a, b, c, # M, N, K, # @@ -576,7 +586,7 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) - bench_fn(reps, warmup_reps, matmul_device_tma_persistent, a, b, tiles_per_update) + bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b, tiles_per_update) def validate(M, N, K, dtype, tiles_per_update): @@ -589,7 +599,7 @@ def validate(M, N, K, dtype, tiles_per_update): naive_result = matmul(a, b.T) persistent_result = matmul_persistent(a, b.T) tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None - device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None + descriptor_persistent_result = matmul_descriptor_persistent(a, b, tiles_per_update) if supports_tma() else None if torch_result is not None: naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), @@ -602,9 +612,9 @@ def validate(M, N, K, dtype, tiles_per_update): if tma_persistent_result is not None: naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16), tma_persistent_result.to(torch.float16), atol=1.0) else "❌" - if device_tma_persistent_result is not None: - naive_vs_device_tma_persistent = "✅" if torch.allclose(cublas_result.to( - torch.float16), device_tma_persistent_result.to(torch.float16), atol=1.0) else "❌" + if descriptor_persistent_result is not None: + naive_vs_descriptor_persistent = "✅" if torch.allclose(cublas_result.to( + torch.float16), descriptor_persistent_result.to(torch.float16), atol=1.0) else "❌" print(f"M={M}, N={N}, K={K} verification naive vs: ", end="") if torch_result is not None: print(f"torch: {naive_vs_torch} ", end="") @@ -613,8 +623,8 @@ def validate(M, N, K, dtype, tiles_per_update): print(f"persistent: {naive_vs_persistent} ", end="") if tma_persistent_result is not None: print(f"TMA persistent: {naive_vs_tma_persistent} ", end="") - if device_tma_persistent_result is not None: - print(f"Device TMA persistent: {naive_vs_device_tma_persistent} ", end="") + if descriptor_persistent_result is not None: + print(f"Device TMA persistent: {naive_vs_descriptor_persistent} ", end="") print() @@ -639,7 +649,7 @@ def show_profile(precision, profile_name): type=int, default=1, help= - "Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel", + "Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel", ) parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") args = parser.parse_args() diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir b/test/TritonGPU/samples/simulated-grouped-gemm.mlir new file mode 100644 index 0000000000..24cbc4dc01 --- /dev/null +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir @@ -0,0 +1,377 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. + +// CHECK: #[[$ATTR_0:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_4:.+]] = #ttg.shared_memory +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \ +// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// CHECK-LABEL: tt.func public @matmul_kernel_descriptor_persistent( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : i64 +// CHECK: %[[VAL_7:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_10:.*]] = arith.constant false +// CHECK: %[[VAL_11:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant 132 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_15:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_17:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_18:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_19:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_20:.*]] = arith.constant 127 : i32 +// CHECK: %[[VAL_21:.*]] = arith.constant 255 : i32 +// CHECK: %[[VAL_22:.*]] = arith.constant 63 : i32 +// CHECK: %[[VAL_23:.*]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_24:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_3]], %[[VAL_20]] : i32 +// CHECK: %[[VAL_26:.*]] = arith.divsi %[[VAL_25]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_4]], %[[VAL_21]] : i32 +// CHECK: %[[VAL_28:.*]] = arith.divsi %[[VAL_27]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_5]], %[[VAL_22]] : i32 +// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_26]], %[[VAL_28]] : i32 +// CHECK: %[[VAL_32:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 +// CHECK: %[[VAL_33:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_32]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_34:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_32]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_35:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 +// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_35]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_37:.*]] = arith.divsi %[[VAL_31]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_38:.*]] = arith.remsi %[[VAL_31]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_39:.*]] = arith.cmpi slt, %[[VAL_24]], %[[VAL_38]] : i32 +// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (i32) { +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_37]], %[[VAL_11]] : i32 +// CHECK: scf.yield %[[VAL_41]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_37]] : i32 +// CHECK: } +// CHECK: %[[VAL_42:.*]] = arith.subi %[[VAL_24]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_43:.*]] = arith.muli %[[VAL_28]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_44:.*]] = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 +// CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_30]], %[[VAL_40]] : i32 +// CHECK: %[[VAL_46:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_47:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_48:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_49:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_50:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_51:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_52:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_53]], 1 : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_54:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_54]], 1 : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_55:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_55]], 1 : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_56:.*]] = arith.cmpi sgt, %[[VAL_45]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[VAL_24]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_14]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_59:.*]]:2 = scf.if %[[VAL_56]] -> (i32, i32) { +// CHECK: %[[VAL_60:.*]] = arith.divsi %[[VAL_24]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_61:.*]] = arith.muli %[[VAL_60]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_62:.*]] = arith.subi %[[VAL_26]], %[[VAL_61]] : i32 +// CHECK: %[[VAL_63:.*]] = arith.minsi %[[VAL_62]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_64:.*]] = arith.remsi %[[VAL_24]], %[[VAL_63]] : i32 +// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_61]], %[[VAL_64]] : i32 +// CHECK: %[[VAL_66:.*]] = arith.remsi %[[VAL_24]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_67:.*]] = arith.divsi %[[VAL_66]], %[[VAL_63]] : i32 +// CHECK: %[[VAL_68:.*]] = arith.muli %[[VAL_65]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_69:.*]] = arith.muli %[[VAL_67]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_68]], %[[VAL_69]] : i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_14]], %[[VAL_14]] : i32, i32 +// CHECK: } +// CHECK: %[[VAL_70:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_70]], 49152, %[[VAL_56]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_71:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_14]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_72:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_33]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_72]]{{\[}}%[[VAL_73:.*]]#0, %[[VAL_14]]] %[[VAL_71]], %[[VAL_70]], %[[VAL_56]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_74:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_14]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_75:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_34]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_75]]{{\[}}%[[VAL_73]]#1, %[[VAL_14]]] %[[VAL_74]], %[[VAL_70]], %[[VAL_56]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_76:.*]] = arith.cmpi sgt, %[[VAL_45]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_77:.*]] = arith.cmpi ne, %[[VAL_46]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_78:.*]] = arith.extui %[[VAL_77]] : i1 to i32 +// CHECK: %[[VAL_79:.*]] = arith.cmpi eq, %[[VAL_78]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_80:.*]] = arith.andi %[[VAL_76]], %[[VAL_79]] : i1 +// CHECK: %[[VAL_81:.*]]:10 = scf.if %[[VAL_80]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_82:.*]] = arith.addi %[[VAL_58]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_82]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_14]], %[[VAL_82]] : i32 +// CHECK: %[[VAL_85:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_86:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_87:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_88:.*]]:3 = scf.if %[[VAL_83]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>) { +// CHECK: %[[VAL_89:.*]] = tt.addptr %[[VAL_0]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_91:.*]] = arith.shrsi %[[VAL_90]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_47]], %[[VAL_89]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_91]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_47]] : !tt.ptr +// CHECK: %[[VAL_92:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_47]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_93:.*]] = tt.addptr %[[VAL_1]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_94:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_95:.*]] = arith.shrsi %[[VAL_94]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_48]], %[[VAL_93]], {{\[}}%[[VAL_18]], %[[VAL_17]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_95]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_48]] : !tt.ptr +// CHECK: %[[VAL_96:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_48]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_97:.*]] = tt.addptr %[[VAL_2]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_98:.*]] = arith.muli %[[VAL_35]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_99:.*]] = arith.shrsi %[[VAL_98]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_49]], %[[VAL_97]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_99]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_49]] : !tt.ptr +// CHECK: %[[VAL_100:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_49]] : !tt.ptr to !tt.tensordesc> +// CHECK: scf.yield %[[VAL_92]], %[[VAL_96]], %[[VAL_100]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_33]], %[[VAL_34]], %[[VAL_36]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } +// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_57]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_102:.*]] = arith.divsi %[[VAL_101]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_103:.*]] = arith.muli %[[VAL_102]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_104:.*]] = arith.subi %[[VAL_26]], %[[VAL_103]] : i32 +// CHECK: %[[VAL_105:.*]] = arith.minsi %[[VAL_104]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_106:.*]] = arith.remsi %[[VAL_101]], %[[VAL_105]] : i32 +// CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_103]], %[[VAL_106]] : i32 +// CHECK: %[[VAL_108:.*]] = arith.remsi %[[VAL_101]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_109:.*]] = arith.divsi %[[VAL_108]], %[[VAL_105]] : i32 +// CHECK: %[[VAL_110:.*]] = arith.muli %[[VAL_107]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_111:.*]] = arith.muli %[[VAL_109]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_112:.*]]#0, %[[VAL_112]]#1, %[[VAL_112]]#2, %[[VAL_101]], %[[VAL_84]], %[[VAL_110]], %[[VAL_111]], %[[VAL_85]], %[[VAL_86]], %[[VAL_87]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_33]], %[[VAL_34]], %[[VAL_36]], %[[VAL_57]], %[[VAL_58]], %[[VAL_73]]#0, %[[VAL_73]]#1, %[[VAL_14]], %[[VAL_14]], %[[VAL_14]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_113:.*]] = arith.muli %[[VAL_78]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_114:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_114]], 49152, %[[VAL_76]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_115:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_11]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_116:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_117:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_116]]{{\[}}%[[VAL_117]]#5, %[[VAL_113]]] %[[VAL_115]], %[[VAL_114]], %[[VAL_76]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_118:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_11]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_119:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_117]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_119]]{{\[}}%[[VAL_117]]#6, %[[VAL_113]]] %[[VAL_118]], %[[VAL_114]], %[[VAL_76]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_120:.*]] = ttg.local_alloc : () -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_121:.*]]:24 = scf.for %[[VAL_122:.*]] = %[[VAL_14]] to %[[VAL_45]] step %[[VAL_11]] iter_args(%[[VAL_123:.*]] = %[[VAL_78]], %[[VAL_124:.*]] = %[[VAL_117]]#0, %[[VAL_125:.*]] = %[[VAL_117]]#1, %[[VAL_126:.*]] = %[[VAL_117]]#2, %[[VAL_127:.*]] = %[[VAL_117]]#3, %[[VAL_128:.*]] = %[[VAL_117]]#4, %[[VAL_129:.*]] = %[[VAL_117]]#5, %[[VAL_130:.*]] = %[[VAL_117]]#6, %[[VAL_131:.*]] = %[[VAL_23]], %[[VAL_132:.*]] = %[[VAL_10]], %[[VAL_133:.*]] = %[[VAL_11]], %[[VAL_134:.*]] = %[[VAL_13]], %[[VAL_135:.*]] = %[[VAL_14]], %[[VAL_136:.*]] = %[[VAL_117]]#7, %[[VAL_137:.*]] = %[[VAL_117]]#8, %[[VAL_138:.*]] = %[[VAL_117]]#9, %[[VAL_139:.*]] = %[[VAL_14]], %[[VAL_140:.*]] = %[[VAL_78]], %[[VAL_141:.*]] = %[[VAL_36]], %[[VAL_142:.*]] = %[[VAL_117]]#2, %[[VAL_143:.*]] = %[[VAL_73]]#0, %[[VAL_144:.*]] = %[[VAL_117]]#5, %[[VAL_145:.*]] = %[[VAL_73]]#1, %[[VAL_146:.*]] = %[[VAL_117]]#6) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) : i32 { +// CHECK: %[[VAL_147:.*]] = arith.subi %[[VAL_45]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_148:.*]] = arith.cmpi slt, %[[VAL_122]], %[[VAL_147]] : i32 +// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_123]], %[[VAL_46]] : i32 +// CHECK: %[[VAL_150:.*]] = arith.addi %[[VAL_123]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_151:.*]] = arith.select %[[VAL_149]], %[[VAL_14]], %[[VAL_150]] : i32 +// CHECK: %[[VAL_152:.*]] = arith.cmpi eq, %[[VAL_151]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_153:.*]] = arith.andi %[[VAL_148]], %[[VAL_152]] : i1 +// CHECK: %[[VAL_154:.*]]:10 = scf.if %[[VAL_153]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_155:.*]] = arith.addi %[[VAL_128]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_156:.*]] = arith.cmpi eq, %[[VAL_155]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_157:.*]] = arith.select %[[VAL_156]], %[[VAL_14]], %[[VAL_155]] : i32 +// CHECK: %[[VAL_158:.*]]:6 = scf.if %[[VAL_156]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32) { +// CHECK: %[[VAL_159:.*]] = tt.addptr %[[VAL_0]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_160:.*]] = arith.muli %[[VAL_136]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_161:.*]] = tt.addptr %[[VAL_47]], %[[VAL_160]] : !tt.ptr, i32 +// CHECK: %[[VAL_162:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_163:.*]] = arith.shrsi %[[VAL_162]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_161]], %[[VAL_159]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_163]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_161]] : !tt.ptr +// CHECK: %[[VAL_164:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_161]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_165:.*]] = arith.addi %[[VAL_136]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_166:.*]] = arith.cmpi slt, %[[VAL_165]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_165]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_168:.*]] = tt.addptr %[[VAL_1]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_169:.*]] = arith.muli %[[VAL_137]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_170:.*]] = tt.addptr %[[VAL_48]], %[[VAL_169]] : !tt.ptr, i32 +// CHECK: %[[VAL_171:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_172:.*]] = arith.shrsi %[[VAL_171]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_170]], %[[VAL_168]], {{\[}}%[[VAL_18]], %[[VAL_17]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_172]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_170]] : !tt.ptr +// CHECK: %[[VAL_173:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_170]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_174:.*]] = arith.addi %[[VAL_137]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_175:.*]] = arith.cmpi slt, %[[VAL_174]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_176:.*]] = arith.select %[[VAL_175]], %[[VAL_174]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_177:.*]] = tt.addptr %[[VAL_2]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_178:.*]] = arith.muli %[[VAL_138]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_179:.*]] = tt.addptr %[[VAL_49]], %[[VAL_178]] : !tt.ptr, i32 +// CHECK: %[[VAL_180:.*]] = arith.muli %[[VAL_35]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_181:.*]] = arith.shrsi %[[VAL_180]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_179]], %[[VAL_177]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_181]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_179]] : !tt.ptr +// CHECK: %[[VAL_182:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_179]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_183:.*]] = arith.addi %[[VAL_138]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_184:.*]] = arith.cmpi slt, %[[VAL_183]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_185:.*]] = arith.select %[[VAL_184]], %[[VAL_183]], %[[VAL_14]] : i32 +// CHECK: scf.yield %[[VAL_164]], %[[VAL_173]], %[[VAL_182]], %[[VAL_167]], %[[VAL_176]], %[[VAL_185]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_136]], %[[VAL_137]], %[[VAL_138]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_186:.*]] = arith.addi %[[VAL_127]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_187:.*]] = arith.divsi %[[VAL_186]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_188:.*]] = arith.muli %[[VAL_187]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_189:.*]] = arith.subi %[[VAL_26]], %[[VAL_188]] : i32 +// CHECK: %[[VAL_190:.*]] = arith.minsi %[[VAL_189]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_191:.*]] = arith.remsi %[[VAL_186]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_192:.*]] = arith.addi %[[VAL_188]], %[[VAL_191]] : i32 +// CHECK: %[[VAL_193:.*]] = arith.remsi %[[VAL_186]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_194:.*]] = arith.divsi %[[VAL_193]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_195:.*]] = arith.muli %[[VAL_192]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_196:.*]] = arith.muli %[[VAL_194]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_197:.*]]#0, %[[VAL_197]]#1, %[[VAL_197]]#2, %[[VAL_186]], %[[VAL_157]], %[[VAL_195]], %[[VAL_196]], %[[VAL_197]]#3, %[[VAL_197]]#4, %[[VAL_197]]#5 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_127]], %[[VAL_128]], %[[VAL_129]], %[[VAL_130]], %[[VAL_136]], %[[VAL_137]], %[[VAL_138]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_198:.*]] = arith.addi %[[VAL_134]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_199:.*]] = arith.cmpi slt, %[[VAL_198]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_200:.*]] = arith.select %[[VAL_199]], %[[VAL_198]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_201:.*]] = arith.xori %[[VAL_135]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_202:.*]] = arith.select %[[VAL_199]], %[[VAL_135]], %[[VAL_201]] : i32 +// CHECK: %[[VAL_203:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_200]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.wait_barrier %[[VAL_203]], %[[VAL_202]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_204:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_200]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_205:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_200]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_206:.*]] = ttg.memdesc_trans %[[VAL_204]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_207:.*]] = ttng.warp_group_dot %[[VAL_205]], %[[VAL_206]], %[[VAL_131]], %[[VAL_132]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -> tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_208:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_207]], %[[VAL_205]], %[[VAL_206]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_209:.*]] = arith.addi %[[VAL_133]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_210:.*]] = arith.cmpi slt, %[[VAL_209]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_211:.*]] = arith.select %[[VAL_210]], %[[VAL_209]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_212:.*]] = arith.muli %[[VAL_151]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_213:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_211]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_213]], 49152, %[[VAL_148]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_214:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_211]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_215:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_216:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_215]]{{\[}}%[[VAL_216]]#5, %[[VAL_212]]] %[[VAL_214]], %[[VAL_213]], %[[VAL_148]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_217:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_211]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_218:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_216]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_218]]{{\[}}%[[VAL_216]]#6, %[[VAL_212]]] %[[VAL_217]], %[[VAL_213]], %[[VAL_148]] : , <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> <256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_219:.*]] = arith.cmpi eq, %[[VAL_139]], %[[VAL_46]] : i32 +// CHECK: %[[VAL_220:.*]] = arith.cmpi ne, %[[VAL_139]], %[[VAL_46]] : i32 +// CHECK: scf.if %[[VAL_219]] { +// CHECK: %[[VAL_221:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_208]]#0, %[[VAL_205]], %[[VAL_206]] {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_222:.*]] = arith.truncf %[[VAL_221]]#0 : tensor<128x256xf32, #[[$ATTR_0]]> to tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_store %[[VAL_222]], %[[VAL_120]] : tensor<128x256xf16, #[[$ATTR_0]]> -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: ttng.fence_async_shared {bCluster = false} +// CHECK: %[[VAL_223:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_141]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_223]]{{\[}}%[[VAL_143]], %[[VAL_145]]] %[[VAL_120]] : , <128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: } +// CHECK: scf.yield %[[VAL_151]], %[[VAL_216]]#0, %[[VAL_216]]#1, %[[VAL_216]]#2, %[[VAL_216]]#3, %[[VAL_216]]#4, %[[VAL_216]]#5, %[[VAL_216]]#6, %[[VAL_208]]#0, %[[VAL_220]], %[[VAL_211]], %[[VAL_200]], %[[VAL_202]], %[[VAL_216]]#7, %[[VAL_216]]#8, %[[VAL_216]]#9, %[[VAL_140]], %[[VAL_151]], %[[VAL_142]], %[[VAL_216]]#2, %[[VAL_144]], %[[VAL_216]]#5, %[[VAL_146]], %[[VAL_216]]#6 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 +// CHECK: } +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_dealloc %[[VAL_120]] : !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_224:.*]] = ttng.warp_group_dot_wait %[[VAL_225:.*]]#8 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_226:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_227:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_227]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_228:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_228]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_229:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_229]] : <1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttg.local_dealloc %[[VAL_50]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_51]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: tt.return +// CHECK: } +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in new file mode 100644 index 0000000000..7b096e8883 --- /dev/null +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in @@ -0,0 +1,104 @@ +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \ +// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp index 459a00c1a1..1bc6eb7cf0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -9,12 +9,13 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::nvidia_gpu; namespace { -constexpr int64_t TMA_SIZE_BYTES = 128; void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value outPtr, From f5d541ce7addb4bbf468d87c67356768b6da7405 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Tue, 10 Dec 2024 11:18:40 -0800 Subject: [PATCH 4/7] [Backend] Reorder return values in `emitHardwareTuple` (NFC) (#5390) @lezcano pointed out in another PR that the order is confusing because typically we list the lane ID, warp ID, and blockID in this order. --- include/triton/Conversion/TritonGPUToLLVM/Utility.h | 4 ++-- lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp | 2 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 4b2179611a..fc1b731d15 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1123,8 +1123,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, return idx; } -// Emit code to compute the (blockId, warpId, laneId) for the current thread. -std::tuple +// Emit code to compute the (laneId, warpId, blockId) for the current thread. +std::tuple emitHardwareTuple(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, bool withCTAOffset, unsigned threadsPerWarp); diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index faf781369e..673ac8e74f 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -258,7 +258,7 @@ void GatherOpConversion::emitWarpLocalGather( SmallVector idxValues = unpackLLElements(loc, adaptor.getIndices(), rewriter); - auto [blockId, warpId, laneId] = + auto [laneId, warpId, blockId] = emitHardwareTuple(loc, rewriter, targetInfo, /*withCTAOffset=*/true, srcLayout.getInDimSize(kLane)); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 02b3b121f4..eb2c82cfc9 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -110,7 +110,7 @@ std::tuple emitHardwareTuple(Location loc, Value warpId = udiv(threadId, threadsPerWarp); Value blockId = withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); - return {blockId, warpId, laneId}; + return {laneId, warpId, blockId}; } SmallVector> @@ -130,7 +130,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); - auto [blockId, warpId, laneId] = emitHardwareTuple( + auto [laneId, warpId, blockId] = emitHardwareTuple( loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane)); unsigned rank = shape.size(); SmallVector> ret; @@ -353,7 +353,7 @@ bool emitTransferBetweenRegistersAndShared( std::min(regToSharedLayout->getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); - auto [blockId, warpId, laneId] = + auto [laneId, warpId, blockId] = emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false, regToSharedLayout->getInDimSize(kLane)); @@ -746,7 +746,7 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto instrShape = mmaLayout.getInstrShape(); SmallVector mmaColIdx(2); SmallVector mmaRowIdx(2); - auto [blockId, warpId, laneId] = emitHardwareTuple( + auto [laneId, warpId, blockId] = emitHardwareTuple( loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32); // TODO: fix the bug in MMAEncodingAttr document SmallVector multiDimWarpId(2); From c646244e2a901ed6f5eaf40f60ae194b18b01b2a Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Tue, 10 Dec 2024 23:38:01 +0000 Subject: [PATCH 5/7] [CI] Cleanup AMD worker at the end of a job (#5391) The AMD runner persists changes to the file system between jobs, so the caches need to be manually cleaned up. Closes #5384 --- .github/workflows/integration-tests.yml | 9 ++++----- .github/workflows/integration-tests.yml.in | 10 ++++------ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 9ae78a64b6..4efd64a5db 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -408,10 +408,6 @@ jobs: cd python ccache --zero-stats pip install -v -e '.[tests]' - - name: Clean up after an unsuccessful build - if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }} - run: | - rm -rf ~/.triton - name: CCache Stats run: ccache --print-stats - name: Run lit tests @@ -477,8 +473,11 @@ jobs: ~/.ccache key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - name: Clean up caches + # Always cleanup the worker, even if builds or tests failed + if: always() run: | - rm -rf ~/.triton/cache + rm -rf ~/.triton + rm -rf ~/.ccache Build-Tests: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 6b58be6571..fd4b0e980e 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -402,11 +402,6 @@ jobs: ccache --zero-stats pip install -v -e '.[tests]' - - name: Clean up after an unsuccessful build - if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }} - run: | - rm -rf ~/.triton - - *print-ccache-stats - *run-lit-tests-step @@ -442,8 +437,11 @@ jobs: - *save-build-artifacts-step - name: Clean up caches + # Always cleanup the worker, even if builds or tests failed + if: always() run: | - rm -rf ~/.triton/cache + rm -rf ~/.triton + rm -rf ~/.ccache Build-Tests: needs: Runner-Preparation From f257479997db6970b3329ba524a02cd508b9d87d Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 10 Dec 2024 23:42:16 +0000 Subject: [PATCH 6/7] [AMD] Reland "Add gfx950 target definitions" (#5392) This relands https://github.com/triton-lang/triton/pull/5392 to enable new arch target since backend support has been added--it doesn't depend on the reverted LLVM upgrade in https://github.com/triton-lang/triton/pull/5341; basic necessary enablement is already included in the current llvm version we're using. --- third_party/amd/backend/include/hsa/amd_hsa_elf.h | 2 +- third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 74f15d7d7a..65a77f041b 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -136,7 +136,7 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d, EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e, - EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f, + EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050, EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051, EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 63fb972f79..7ab6fd68a5 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -11,6 +11,7 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { // CDNA ISA cases switch (kind) { + case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: case llvm::AMDGPU::GK_GFX941: case llvm::AMDGPU::GK_GFX940: From 6f5baf6801b44e51b7ba8eedaa619e39c912bef6 Mon Sep 17 00:00:00 2001 From: Yuanwei Fang Date: Tue, 10 Dec 2024 17:46:44 -0800 Subject: [PATCH 7/7] Allow TRITON_KERNEL_OVERRIDE on .amdgcn and .hsaco files (#5394) Enable the TRITON_KERNEL_OVERRIDE feature to work on AMD assembly and binary. Currently, for the backends, it only works on Nvidia `ptx` and `cubin`. --------- Co-authored-by: Yuanwei Fang --- README.md | 10 +++++----- python/triton/compiler/compiler.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 4e5bf22346..79f42adf93 100644 --- a/README.md +++ b/README.md @@ -211,10 +211,10 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). - `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks. -- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx. -- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx when `TRITON_KERNEL_DUMP` is set to 1. -- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx at the beginning of each compilation stage. -- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx files when `TRITON_KERNEL_OVERRIDE` is set to 1. +- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn. +- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1. +- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage. +- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1. **Kernel Override Steps** @@ -224,7 +224,7 @@ export TRITON_KERNEL_DUMP=1 export TRITON_DUMP_DIR= export TRITON_KERNEL_OVERRIDE=1 export TRITON_OVERRIDE_DIR= -# Step 1: Run the kernel once to dump kernel's IRs and ptx in $TRITON_DUMP_DIR +# Step 1: Run the kernel once to dump kernel's IRs and ptx/amdgcn in $TRITON_DUMP_DIR # Step 2: Copy $TRITON_DUMP_DIR/ to $TRITON_OVERRIDE_DIR # Step 3: Delete the stages that you do not want to override and modify the stage you do want to override # Step 4: Run the kernel again to see the overridden result diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 52b8afea14..2091109824 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -170,9 +170,9 @@ def parse(full_name, ext, context): module = ir.parse_mlir_module(full_name, context) module.context = context return module - if ext == "llir" or ext == "ptx": + if ext == "llir" or ext == "ptx" or ext == "amdgcn": return Path(full_name).read_text() - if ext == "cubin": + if ext == "cubin" or ext == "hsaco": return Path(full_name).read_bytes()